Commit e81f39fe authored by Serge S. Koval's avatar Serge S. Koval

Merge pull request #717 from pawl/newfiltersmongoengine

add empty, between, and list filters + tests to mongoengine and peewee
parents 591c38e2 0eff8c1e
Custom filter with SQLAlchemy backend example. Example of custom filters for the SQLAlchemy backend.
To run this example: To run this example:
...@@ -14,11 +14,11 @@ To run this example: ...@@ -14,11 +14,11 @@ To run this example:
3. Install requirements:: 3. Install requirements::
pip install -r 'examples/custom-filter/requirements.txt' pip install -r 'examples/sqla-custom-filter/requirements.txt'
4. Run the application:: 4. Run the application::
python examples/custom-filter/app.py python examples/sqla-custom-filter/app.py
The first time you run this example, a sample sqlite database gets populated automatically. To suppress this behaviour, The first time you run this example, a sample sqlite database gets populated automatically. To suppress this behaviour,
comment the following lines in app.py::: comment the following lines in app.py:::
......
from flask.ext.admin.babel import gettext import datetime
from flask.ext.admin.babel import lazy_gettext
from flask.ext.admin.model import filters from flask.ext.admin.model import filters
from .tools import parse_like_term from .tools import parse_like_term
from mongoengine.queryset import Q
class BaseMongoEngineFilter(filters.BaseFilter): class BaseMongoEngineFilter(filters.BaseFilter):
""" """
...@@ -33,7 +35,7 @@ class FilterEqual(BaseMongoEngineFilter): ...@@ -33,7 +35,7 @@ class FilterEqual(BaseMongoEngineFilter):
return query.filter(**flt) return query.filter(**flt)
def operation(self): def operation(self):
return gettext('equals') return lazy_gettext('equals')
class FilterNotEqual(BaseMongoEngineFilter): class FilterNotEqual(BaseMongoEngineFilter):
...@@ -42,7 +44,7 @@ class FilterNotEqual(BaseMongoEngineFilter): ...@@ -42,7 +44,7 @@ class FilterNotEqual(BaseMongoEngineFilter):
return query.filter(**flt) return query.filter(**flt)
def operation(self): def operation(self):
return gettext('not equal') return lazy_gettext('not equal')
class FilterLike(BaseMongoEngineFilter): class FilterLike(BaseMongoEngineFilter):
...@@ -52,7 +54,7 @@ class FilterLike(BaseMongoEngineFilter): ...@@ -52,7 +54,7 @@ class FilterLike(BaseMongoEngineFilter):
return query.filter(**flt) return query.filter(**flt)
def operation(self): def operation(self):
return gettext('contains') return lazy_gettext('contains')
class FilterNotLike(BaseMongoEngineFilter): class FilterNotLike(BaseMongoEngineFilter):
...@@ -62,7 +64,7 @@ class FilterNotLike(BaseMongoEngineFilter): ...@@ -62,7 +64,7 @@ class FilterNotLike(BaseMongoEngineFilter):
return query.filter(**flt) return query.filter(**flt)
def operation(self): def operation(self):
return gettext('not contains') return lazy_gettext('not contains')
class FilterGreater(BaseMongoEngineFilter): class FilterGreater(BaseMongoEngineFilter):
...@@ -71,7 +73,7 @@ class FilterGreater(BaseMongoEngineFilter): ...@@ -71,7 +73,7 @@ class FilterGreater(BaseMongoEngineFilter):
return query.filter(**flt) return query.filter(**flt)
def operation(self): def operation(self):
return gettext('greater than') return lazy_gettext('greater than')
class FilterSmaller(BaseMongoEngineFilter): class FilterSmaller(BaseMongoEngineFilter):
...@@ -80,7 +82,43 @@ class FilterSmaller(BaseMongoEngineFilter): ...@@ -80,7 +82,43 @@ class FilterSmaller(BaseMongoEngineFilter):
return query.filter(**flt) return query.filter(**flt)
def operation(self): def operation(self):
return gettext('smaller than') return lazy_gettext('smaller than')
class FilterEmpty(BaseMongoEngineFilter, filters.BaseBooleanFilter):
def apply(self, query, value):
if value == '1':
flt = {'%s' % self.column.name: None}
else:
flt = {'%s__ne' % self.column.name: None}
return query.filter(**flt)
def operation(self):
return lazy_gettext('empty')
class FilterInList(BaseMongoEngineFilter):
def __init__(self, column, name, options=None, data_type=None):
super(FilterInList, self).__init__(column, name, options, data_type='select2-tags')
def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value):
flt = {'%s__in' % self.column.name: value}
return query.filter(**flt)
def operation(self):
return lazy_gettext('in list')
class FilterNotInList(FilterInList):
def apply(self, query, value):
flt = {'%s__nin' % self.column.name: value}
return query.filter(**flt)
def operation(self):
return lazy_gettext('not in list')
# Customized type filters # Customized type filters
...@@ -96,10 +134,66 @@ class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter): ...@@ -96,10 +134,66 @@ class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter):
return query.filter(**flt) return query.filter(**flt)
class DateTimeEqualFilter(FilterEqual, filters.BaseDateTimeFilter):
pass
class DateTimeNotEqualFilter(FilterNotEqual, filters.BaseDateTimeFilter):
pass
class DateTimeGreaterFilter(FilterGreater, filters.BaseDateTimeFilter):
pass
class DateTimeSmallerFilter(FilterSmaller, filters.BaseDateTimeFilter):
pass
class DateTimeBetweenFilter(BaseMongoEngineFilter):
def __init__(self, column, name, options=None, data_type=None):
super(DateTimeBetweenFilter, self).__init__(column, name, options, data_type='datetimerangepicker')
def clean(self, value):
return [datetime.datetime.strptime(range, '%Y-%m-%d %H:%M:%S') for range in value.split(' to ')]
def apply(self, query, value):
start, end = value
flt = {'%s__gte' % self.column.name: start, '%s__lte' % self.column.name: end}
return query.filter(**flt)
def operation(self):
return lazy_gettext('between')
def validate(self, value):
try:
value = [datetime.datetime.strptime(range, '%Y-%m-%d %H:%M:%S') for range in value.split(' to ')]
if (len(value) == 2) and (value[0] <= value[1]):
return True
else:
return False
except ValueError:
return False
class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value):
start, end = value
return query.filter(Q(**{'%s__not__gte' % self.column.name: start}) |
Q(**{'%s__not__lte' % self.column.name: end}))
def operation(self):
return lazy_gettext('not between')
# Base peewee filter field converter # Base peewee filter field converter
class FilterConverter(filters.BaseFilterConverter): class FilterConverter(filters.BaseFilterConverter):
strings = (FilterEqual, FilterNotEqual, FilterLike, FilterNotLike) strings = (FilterEqual, FilterNotEqual, FilterLike, FilterNotLike, FilterEmpty, FilterInList, FilterNotInList)
numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller) numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller, FilterEmpty, FilterInList, FilterNotInList)
bool = (BooleanEqualFilter, BooleanNotEqualFilter)
datetime_filters = (DateTimeEqualFilter, DateTimeNotEqualFilter, DateTimeGreaterFilter,
DateTimeSmallerFilter, DateTimeBetweenFilter, DateTimeNotBetweenFilter,
FilterEmpty)
def convert(self, type_name, column, name): def convert(self, type_name, column, name):
if type_name in self.converters: if type_name in self.converters:
...@@ -107,24 +201,18 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -107,24 +201,18 @@ class FilterConverter(filters.BaseFilterConverter):
return None return None
@filters.convert('StringField', 'EmailField') @filters.convert('StringField', 'EmailField', 'URLField')
def conv_string(self, column, name): def conv_string(self, column, name):
return [f(column, name) for f in self.strings] return [f(column, name) for f in self.strings]
@filters.convert('BooleanField') @filters.convert('BooleanField')
def conv_bool(self, column, name): def conv_bool(self, column, name):
return [BooleanEqualFilter(column, name), return [f(column, name) for f in self.bool]
BooleanNotEqualFilter(column, name)]
@filters.convert('IntField', 'DecimalField', 'FloatField') @filters.convert('IntField', 'DecimalField', 'FloatField', 'LongField')
def conv_int(self, column, name): def conv_int(self, column, name):
return [f(column, name) for f in self.numeric] return [f(column, name) for f in self.numeric]
@filters.convert('DateField') @filters.convert('DateTimeField', 'ComplexDateTimeField')
def conv_date(self, column, name):
return [f(column, name, data_type='datepicker') for f in self.numeric]
@filters.convert('DateTimeField')
def conv_datetime(self, column, name): def conv_datetime(self, column, name):
return [f(column, name, data_type='datetimepicker') return [f(column, name) for f in self.datetime_filters]
for f in self.numeric]
...@@ -433,7 +433,7 @@ class ModelView(BaseModelView): ...@@ -433,7 +433,7 @@ class ModelView(BaseModelView):
if self._filters: if self._filters:
for flt, flt_name, value in filters: for flt, flt_name, value in filters:
f = self._filters[flt] f = self._filters[flt]
query = f.apply(query, value) query = f.apply(query, f.clean(value))
# Search # Search
if self._search_supported and search: if self._search_supported and search:
......
This diff is collapsed.
...@@ -313,7 +313,7 @@ class ModelView(BaseModelView): ...@@ -313,7 +313,7 @@ class ModelView(BaseModelView):
f = self._filters[flt] f = self._filters[flt]
query = self._handle_join(query, f.column, joins) query = self._handle_join(query, f.column, joins)
query = f.apply(query, value) query = f.apply(query, f.clean(value))
# Get count # Get count
count = query.count() count = query.count()
......
...@@ -13,6 +13,7 @@ from flask.ext.admin.contrib.mongoengine import ModelView ...@@ -13,6 +13,7 @@ from flask.ext.admin.contrib.mongoengine import ModelView
from . import setup from . import setup
from datetime import datetime
class CustomModelView(ModelView): class CustomModelView(ModelView):
def __init__(self, model, def __init__(self, model,
...@@ -32,6 +33,7 @@ def create_models(db): ...@@ -32,6 +33,7 @@ def create_models(db):
test2 = db.StringField(max_length=20) test2 = db.StringField(max_length=20)
test3 = db.StringField() test3 = db.StringField()
test4 = db.StringField() test4 = db.StringField()
datetime_field = db.DateTimeField()
def __str__(self): def __str__(self):
return self.test1 return self.test1
...@@ -127,20 +129,25 @@ def test_column_filters(): ...@@ -127,20 +129,25 @@ def test_column_filters():
Model1, Model2 = create_models(db) Model1, Model2 = create_models(db)
# fill DB with values # fill DB with values
model1_obj1 = Model1(test1=u'test1_val_1', test2=u'test2_val_1') Model1('test1_val_1', 'test2_val_1').save()
model1_obj1.save() Model1('test1_val_2', 'test2_val_2').save()
model1_obj2 = Model1(test1=u'test1_val_2', test2=u'test2_val_2') Model1('test1_val_3', 'test2_val_3').save()
model1_obj2.save() Model1('test1_val_4', 'test2_val_4').save()
model2_obj1 = Model2(string_field=u'string_field_val_1', int_field=5000) Model1(None, 'empty_obj').save()
model2_obj1.save()
model2_obj2 = Model2(string_field=u'string_field_val_2', int_field=9000) Model2('string_field_val_1', None).save()
model2_obj2.save() Model2('string_field_val_2', None).save()
Model2('string_field_val_3', 5000).save()
Model2('string_field_val_4', 9000).save()
Model1('datetime_obj1', datetime_field=datetime(2014,4,3,1,9,0)).save()
Model1('datetime_obj2', datetime_field=datetime(2013,3,2,0,8,0)).save()
# Test string filter # Test string filter
view = CustomModelView(Model1, column_filters=['test1']) view = CustomModelView(Model1, column_filters=['test1'])
admin.add_view(view) admin.add_view(view)
eq_(len(view._filters), 4) eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']], eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
[ [
...@@ -148,6 +155,9 @@ def test_column_filters(): ...@@ -148,6 +155,9 @@ def test_column_filters():
(1, 'not equal'), (1, 'not equal'),
(2, 'contains'), (2, 'contains'),
(3, 'not contains'), (3, 'not contains'),
(4, 'empty'),
(5, 'in list'),
(6, 'not in list'),
]) ])
# Make some test clients # Make some test clients
...@@ -181,6 +191,40 @@ def test_column_filters(): ...@@ -181,6 +191,40 @@ def test_column_filters():
ok_('test2_val_1' not in data) ok_('test2_val_1' not in data)
ok_('test1_val_2' in data) ok_('test1_val_2' in data)
# string - empty
rv = client.get('/admin/model1/?flt0_4=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('empty_obj' in data)
ok_('test1_val_1' not in data)
ok_('test1_val_2' not in data)
# string - not empty
rv = client.get('/admin/model1/?flt0_4=0')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('empty_obj' not in data)
ok_('test1_val_1' in data)
ok_('test1_val_2' in data)
# string - in list
rv = client.get('/admin/model1/?flt0_5=test1_val_1%2Ctest1_val_2')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' in data)
ok_('test2_val_2' in data)
ok_('test1_val_3' not in data)
ok_('test1_val_4' not in data)
# string - not in list
rv = client.get('/admin/model1/?flt0_6=test1_val_1%2Ctest1_val_2')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test2_val_2' not in data)
ok_('test1_val_3' in data)
ok_('test1_val_4' in data)
# Test numeric filter # Test numeric filter
view = CustomModelView(Model2, column_filters=['int_field']) view = CustomModelView(Model2, column_filters=['int_field'])
admin.add_view(view) admin.add_view(view)
...@@ -191,35 +235,149 @@ def test_column_filters(): ...@@ -191,35 +235,149 @@ def test_column_filters():
(1, 'not equal'), (1, 'not equal'),
(2, 'greater than'), (2, 'greater than'),
(3, 'smaller than'), (3, 'smaller than'),
(4, 'empty'),
(5, 'in list'),
(6, 'not in list'),
]) ])
# integer - equals # integer - equals
rv = client.get('/admin/model2/?flt0_0=5000') rv = client.get('/admin/model2/?flt0_0=5000')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('string_field_val_1' in data) ok_('string_field_val_3' in data)
ok_('string_field_val_2' not in data) ok_('string_field_val_4' not in data)
# integer - not equal # integer - not equal
rv = client.get('/admin/model2/?flt0_1=5000') rv = client.get('/admin/model2/?flt0_1=5000')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data) ok_('string_field_val_3' not in data)
ok_('string_field_val_2' in data) ok_('string_field_val_4' in data)
# integer - greater # integer - greater
rv = client.get('/admin/model2/?flt0_2=6000') rv = client.get('/admin/model2/?flt0_2=6000')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data) ok_('string_field_val_3' not in data)
ok_('string_field_val_2' in data) ok_('string_field_val_4' in data)
# integer - smaller # integer - smaller
rv = client.get('/admin/model2/?flt0_3=6000') rv = client.get('/admin/model2/?flt0_3=6000')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('string_field_val_3' in data)
ok_('string_field_val_4' not in data)
# integer - empty
rv = client.get('/admin/model2/?flt0_4=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' in data) ok_('string_field_val_1' in data)
ok_('string_field_val_2' in data)
ok_('string_field_val_3' not in data)
ok_('string_field_val_4' not in data)
# integer - not empty
rv = client.get('/admin/model2/?flt0_4=0')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data)
ok_('string_field_val_2' not in data) ok_('string_field_val_2' not in data)
ok_('string_field_val_3' in data)
ok_('string_field_val_4' in data)
# integer - in list
rv = client.get('/admin/model2/?flt0_5=5000%2C9000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data)
ok_('string_field_val_2' not in data)
ok_('string_field_val_3' in data)
ok_('string_field_val_4' in data)
# integer - not in list
rv = client.get('/admin/model2/?flt0_6=5000%2C9000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' in data)
ok_('string_field_val_2' in data)
ok_('string_field_val_3' not in data)
ok_('string_field_val_4' not in data)
# Test datetime filter
view = CustomModelView(Model1,
column_filters=['datetime_field'],
endpoint="_datetime")
admin.add_view(view)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Datetime Field']],
[
(0, 'equals'),
(1, 'not equal'),
(2, 'greater than'),
(3, 'smaller than'),
(4, 'between'),
(5, 'not between'),
(6, 'empty'),
])
# datetime - equals
rv = client.get('/admin/_datetime/?flt0_0=2014-04-03+01%3A09%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' in data)
ok_('datetime_obj2' not in data)
# datetime - not equal
rv = client.get('/admin/_datetime/?flt0_1=2014-04-03+01%3A09%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' in data)
# datetime - greater
rv = client.get('/admin/_datetime/?flt0_2=2014-04-03+01%3A08%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' in data)
ok_('datetime_obj2' not in data)
# datetime - smaller
rv = client.get('/admin/_datetime/?flt0_3=2014-04-03+01%3A08%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' in data)
# datetime - between
rv = client.get('/admin/_datetime/?flt0_4=2014-04-02+00%3A00%3A00+to+2014-11-20+23%3A59%3A59')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' in data)
ok_('datetime_obj2' not in data)
# datetime - not between
rv = client.get('/admin/_datetime/?flt0_5=2014-04-02+00%3A00%3A00+to+2014-11-20+23%3A59%3A59')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' in data)
# datetime - empty
rv = client.get('/admin/_datetime/?flt0_6=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test1_val_1' in data)
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' not in data)
# datetime - not empty
rv = client.get('/admin/_datetime/?flt0_6=0')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test1_val_1' not in data)
ok_('datetime_obj1' in data)
ok_('datetime_obj2' in data)
def test_default_sort(): def test_default_sort():
app, db, admin = setup() app, db, admin = setup()
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment