Commit 2323b5e0 authored by Serge S. Koval's avatar Serge S. Koval

Merge pull request #697 from pawl/master

Fix Time/DateTime Equals Filters For SQLAlchemy+SQLite
parents 60ec7cb6 8a25ce88
import warnings import warnings
import time
import datetime
from flask.ext.admin.babel import lazy_gettext from flask.ext.admin.babel import lazy_gettext
from flask.ext.admin.model import filters from flask.ext.admin.model import filters
...@@ -81,11 +83,49 @@ class FilterSmaller(BaseSQLAFilter): ...@@ -81,11 +83,49 @@ class FilterSmaller(BaseSQLAFilter):
# Customized type filters # Customized type filters
class BooleanEqualFilter(FilterEqual, filters.BaseBooleanFilter): class BooleanEqualFilter(FilterEqual, filters.BaseBooleanFilter):
pass pass
class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter): class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter):
pass pass
class DateEqualFilter(FilterEqual, filters.BaseDateFilter):
def clean(self, value):
return datetime.datetime.strptime(value, '%Y-%m-%d').date()
def validate(self, value):
try:
self.clean(value)
return True
except ValueError:
return False
class DateTimeEqualFilter(FilterEqual, filters.BaseDateTimeFilter):
def clean(self, value):
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
def validate(self, value):
try:
self.clean(value)
return True
except ValueError:
return False
class TimeEqualFilter(FilterEqual, filters.BaseTimeFilter):
def clean(self, value):
timetuple = time.strptime(value, '%H:%M:%S')
return datetime.time(timetuple.tm_hour,
timetuple.tm_min,
timetuple.tm_sec)
def validate(self, value):
try:
self.clean(value)
return True
except ValueError:
return False
# Base SQLA filter field converter # Base SQLA filter field converter
class FilterConverter(filters.BaseFilterConverter): class FilterConverter(filters.BaseFilterConverter):
...@@ -114,15 +154,24 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -114,15 +154,24 @@ class FilterConverter(filters.BaseFilterConverter):
@filters.convert('date') @filters.convert('date')
def conv_date(self, column, name, **kwargs): def conv_date(self, column, name, **kwargs):
return [f(column, name, data_type='datepicker', **kwargs) for f in self.numeric] return [DateEqualFilter(column, name),
FilterNotEqual(column, name, data_type='datepicker', **kwargs),
FilterGreater(column, name, data_type='datepicker', **kwargs),
FilterSmaller(column, name, data_type='datepicker', **kwargs)]
@filters.convert('datetime') @filters.convert('datetime')
def conv_datetime(self, column, name, **kwargs): def conv_datetime(self, column, name, **kwargs):
return [f(column, name, data_type='datetimepicker', **kwargs) for f in self.numeric] return [DateTimeEqualFilter(column, name),
FilterNotEqual(column, name, data_type='datetimepicker', **kwargs),
FilterGreater(column, name, data_type='datetimepicker', **kwargs),
FilterSmaller(column, name, data_type='datetimepicker', **kwargs)]
@filters.convert('time') @filters.convert('time')
def conv_time(self, column, name, **kwargs): def conv_time(self, column, name, **kwargs):
return [f(column, name, data_type='timepicker', **kwargs) for f in self.numeric] return [TimeEqualFilter(column, name),
FilterNotEqual(column, name, data_type='timepicker', **kwargs),
FilterGreater(column, name, data_type='timepicker', **kwargs),
FilterSmaller(column, name, data_type='timepicker', **kwargs)]
@filters.convert('enum') @filters.convert('enum')
def conv_enum(self, column, name, options=None, **kwargs): def conv_enum(self, column, name, options=None, **kwargs):
......
...@@ -1114,6 +1114,8 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1114,6 +1114,8 @@ class BaseModelView(BaseView, ActionsMixin):
if flt.validate(value): if flt.validate(value):
filters.append((pos, (idx, flt.clean(value)))) filters.append((pos, (idx, flt.clean(value))))
else:
flash(gettext('Invalid Filter Value: %(value)s', value=value))
# Sort filters # Sort filters
return [v[1] for v in sorted(filters, key=lambda n: n[0])] return [v[1] for v in sorted(filters, key=lambda n: n[0])]
......
...@@ -100,7 +100,7 @@ class BaseDateFilter(BaseFilter): ...@@ -100,7 +100,7 @@ class BaseDateFilter(BaseFilter):
""" """
Base Date filter. Uses client-side date picker control. Base Date filter. Uses client-side date picker control.
""" """
def __init__(self, name, options=None): def __init__(self, name, options=None, data_type=None):
super(BaseDateFilter, self).__init__(name, super(BaseDateFilter, self).__init__(name,
options, options,
data_type='datepicker') data_type='datepicker')
...@@ -112,9 +112,9 @@ class BaseDateFilter(BaseFilter): ...@@ -112,9 +112,9 @@ class BaseDateFilter(BaseFilter):
class BaseDateTimeFilter(BaseFilter): class BaseDateTimeFilter(BaseFilter):
""" """
Base DateTime filter. Uses client-side date picker control. Base DateTime filter. Uses client-side date time picker control.
""" """
def __init__(self, name, options=None): def __init__(self, name, options=None, data_type=None):
super(BaseDateTimeFilter, self).__init__(name, super(BaseDateTimeFilter, self).__init__(name,
options, options,
data_type='datetimepicker') data_type='datetimepicker')
...@@ -122,6 +122,20 @@ class BaseDateTimeFilter(BaseFilter): ...@@ -122,6 +122,20 @@ class BaseDateTimeFilter(BaseFilter):
def validate(self, value): def validate(self, value):
# TODO: Validation # TODO: Validation
return True return True
class BaseTimeFilter(BaseFilter):
"""
Base Time filter. Uses client-side time picker control.
"""
def __init__(self, name, options=None, data_type=None):
super(BaseTimeFilter, self).__init__(name,
options,
data_type='timepicker')
def validate(self, value):
# TODO: Validation
return True
def convert(*args): def convert(*args):
......
...@@ -9,6 +9,7 @@ from flask.ext.admin.contrib.sqla import ModelView ...@@ -9,6 +9,7 @@ from flask.ext.admin.contrib.sqla import ModelView
from . import setup from . import setup
from datetime import datetime, time, date
class CustomModelView(ModelView): class CustomModelView(ModelView):
def __init__(self, model, session, def __init__(self, model, session,
...@@ -23,12 +24,16 @@ class CustomModelView(ModelView): ...@@ -23,12 +24,16 @@ class CustomModelView(ModelView):
def create_models(db): def create_models(db):
class Model1(db.Model): class Model1(db.Model):
def __init__(self, test1=None, test2=None, test3=None, test4=None, bool_field=False): def __init__(self, test1=None, test2=None, test3=None, test4=None,
bool_field=False, date_field=None, time_field=None, datetime_field=None):
self.test1 = test1 self.test1 = test1
self.test2 = test2 self.test2 = test2
self.test3 = test3 self.test3 = test3
self.test4 = test4 self.test4 = test4
self.bool_field = bool_field self.bool_field = bool_field
self.date_field = date_field
self.time_field = time_field
self.datetime_field = datetime_field
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
test1 = db.Column(db.String(20)) test1 = db.Column(db.String(20))
...@@ -37,7 +42,11 @@ def create_models(db): ...@@ -37,7 +42,11 @@ def create_models(db):
test4 = db.Column(db.UnicodeText) test4 = db.Column(db.UnicodeText)
bool_field = db.Column(db.Boolean) bool_field = db.Column(db.Boolean)
enum_field = db.Column(db.Enum('model1_v1', 'model1_v1'), nullable=True) enum_field = db.Column(db.Enum('model1_v1', 'model1_v1'), nullable=True)
date_field = db.Column(db.Date)
time_field = db.Column(db.Time)
datetime_field = db.Column(db.DateTime)
def __unicode__(self): def __unicode__(self):
return self.test1 return self.test1
...@@ -178,7 +187,7 @@ def test_exclude_columns(): ...@@ -178,7 +187,7 @@ def test_exclude_columns():
view = CustomModelView( view = CustomModelView(
Model1, db.session, Model1, db.session,
column_exclude_list=['test2', 'test4', 'enum_field'] column_exclude_list=['test2', 'test4', 'enum_field', 'date_field', 'time_field', 'datetime_field']
) )
admin.add_view(view) admin.add_view(view)
...@@ -355,9 +364,19 @@ def test_column_filters(): ...@@ -355,9 +364,19 @@ def test_column_filters():
model2_obj2 = Model2('model2_obj2', model1=model1_obj1) model2_obj2 = Model2('model2_obj2', model1=model1_obj1)
model2_obj3 = Model2('model2_obj3') model2_obj3 = Model2('model2_obj3')
model2_obj4 = Model2('model2_obj4') model2_obj4 = Model2('model2_obj4')
date_obj1 = Model1('date_obj1', date_field=date(2014,11,17))
date_obj2 = Model1('date_obj2', date_field=date(2013,10,16))
time_obj1 = Model1('time_obj1', time_field=time(11,10,9))
time_obj2 = Model1('time_obj2', time_field=time(10,9,8))
datetime_obj1 = Model1('datetime_obj1', datetime_field=datetime(2014,4,3,1,9,0))
datetime_obj2 = Model1('datetime_obj2', datetime_field=datetime(2013,3,2,0,8,0))
db.session.add_all([ db.session.add_all([
model1_obj1, model1_obj2, model1_obj3, model1_obj4, model1_obj1, model1_obj2, model1_obj3, model1_obj4,
model2_obj1, model2_obj2, model2_obj3, model2_obj4, model2_obj1, model2_obj2, model2_obj3, model2_obj4,
date_obj1, time_obj1, datetime_obj1,
date_obj2, time_obj2, datetime_obj2
]) ])
db.session.commit() db.session.commit()
...@@ -423,7 +442,57 @@ def test_column_filters(): ...@@ -423,7 +442,57 @@ def test_column_filters():
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('model1_obj1' in data) ok_('model1_obj1' in data)
ok_('model1_obj2' not in data) ok_('model1_obj2' not in data)
# Test date, time, and datetime filters
view = CustomModelView(Model1, db.session,
column_filters=['date_field', 'datetime_field', 'time_field'],
endpoint="_datetime")
admin.add_view(view)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Date Field']],
[
(0, 'equals'),
(1, 'not equal'),
(2, 'greater than'),
(3, 'smaller than')
])
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Datetime Field']],
[
(4, 'equals'),
(5, 'not equal'),
(6, 'greater than'),
(7, 'smaller than')
])
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Time Field']],
[
(8, 'equals'),
(9, 'not equal'),
(10, 'greater than'),
(11, 'smaller than')
])
# date - equals
rv = client.get('/admin/_datetime/?flt0_0=2014-11-17')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('date_obj1' in data)
ok_('date_obj2' not in data)
# datetime - equals
rv = client.get('/admin/_datetime/?flt0_4=2014-04-03+01%3A09%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' in data)
ok_('datetime_obj2' not in data)
# time - equals
rv = client.get('/admin/_datetime/?flt0_8=11%3A10%3A09')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('time_obj1' in data)
ok_('time_obj2' not in data)
def test_url_args(): def test_url_args():
app, db, admin = setup() app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment