Commit 3aeb0cea authored by Artem Serga's avatar Artem Serga

Add full support for SQ.Enum field

parent 2bdb4fba
from flask.ext.admin.babel import gettext
import warnings
from flask.ext.admin.babel import gettext
from flask.ext.admin.model import filters
from flask.ext.admin.contrib.sqlamodel import tools
......@@ -90,30 +91,45 @@ class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter):
class FilterConverter(filters.BaseFilterConverter):
strings = (FilterEqual, FilterNotEqual, FilterLike, FilterNotLike)
numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller)
bool = (BooleanEqualFilter, BooleanNotEqualFilter)
enum = (FilterEqual, FilterNotEqual)
def convert(self, type_name, column, name):
def convert(self, type_name, column, name, **kwargs):
if type_name in self.converters:
return self.converters[type_name](column, name)
return self.converters[type_name](column, name, **kwargs)
return None
@filters.convert('String', 'Unicode', 'Text', 'UnicodeText')
def conv_string(self, column, name):
return [f(column, name) for f in self.strings]
def conv_string(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.strings]
@filters.convert('Boolean')
def conv_bool(self, column, name):
return [BooleanEqualFilter(column, name),
BooleanNotEqualFilter(column, name)]
def conv_bool(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.bool]
@filters.convert('Integer', 'SmallInteger', 'Numeric', 'Float')
def conv_int(self, column, name):
return [f(column, name) for f in self.numeric]
def conv_int(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.numeric]
@filters.convert('Date')
def conv_date(self, column, name):
return [f(column, name, data_type='datepicker') for f in self.numeric]
def conv_date(self, column, name, **kwargs):
return [f(column, name, data_type='datepicker', **kwargs) for f in self.numeric]
@filters.convert('DateTime')
def conv_datetime(self, column, name):
return [f(column, name, data_type='datetimepicker') for f in self.numeric]
def conv_datetime(self, column, name, **kwargs):
return [f(column, name, data_type='datetimepicker', **kwargs) for f in self.numeric]
@filters.convert('Enum', 'ENUM')
def conv_enum(self, column, name, options=None, **kwargs):
if not options:
warnings.warn(
'You can make SQ field with `Enum` type '
'more human readable in the form by using '
'`column_choices` in your `ModelView`'
)
options = [
(v, v)
for v in column.type.enums
]
return [f(column, name, options, **kwargs) for f in self.enum]
......@@ -2,6 +2,7 @@ from wtforms import fields, validators
from sqlalchemy import Boolean, Column
from flask.ext.admin import form
from flask.ext.admin.form import Select2Field
from flask.ext.admin.tools import get_property
from flask.ext.admin.model.form import (converts, ModelConverterBase,
InlineFormAdmin, InlineModelConverterBase)
......@@ -186,6 +187,16 @@ class AdminModelConverter(ModelConverterBase):
if override:
return override(**kwargs)
# Check choices
if mapper.class_ == self.view.model and self.view.form_choices:
choices = self.view.form_choices.get(column.key)
if choices:
return Select2Field(
choices=choices,
allow_blank=column.nullable,
**kwargs
)
# Run converter
converter = self.get_converter(column)
......
......@@ -218,6 +218,18 @@ class ModelView(BaseModelView):
column_type_formatters = DEFAULT_FORMATTERS
form_choices = None
"""
Map choices to form fields
Example::
class MyModelView(BaseModelView):
form_choices = {'my_form_field': [
('db_value', 'display_value'),
]
"""
def __init__(self, model, session,
name=None, category=None, endpoint=None, url=None):
"""
......@@ -243,6 +255,9 @@ class ModelView(BaseModelView):
self._filter_joins = dict()
if self.form_choices is None:
self.form_choices = {}
super(ModelView, self).__init__(model, name, category, endpoint, url)
# Primary key
......@@ -453,9 +468,11 @@ class ModelView(BaseModelView):
column = columns[0]
if self._need_join(column.table):
visible_name = '%s / %s' % (self.get_column_name(column.table.name),
self.get_column_name(column.name))
if self._need_join(column.table) and name not in self.column_labels:
visible_name = '%s / %s' % (
self.get_column_name(column.table.name),
self.get_column_name(column.name)
)
else:
if not isinstance(name, basestring):
visible_name = self.get_column_name(name.property.key)
......@@ -463,16 +480,19 @@ class ModelView(BaseModelView):
visible_name = self.get_column_name(name)
type_name = type(column.type).__name__
flt = self.filter_converter.convert(type_name,
column,
visible_name)
if flt:
# If there's relation to other table, do it
if join_tables:
self._filter_joins[column.table.name] = join_tables
elif self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
if join_tables:
self._filter_joins[column.table.name] = join_tables
flt = self.filter_converter.convert(
type_name,
column,
visible_name,
options=self.column_choices.get(name),
)
if flt and not join_tables and self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
return flt
......
......@@ -110,6 +110,45 @@ class Select2Field(fields.SelectField):
"""
widget = Select2Widget()
def __init__(self, label=None, validators=None, coerce=unicode,
choices=None, allow_blank=False, blank_text=None, **kwargs):
super(Select2Field, self).__init__(
label, validators, coerce, choices, **kwargs
)
self.allow_blank = allow_blank
self.blank_text = blank_text or ' '
def iter_choices(self):
if self.allow_blank:
yield (u'__None', self.blank_text, self.data is None)
for value, label in self.choices:
yield (value, label, self.coerce(value) == self.data)
def process_data(self, value):
if value is None:
self.data = None
else:
try:
self.data = self.coerce(value)
except (ValueError, TypeError):
self.data = None
def process_formdata(self, valuelist):
if valuelist:
if valuelist[0] == '__None':
self.data = None
else:
try:
self.data = self.coerce(valuelist[0])
except ValueError:
raise ValueError(self.gettext(u'Invalid Choice: could not coerce'))
def pre_validate(self, form):
if self.allow_blank and self.data is None:
return
super(Select2Field, self).pre_validate(form)
class DatePickerWidget(widgets.TextInput):
"""
......
......@@ -180,6 +180,20 @@ class BaseModelView(BaseView, ActionsMixin):
class MyModelView(BaseModelView):
column_searchable_list = ('name', 'email')
"""
column_choices = None
"""
Map choices to columns in list view
Example::
class MyModelView(BaseModelView):
column_choices = {
'my_column': [
('db_value', 'display_value'),
]
}
"""
column_filters = None
"""
......@@ -339,6 +353,10 @@ class BaseModelView(BaseView, ActionsMixin):
self._list_columns = self.get_list_columns()
self._sortable_columns = self.get_sortable_columns()
# Labels
if self.column_labels is None:
self.column_labels = {}
# Forms
self._create_form_class = self.get_create_form()
self._edit_form_class = self.get_edit_form()
......@@ -349,6 +367,15 @@ class BaseModelView(BaseView, ActionsMixin):
# Search
self._search_supported = self.init_search()
# Choices
if self.column_choices:
self._column_choices_map = dict([
(column, dict(choices))
for column, choices in self.column_choices.items()
])
else:
self.column_choices = self._column_choices_map = dict()
# Filters
self._filters = self.get_filters()
......@@ -499,15 +526,14 @@ class BaseModelView(BaseView, ActionsMixin):
collection = []
for n in self.column_filters:
if not self.is_valid_filter(n):
if self.is_valid_filter(n):
collection.append(n)
else:
flt = self.scaffold_filters(n)
if flt:
collection.extend(flt)
else:
raise Exception('Unsupported filter type %s' % n)
else:
collection.append(n)
return collection
else:
return None
......@@ -794,6 +820,10 @@ class BaseModelView(BaseView, ActionsMixin):
value = rec_getattr(model, name)
choices_map = self._column_choices_map.get(name, {})
if choices_map:
return choices_map.get(value) or value
type_fmt = self.column_type_formatters.get(type(value))
if type_fmt is not None:
value = type_fmt(value)
......
......@@ -33,6 +33,7 @@ def create_models(db):
test3 = db.Column(db.Text)
test4 = db.Column(db.UnicodeText)
bool_field = db.Column(db.Boolean)
enum_field = db.Column(db.Enum('model1_v1', 'model1_v1'), nullable=True)
class Model2(db.Model):
def __init__(self, string_field=None, int_field=None, bool_field=None, model1=None):
......@@ -45,6 +46,9 @@ def create_models(db):
string_field = db.Column(db.String)
int_field = db.Column(db.Integer)
bool_field = db.Column(db.Boolean)
enum_field = db.Column(db.Enum('model2_v1', 'model2_v2'), nullable=True)
# Relation
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1)
......@@ -162,8 +166,10 @@ def test_exclude_columns():
Model1, Model2 = create_models(db)
view = CustomModelView(Model1, db.session,
column_exclude_list=['test2', 'test4'])
view = CustomModelView(
Model1, db.session,
column_exclude_list=['test2', 'test4', 'enum_field']
)
admin.add_view(view)
eq_(
......@@ -210,8 +216,10 @@ def test_column_filters():
Model1, Model2 = create_models(db)
view = CustomModelView(Model1, db.session,
column_filters=['test1'])
view = CustomModelView(
Model1, db.session,
column_filters=['test1']
)
admin.add_view(view)
eq_(len(view._filters), 4)
......@@ -258,6 +266,10 @@ def test_column_filters():
(16, 'equals'),
(17, 'not equal'),
],
'Model1 / Enum Field': [
(18, u'equals'),
(19, u'not equal'),
]
})
# Test filter with a dot
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment