Commit b310562b authored by Serge S. Koval's avatar Serge S. Koval

Merge pull request #189 from ArtemSerga/_choices

Add full support for SQLAlchemy.Enum field
parents 2bdb4fba 3aeb0cea
from flask.ext.admin.babel import gettext
import warnings
from flask.ext.admin.babel import gettext
from flask.ext.admin.model import filters
from flask.ext.admin.contrib.sqlamodel import tools
......@@ -90,30 +91,45 @@ class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter):
class FilterConverter(filters.BaseFilterConverter):
strings = (FilterEqual, FilterNotEqual, FilterLike, FilterNotLike)
numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller)
bool = (BooleanEqualFilter, BooleanNotEqualFilter)
enum = (FilterEqual, FilterNotEqual)
def convert(self, type_name, column, name):
def convert(self, type_name, column, name, **kwargs):
if type_name in self.converters:
return self.converters[type_name](column, name)
return self.converters[type_name](column, name, **kwargs)
return None
@filters.convert('String', 'Unicode', 'Text', 'UnicodeText')
def conv_string(self, column, name):
return [f(column, name) for f in self.strings]
def conv_string(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.strings]
@filters.convert('Boolean')
def conv_bool(self, column, name):
return [BooleanEqualFilter(column, name),
BooleanNotEqualFilter(column, name)]
def conv_bool(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.bool]
@filters.convert('Integer', 'SmallInteger', 'Numeric', 'Float')
def conv_int(self, column, name):
return [f(column, name) for f in self.numeric]
def conv_int(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.numeric]
@filters.convert('Date')
def conv_date(self, column, name):
return [f(column, name, data_type='datepicker') for f in self.numeric]
def conv_date(self, column, name, **kwargs):
return [f(column, name, data_type='datepicker', **kwargs) for f in self.numeric]
@filters.convert('DateTime')
def conv_datetime(self, column, name):
return [f(column, name, data_type='datetimepicker') for f in self.numeric]
def conv_datetime(self, column, name, **kwargs):
return [f(column, name, data_type='datetimepicker', **kwargs) for f in self.numeric]
@filters.convert('Enum', 'ENUM')
def conv_enum(self, column, name, options=None, **kwargs):
if not options:
warnings.warn(
'You can make SQ field with `Enum` type '
'more human readable in the form by using '
'`column_choices` in your `ModelView`'
)
options = [
(v, v)
for v in column.type.enums
]
return [f(column, name, options, **kwargs) for f in self.enum]
......@@ -2,6 +2,7 @@ from wtforms import fields, validators
from sqlalchemy import Boolean, Column
from flask.ext.admin import form
from flask.ext.admin.form import Select2Field
from flask.ext.admin.tools import get_property
from flask.ext.admin.model.form import (converts, ModelConverterBase,
InlineFormAdmin, InlineModelConverterBase)
......@@ -186,6 +187,16 @@ class AdminModelConverter(ModelConverterBase):
if override:
return override(**kwargs)
# Check choices
if mapper.class_ == self.view.model and self.view.form_choices:
choices = self.view.form_choices.get(column.key)
if choices:
return Select2Field(
choices=choices,
allow_blank=column.nullable,
**kwargs
)
# Run converter
converter = self.get_converter(column)
......
......@@ -218,6 +218,18 @@ class ModelView(BaseModelView):
column_type_formatters = DEFAULT_FORMATTERS
form_choices = None
"""
Map choices to form fields
Example::
class MyModelView(BaseModelView):
form_choices = {'my_form_field': [
('db_value', 'display_value'),
]
"""
def __init__(self, model, session,
name=None, category=None, endpoint=None, url=None):
"""
......@@ -243,6 +255,9 @@ class ModelView(BaseModelView):
self._filter_joins = dict()
if self.form_choices is None:
self.form_choices = {}
super(ModelView, self).__init__(model, name, category, endpoint, url)
# Primary key
......@@ -453,9 +468,11 @@ class ModelView(BaseModelView):
column = columns[0]
if self._need_join(column.table):
visible_name = '%s / %s' % (self.get_column_name(column.table.name),
self.get_column_name(column.name))
if self._need_join(column.table) and name not in self.column_labels:
visible_name = '%s / %s' % (
self.get_column_name(column.table.name),
self.get_column_name(column.name)
)
else:
if not isinstance(name, basestring):
visible_name = self.get_column_name(name.property.key)
......@@ -463,16 +480,19 @@ class ModelView(BaseModelView):
visible_name = self.get_column_name(name)
type_name = type(column.type).__name__
flt = self.filter_converter.convert(type_name,
column,
visible_name)
if flt:
# If there's relation to other table, do it
if join_tables:
self._filter_joins[column.table.name] = join_tables
elif self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
if join_tables:
self._filter_joins[column.table.name] = join_tables
flt = self.filter_converter.convert(
type_name,
column,
visible_name,
options=self.column_choices.get(name),
)
if flt and not join_tables and self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
return flt
......
......@@ -110,6 +110,45 @@ class Select2Field(fields.SelectField):
"""
widget = Select2Widget()
def __init__(self, label=None, validators=None, coerce=unicode,
choices=None, allow_blank=False, blank_text=None, **kwargs):
super(Select2Field, self).__init__(
label, validators, coerce, choices, **kwargs
)
self.allow_blank = allow_blank
self.blank_text = blank_text or ' '
def iter_choices(self):
if self.allow_blank:
yield (u'__None', self.blank_text, self.data is None)
for value, label in self.choices:
yield (value, label, self.coerce(value) == self.data)
def process_data(self, value):
if value is None:
self.data = None
else:
try:
self.data = self.coerce(value)
except (ValueError, TypeError):
self.data = None
def process_formdata(self, valuelist):
if valuelist:
if valuelist[0] == '__None':
self.data = None
else:
try:
self.data = self.coerce(valuelist[0])
except ValueError:
raise ValueError(self.gettext(u'Invalid Choice: could not coerce'))
def pre_validate(self, form):
if self.allow_blank and self.data is None:
return
super(Select2Field, self).pre_validate(form)
class DatePickerWidget(widgets.TextInput):
"""
......
......@@ -180,6 +180,20 @@ class BaseModelView(BaseView, ActionsMixin):
class MyModelView(BaseModelView):
column_searchable_list = ('name', 'email')
"""
column_choices = None
"""
Map choices to columns in list view
Example::
class MyModelView(BaseModelView):
column_choices = {
'my_column': [
('db_value', 'display_value'),
]
}
"""
column_filters = None
"""
......@@ -339,6 +353,10 @@ class BaseModelView(BaseView, ActionsMixin):
self._list_columns = self.get_list_columns()
self._sortable_columns = self.get_sortable_columns()
# Labels
if self.column_labels is None:
self.column_labels = {}
# Forms
self._create_form_class = self.get_create_form()
self._edit_form_class = self.get_edit_form()
......@@ -349,6 +367,15 @@ class BaseModelView(BaseView, ActionsMixin):
# Search
self._search_supported = self.init_search()
# Choices
if self.column_choices:
self._column_choices_map = dict([
(column, dict(choices))
for column, choices in self.column_choices.items()
])
else:
self.column_choices = self._column_choices_map = dict()
# Filters
self._filters = self.get_filters()
......@@ -499,15 +526,14 @@ class BaseModelView(BaseView, ActionsMixin):
collection = []
for n in self.column_filters:
if not self.is_valid_filter(n):
if self.is_valid_filter(n):
collection.append(n)
else:
flt = self.scaffold_filters(n)
if flt:
collection.extend(flt)
else:
raise Exception('Unsupported filter type %s' % n)
else:
collection.append(n)
return collection
else:
return None
......@@ -794,6 +820,10 @@ class BaseModelView(BaseView, ActionsMixin):
value = rec_getattr(model, name)
choices_map = self._column_choices_map.get(name, {})
if choices_map:
return choices_map.get(value) or value
type_fmt = self.column_type_formatters.get(type(value))
if type_fmt is not None:
value = type_fmt(value)
......
......@@ -33,6 +33,7 @@ def create_models(db):
test3 = db.Column(db.Text)
test4 = db.Column(db.UnicodeText)
bool_field = db.Column(db.Boolean)
enum_field = db.Column(db.Enum('model1_v1', 'model1_v1'), nullable=True)
class Model2(db.Model):
def __init__(self, string_field=None, int_field=None, bool_field=None, model1=None):
......@@ -45,6 +46,9 @@ def create_models(db):
string_field = db.Column(db.String)
int_field = db.Column(db.Integer)
bool_field = db.Column(db.Boolean)
enum_field = db.Column(db.Enum('model2_v1', 'model2_v2'), nullable=True)
# Relation
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1)
......@@ -162,8 +166,10 @@ def test_exclude_columns():
Model1, Model2 = create_models(db)
view = CustomModelView(Model1, db.session,
column_exclude_list=['test2', 'test4'])
view = CustomModelView(
Model1, db.session,
column_exclude_list=['test2', 'test4', 'enum_field']
)
admin.add_view(view)
eq_(
......@@ -210,8 +216,10 @@ def test_column_filters():
Model1, Model2 = create_models(db)
view = CustomModelView(Model1, db.session,
column_filters=['test1'])
view = CustomModelView(
Model1, db.session,
column_filters=['test1']
)
admin.add_view(view)
eq_(len(view._filters), 4)
......@@ -258,6 +266,10 @@ def test_column_filters():
(16, 'equals'),
(17, 'not equal'),
],
'Model1 / Enum Field': [
(18, u'equals'),
(19, u'not equal'),
]
})
# Test filter with a dot
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment