Commit a8ac91c5 authored by Paul Brown's avatar Paul Brown

add empty, between, and list filters + tests to mongoengine and peewee

parent 591c38e2
from flask.ext.admin.babel import gettext
import datetime
from flask.ext.admin.babel import lazy_gettext
from flask.ext.admin.model import filters
from .tools import parse_like_term
from mongoengine.queryset import Q
class BaseMongoEngineFilter(filters.BaseFilter):
"""
......@@ -33,7 +35,7 @@ class FilterEqual(BaseMongoEngineFilter):
return query.filter(**flt)
def operation(self):
return gettext('equals')
return lazy_gettext('equals')
class FilterNotEqual(BaseMongoEngineFilter):
......@@ -42,7 +44,7 @@ class FilterNotEqual(BaseMongoEngineFilter):
return query.filter(**flt)
def operation(self):
return gettext('not equal')
return lazy_gettext('not equal')
class FilterLike(BaseMongoEngineFilter):
......@@ -52,7 +54,7 @@ class FilterLike(BaseMongoEngineFilter):
return query.filter(**flt)
def operation(self):
return gettext('contains')
return lazy_gettext('contains')
class FilterNotLike(BaseMongoEngineFilter):
......@@ -62,7 +64,7 @@ class FilterNotLike(BaseMongoEngineFilter):
return query.filter(**flt)
def operation(self):
return gettext('not contains')
return lazy_gettext('not contains')
class FilterGreater(BaseMongoEngineFilter):
......@@ -71,7 +73,7 @@ class FilterGreater(BaseMongoEngineFilter):
return query.filter(**flt)
def operation(self):
return gettext('greater than')
return lazy_gettext('greater than')
class FilterSmaller(BaseMongoEngineFilter):
......@@ -80,8 +82,44 @@ class FilterSmaller(BaseMongoEngineFilter):
return query.filter(**flt)
def operation(self):
return gettext('smaller than')
return lazy_gettext('smaller than')
class FilterEmpty(BaseMongoEngineFilter, filters.BaseBooleanFilter):
def apply(self, query, value):
if value == '1':
flt = {'%s' % self.column.name: None}
else:
flt = {'%s__ne' % self.column.name: None}
return query.filter(**flt)
def operation(self):
return lazy_gettext('empty')
class FilterInList(BaseMongoEngineFilter):
def __init__(self, column, name, options=None, data_type=None):
super(FilterInList, self).__init__(column, name, options, data_type='select2-tags')
def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value):
flt = {'%s__in' % self.column.name: value}
return query.filter(**flt)
def operation(self):
return lazy_gettext('in list')
class FilterNotInList(FilterInList):
def apply(self, query, value):
flt = {'%s__nin' % self.column.name: value}
return query.filter(**flt)
def operation(self):
return lazy_gettext('not in list')
# Customized type filters
class BooleanEqualFilter(FilterEqual, filters.BaseBooleanFilter):
......@@ -96,10 +134,66 @@ class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter):
return query.filter(**flt)
class DateTimeEqualFilter(FilterEqual, filters.BaseDateTimeFilter):
pass
class DateTimeNotEqualFilter(FilterNotEqual, filters.BaseDateTimeFilter):
pass
class DateTimeGreaterFilter(FilterGreater, filters.BaseDateTimeFilter):
pass
class DateTimeSmallerFilter(FilterSmaller, filters.BaseDateTimeFilter):
pass
class DateTimeBetweenFilter(BaseMongoEngineFilter):
def __init__(self, column, name, options=None, data_type=None):
super(DateTimeBetweenFilter, self).__init__(column, name, options, data_type='datetimerangepicker')
def clean(self, value):
return [datetime.datetime.strptime(range, '%Y-%m-%d %H:%M:%S') for range in value.split(' to ')]
def apply(self, query, value):
start, end = value
flt = {'%s__gte' % self.column.name: start, '%s__lte' % self.column.name: end}
return query.filter(**flt)
def operation(self):
return lazy_gettext('between')
def validate(self, value):
try:
value = [datetime.datetime.strptime(range, '%Y-%m-%d %H:%M:%S') for range in value.split(' to ')]
if (len(value) == 2) and (value[0] <= value[1]):
return True
else:
return False
except ValueError:
return False
class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value):
start, end = value
return query.filter(Q(**{'%s__not__gte' % self.column.name: start}) |
Q(**{'%s__not__lte' % self.column.name: end}))
def operation(self):
return lazy_gettext('not between')
# Base peewee filter field converter
class FilterConverter(filters.BaseFilterConverter):
strings = (FilterEqual, FilterNotEqual, FilterLike, FilterNotLike)
numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller)
strings = (FilterEqual, FilterNotEqual, FilterLike, FilterNotLike, FilterEmpty, FilterInList, FilterNotInList)
numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller, FilterEmpty, FilterInList, FilterNotInList)
bool = (BooleanEqualFilter, BooleanNotEqualFilter)
datetime_filters = (DateTimeEqualFilter, DateTimeNotEqualFilter, DateTimeGreaterFilter,
DateTimeSmallerFilter, DateTimeBetweenFilter, DateTimeNotBetweenFilter,
FilterEmpty)
def convert(self, type_name, column, name):
if type_name in self.converters:
......@@ -107,24 +201,18 @@ class FilterConverter(filters.BaseFilterConverter):
return None
@filters.convert('StringField', 'EmailField')
@filters.convert('StringField', 'EmailField', 'URLField')
def conv_string(self, column, name):
return [f(column, name) for f in self.strings]
@filters.convert('BooleanField')
def conv_bool(self, column, name):
return [BooleanEqualFilter(column, name),
BooleanNotEqualFilter(column, name)]
return [f(column, name) for f in self.bool]
@filters.convert('IntField', 'DecimalField', 'FloatField')
@filters.convert('IntField', 'DecimalField', 'FloatField', 'LongField')
def conv_int(self, column, name):
return [f(column, name) for f in self.numeric]
@filters.convert('DateField')
def conv_date(self, column, name):
return [f(column, name, data_type='datepicker') for f in self.numeric]
@filters.convert('DateTimeField')
@filters.convert('DateTimeField', 'ComplexDateTimeField')
def conv_datetime(self, column, name):
return [f(column, name, data_type='datetimepicker')
for f in self.numeric]
return [f(column, name) for f in self.datetime_filters]
......@@ -433,7 +433,7 @@ class ModelView(BaseModelView):
if self._filters:
for flt, flt_name, value in filters:
f = self._filters[flt]
query = f.apply(query, value)
query = f.apply(query, f.clean(value))
# Search
if self._search_supported and search:
......
This diff is collapsed.
......@@ -313,7 +313,7 @@ class ModelView(BaseModelView):
f = self._filters[flt]
query = self._handle_join(query, f.column, joins)
query = f.apply(query, value)
query = f.apply(query, f.clean(value))
# Get count
count = query.count()
......
......@@ -13,6 +13,7 @@ from flask.ext.admin.contrib.mongoengine import ModelView
from . import setup
from datetime import datetime
class CustomModelView(ModelView):
def __init__(self, model,
......@@ -32,6 +33,7 @@ def create_models(db):
test2 = db.StringField(max_length=20)
test3 = db.StringField()
test4 = db.StringField()
datetime_field = db.DateTimeField()
def __str__(self):
return self.test1
......@@ -127,20 +129,25 @@ def test_column_filters():
Model1, Model2 = create_models(db)
# fill DB with values
model1_obj1 = Model1(test1=u'test1_val_1', test2=u'test2_val_1')
model1_obj1.save()
model1_obj2 = Model1(test1=u'test1_val_2', test2=u'test2_val_2')
model1_obj2.save()
model2_obj1 = Model2(string_field=u'string_field_val_1', int_field=5000)
model2_obj1.save()
model2_obj2 = Model2(string_field=u'string_field_val_2', int_field=9000)
model2_obj2.save()
Model1('test1_val_1', 'test2_val_1').save()
Model1('test1_val_2', 'test2_val_2').save()
Model1('test1_val_3', 'test2_val_3').save()
Model1('test1_val_4', 'test2_val_4').save()
Model1(None, 'empty_obj').save()
Model2('string_field_val_1', None).save()
Model2('string_field_val_2', None).save()
Model2('string_field_val_3', 5000).save()
Model2('string_field_val_4', 9000).save()
Model1('datetime_obj1', datetime_field=datetime(2014,4,3,1,9,0)).save()
Model1('datetime_obj2', datetime_field=datetime(2013,3,2,0,8,0)).save()
# Test string filter
view = CustomModelView(Model1, column_filters=['test1'])
admin.add_view(view)
eq_(len(view._filters), 4)
eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
[
......@@ -148,6 +155,9 @@ def test_column_filters():
(1, 'not equal'),
(2, 'contains'),
(3, 'not contains'),
(4, 'empty'),
(5, 'in list'),
(6, 'not in list'),
])
# Make some test clients
......@@ -181,6 +191,40 @@ def test_column_filters():
ok_('test2_val_1' not in data)
ok_('test1_val_2' in data)
# string - empty
rv = client.get('/admin/model1/?flt0_4=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('empty_obj' in data)
ok_('test1_val_1' not in data)
ok_('test1_val_2' not in data)
# string - not empty
rv = client.get('/admin/model1/?flt0_4=0')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('empty_obj' not in data)
ok_('test1_val_1' in data)
ok_('test1_val_2' in data)
# string - in list
rv = client.get('/admin/model1/?flt0_5=test1_val_1%2Ctest1_val_2')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' in data)
ok_('test2_val_2' in data)
ok_('test1_val_3' not in data)
ok_('test1_val_4' not in data)
# string - not in list
rv = client.get('/admin/model1/?flt0_6=test1_val_1%2Ctest1_val_2')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test2_val_2' not in data)
ok_('test1_val_3' in data)
ok_('test1_val_4' in data)
# Test numeric filter
view = CustomModelView(Model2, column_filters=['int_field'])
admin.add_view(view)
......@@ -191,35 +235,149 @@ def test_column_filters():
(1, 'not equal'),
(2, 'greater than'),
(3, 'smaller than'),
(4, 'empty'),
(5, 'in list'),
(6, 'not in list'),
])
# integer - equals
rv = client.get('/admin/model2/?flt0_0=5000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' in data)
ok_('string_field_val_2' not in data)
ok_('string_field_val_3' in data)
ok_('string_field_val_4' not in data)
# integer - not equal
rv = client.get('/admin/model2/?flt0_1=5000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data)
ok_('string_field_val_2' in data)
ok_('string_field_val_3' not in data)
ok_('string_field_val_4' in data)
# integer - greater
rv = client.get('/admin/model2/?flt0_2=6000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data)
ok_('string_field_val_2' in data)
ok_('string_field_val_3' not in data)
ok_('string_field_val_4' in data)
# integer - smaller
rv = client.get('/admin/model2/?flt0_3=6000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_3' in data)
ok_('string_field_val_4' not in data)
# integer - empty
rv = client.get('/admin/model2/?flt0_4=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' in data)
ok_('string_field_val_2' in data)
ok_('string_field_val_3' not in data)
ok_('string_field_val_4' not in data)
# integer - not empty
rv = client.get('/admin/model2/?flt0_4=0')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data)
ok_('string_field_val_2' not in data)
ok_('string_field_val_3' in data)
ok_('string_field_val_4' in data)
# integer - in list
rv = client.get('/admin/model2/?flt0_5=5000%2C9000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' not in data)
ok_('string_field_val_2' not in data)
ok_('string_field_val_3' in data)
ok_('string_field_val_4' in data)
# integer - not in list
rv = client.get('/admin/model2/?flt0_6=5000%2C9000')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('string_field_val_1' in data)
ok_('string_field_val_2' in data)
ok_('string_field_val_3' not in data)
ok_('string_field_val_4' not in data)
# Test datetime filter
view = CustomModelView(Model1,
column_filters=['datetime_field'],
endpoint="_datetime")
admin.add_view(view)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Datetime Field']],
[
(0, 'equals'),
(1, 'not equal'),
(2, 'greater than'),
(3, 'smaller than'),
(4, 'between'),
(5, 'not between'),
(6, 'empty'),
])
# datetime - equals
rv = client.get('/admin/_datetime/?flt0_0=2014-04-03+01%3A09%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' in data)
ok_('datetime_obj2' not in data)
# datetime - not equal
rv = client.get('/admin/_datetime/?flt0_1=2014-04-03+01%3A09%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' in data)
# datetime - greater
rv = client.get('/admin/_datetime/?flt0_2=2014-04-03+01%3A08%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' in data)
ok_('datetime_obj2' not in data)
# datetime - smaller
rv = client.get('/admin/_datetime/?flt0_3=2014-04-03+01%3A08%3A00')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' in data)
# datetime - between
rv = client.get('/admin/_datetime/?flt0_4=2014-04-02+00%3A00%3A00+to+2014-11-20+23%3A59%3A59')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' in data)
ok_('datetime_obj2' not in data)
# datetime - not between
rv = client.get('/admin/_datetime/?flt0_5=2014-04-02+00%3A00%3A00+to+2014-11-20+23%3A59%3A59')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' in data)
# datetime - empty
rv = client.get('/admin/_datetime/?flt0_6=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test1_val_1' in data)
ok_('datetime_obj1' not in data)
ok_('datetime_obj2' not in data)
# datetime - not empty
rv = client.get('/admin/_datetime/?flt0_6=0')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test1_val_1' not in data)
ok_('datetime_obj1' in data)
ok_('datetime_obj2' in data)
def test_default_sort():
app, db, admin = setup()
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment