Commit 76124b50 authored by Serge S. Koval's avatar Serge S. Koval

Merge pull request #139 from ArtemSerga/4

Fix filters, presented as dot-separated string
parents 7e990f16 e5756416
......@@ -394,8 +394,20 @@ class ModelView(BaseModelView):
"""
Return list of enabled filters
"""
join_tables = []
if isinstance(name, basestring):
attr = getattr(self.model, name, None)
model = self.model
for attribute in name.split('.'):
value = getattr(model, attribute)
if (
hasattr(value, 'property')
and hasattr(value.property, 'direction')
):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
join_tables.append(table)
attr = value
else:
attr = name
......@@ -423,9 +435,11 @@ class ModelView(BaseModelView):
visible_name)
if flt:
if self._need_join(column.table):
self._filter_joins[column.table.name] = column.table
table = column.table
if join_tables:
self._filter_joins[table.name.name] = join_tables
elif self._need_join(table.name):
self._filter_joins[table.name.name] = [table.name]
filters.extend(flt)
return filters
......@@ -449,9 +463,10 @@ class ModelView(BaseModelView):
if flt:
# If there's relation to other table, do it
if self._need_join(column.table):
self._filter_joins[column.table.name] = column.table
if join_tables:
self._filter_joins[column.table.name] = join_tables
elif self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
return flt
def is_valid_filter(self, filter):
......@@ -581,10 +596,10 @@ class ModelView(BaseModelView):
# Figure out join
tbl = flt.column.table.name
join = self._filter_joins.get(tbl)
if join is not None:
query = query.join(join)
joins.add(tbl)
join_tables = self._filter_joins.get(tbl, [])
for table in join_tables:
query = query.join(table)
joins.add(table)
# Apply filter
query = flt.apply(query, value)
......
......@@ -20,22 +20,33 @@ class CustomModelView(ModelView):
def create_models(db):
class Model1(db.Model):
def __init__(self, test1=None, test2=None, test3=None, test4=None):
def __init__(self, test1=None, test2=None, test3=None, test4=None, bool_field=False):
self.test1 = test1
self.test2 = test2
self.test3 = test3
self.test4 = test4
self.bool_field = bool_field
id = db.Column(db.Integer, primary_key=True)
test1 = db.Column(db.String(20))
test2 = db.Column(db.Unicode(20))
test3 = db.Column(db.Text)
test4 = db.Column(db.UnicodeText)
bool_field = db.Column(db.Boolean)
class Model2(db.Model):
def __init__(self, string_field=None, int_field=None, bool_field=None, model1=None):
self.string_field = string_field
self.int_field = int_field
self.bool_field = bool_field
self.model1 = model1
id = db.Column(db.Integer, primary_key=True)
string_field = db.Column(db.String)
int_field = db.Column(db.Integer)
bool_field = db.Column(db.Boolean)
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1)
db.create_all()
......@@ -155,7 +166,10 @@ def test_exclude_columns():
column_exclude_list=['test2', 'test4'])
admin.add_view(view)
eq_(view._list_columns, [('test1', 'Test1'), ('test3', 'Test3')])
eq_(
view._list_columns,
[('test1', 'Test1'), ('test3', 'Test3'), ('bool_field', 'Bool Field')]
)
client = app.test_client()
......@@ -202,28 +216,43 @@ def test_column_filters():
eq_(len(view._filters), 4)
eq_(view._filter_dict, {'Test1': [(0, 'equals'),
(1, 'not equal'),
(2, 'contains'),
(3, 'not contains')]})
db.session.add(Model1('model1'))
db.session.add(Model1('model2'))
db.session.add(Model1('model3'))
db.session.add(Model1('model4'))
eq_(view._filter_dict, {
'Test1': [
(0, 'equals'),
(1, 'not equal'),
(2, 'contains'),
(3, 'not contains')
],
})
# Fill DB
model1_obj1 = Model1('model1_obj1', bool_field=True)
model1_obj2 = Model1('model1_obj2')
model1_obj3 = Model1('model1_obj3')
model1_obj4 = Model1('model1_obj4')
model2_obj1 = Model2('model2_obj1', model1=model1_obj1)
model2_obj2 = Model2('model2_obj2', model1=model1_obj1)
model2_obj3 = Model2('model2_obj3')
model2_obj4 = Model2('model2_obj4')
db.session.add_all([
model1_obj1, model1_obj2, model1_obj3, model1_obj4,
model2_obj1, model2_obj2, model2_obj3, model2_obj4,
])
db.session.commit()
client = app.test_client()
rv = client.get('/admin/model1view/?flt0_0=model1')
rv = client.get('/admin/model1view/?flt0_0=model1_obj1')
eq_(rv.status_code, 200)
ok_('model1' in rv.data)
ok_('model2' not in rv.data)
ok_('model1_obj1' in rv.data)
ok_('model1_obj2' not in rv.data)
rv = client.get('/admin/model1view/?flt0_5=model1')
rv = client.get('/admin/model1view/?flt0_5=model1_obj1')
eq_(rv.status_code, 200)
ok_('model1' in rv.data)
ok_('model2' in rv.data)
ok_('model1_obj1' in rv.data)
ok_('model1_obj2' in rv.data)
# Test different filter types
view = CustomModelView(Model2, db.session,
......@@ -234,6 +263,27 @@ def test_column_filters():
(2, 'greater than'), (3, 'smaller than')]})
#Test filters to joined table field
view = CustomModelView(
Model2, db.session,
endpoint='_model2',
column_filters=['model1.bool_field'],
column_list=[
'string_field',
'model1.id',
'model1.bool_field',
]
)
admin.add_view(view)
rv = client.get('/admin/_model2/?flt1_0=1')
eq_(rv.status_code, 200)
ok_('model2_obj1' in rv.data)
ok_('model2_obj2' in rv.data)
ok_('model2_obj3' not in rv.data)
ok_('model2_obj4' not in rv.data)
def test_url_args():
app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment