Commit 904178eb authored by Serge S. Koval's avatar Serge S. Koval

#139 fixes. Additional unit tests

parent 76124b50
...@@ -394,19 +394,21 @@ class ModelView(BaseModelView): ...@@ -394,19 +394,21 @@ class ModelView(BaseModelView):
""" """
Return list of enabled filters Return list of enabled filters
""" """
join_tables = [] join_tables = []
if isinstance(name, basestring): if isinstance(name, basestring):
model = self.model model = self.model
for attribute in name.split('.'): for attribute in name.split('.'):
value = getattr(model, attribute) value = getattr(model, attribute)
if ( if (hasattr(value, 'property') and
hasattr(value, 'property') hasattr(value.property, 'direction')):
and hasattr(value.property, 'direction')
):
model = value.property.mapper.class_ model = value.property.mapper.class_
table = model.__table__ table = model.__table__
if self._need_join(table): if self._need_join(table):
join_tables.append(table) join_tables.append(table)
attr = value attr = value
else: else:
attr = name attr = name
...@@ -436,10 +438,11 @@ class ModelView(BaseModelView): ...@@ -436,10 +438,11 @@ class ModelView(BaseModelView):
if flt: if flt:
table = column.table table = column.table
if join_tables: if join_tables:
self._filter_joins[table.name.name] = join_tables self._filter_joins[table.name] = join_tables
elif self._need_join(table.name): elif self._need_join(table.name):
self._filter_joins[table.name.name] = [table.name] self._filter_joins[table.name] = [table.name]
filters.extend(flt) filters.extend(flt)
return filters return filters
...@@ -451,10 +454,14 @@ class ModelView(BaseModelView): ...@@ -451,10 +454,14 @@ class ModelView(BaseModelView):
column = columns[0] column = columns[0]
if not isinstance(name, basestring): if self._need_join(column.table):
visible_name = self.get_column_name(name.property.key) visible_name = '%s / %s' % (self.get_column_name(column.table.name),
self.get_column_name(column.name))
else: else:
visible_name = self.get_column_name(name) if not isinstance(name, basestring):
visible_name = self.get_column_name(name.property.key)
else:
visible_name = self.get_column_name(name)
type_name = type(column.type).__name__ type_name = type(column.type).__name__
flt = self.filter_converter.convert(type_name, flt = self.filter_converter.convert(type_name,
...@@ -467,6 +474,7 @@ class ModelView(BaseModelView): ...@@ -467,6 +474,7 @@ class ModelView(BaseModelView):
self._filter_joins[column.table.name] = join_tables self._filter_joins[column.table.name] = join_tables
elif self._need_join(column.table): elif self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table] self._filter_joins[column.table.name] = [column.table]
return flt return flt
def is_valid_filter(self, filter): def is_valid_filter(self, filter):
...@@ -594,12 +602,15 @@ class ModelView(BaseModelView): ...@@ -594,12 +602,15 @@ class ModelView(BaseModelView):
for idx, value in filters: for idx, value in filters:
flt = self._filters[idx] flt = self._filters[idx]
# Figure out join # Figure out joins
tbl = flt.column.table.name tbl = flt.column.table.name
join_tables = self._filter_joins.get(tbl, []) join_tables = self._filter_joins.get(tbl, [])
for table in join_tables: for table in join_tables:
query = query.join(table) if table.name not in joins:
joins.add(table) query = query.join(table)
joins.add(table)
# Apply filter # Apply filter
query = flt.apply(query, value) query = flt.apply(query, value)
...@@ -616,6 +627,7 @@ class ModelView(BaseModelView): ...@@ -616,6 +627,7 @@ class ModelView(BaseModelView):
if sort_column in self._sortable_columns: if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column] sort_field = self._sortable_columns[sort_column]
# TODO: Preprocessing for joins
# Try to handle it as a string # Try to handle it as a string
if isinstance(sort_field, basestring): if isinstance(sort_field, basestring):
# Create automatic join against a table if column name # Create automatic join against a table if column name
......
...@@ -29,7 +29,10 @@ class BaseFilter(object): ...@@ -29,7 +29,10 @@ class BaseFilter(object):
:param view: :param view:
Associated administrative view class. Associated administrative view class.
""" """
return self.options if self.options:
return [(v, unicode(n)) for v, n in self.options]
return None
def validate(self, value): def validate(self, value):
""" """
......
...@@ -94,8 +94,8 @@ def test_model(): ...@@ -94,8 +94,8 @@ def test_model():
model = Model1.select().get() model = Model1.select().get()
eq_(model.test1, 'test1large') eq_(model.test1, 'test1large')
eq_(model.test2, 'test2') eq_(model.test2, 'test2')
eq_(model.test3, None) ok_(model.test3 is None or model.test3 == '')
eq_(model.test4, None) ok_(model.test4 is None or model.test4 == '')
rv = client.get('/admin/model1view/') rv = client.get('/admin/model1view/')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
...@@ -112,8 +112,8 @@ def test_model(): ...@@ -112,8 +112,8 @@ def test_model():
model = Model1.select().get() model = Model1.select().get()
eq_(model.test1, 'test1small') eq_(model.test1, 'test1small')
eq_(model.test2, 'test2large') eq_(model.test2, 'test2large')
eq_(model.test3, None) ok_(model.test3 is None or model.test3 == '')
eq_(model.test4, None) ok_(model.test4 is None or model.test4 == '')
url = '/admin/model1view/delete/?id=%s' % model.id url = '/admin/model1view/delete/?id=%s' % model.id
rv = client.post(url) rv = client.post(url)
......
...@@ -223,9 +223,53 @@ def test_column_filters(): ...@@ -223,9 +223,53 @@ def test_column_filters():
(2, 'contains'), (2, 'contains'),
(3, 'not contains') (3, 'not contains')
], ],
})
# Test filter that references property
view = CustomModelView(Model2, db.session,
column_filters=['model1'])
eq_(view._filter_dict, {
'Model1 / Test1': [
(0, 'equals'),
(1, 'not equal'),
(2, 'contains'),
(3, 'not contains')
],
'Model1 / Test2': [
(4, 'equals'),
(5, 'not equal'),
(6, 'contains'),
(7, 'not contains')
],
'Model1 / Test3': [
(8, 'equals'),
(9, 'not equal'),
(10, 'contains'),
(11, 'not contains')
],
'Model1 / Test4': [
(12, 'equals'),
(13, 'not equal'),
(14, 'contains'),
(15, 'not contains')
],
'Model1 / Bool Field': [
(16, 'equals'),
(17, 'not equal'),
],
}) })
# Test filter with a dot
view = CustomModelView(Model2, db.session,
column_filters=['model1.bool_field'])
eq_(view._filter_dict, {
'Model1 / Bool Field': [
(0, 'equals'),
(1, 'not equal'),
],
})
# Fill DB # Fill DB
model1_obj1 = Model1('model1_obj1', bool_field=True) model1_obj1 = Model1('model1_obj1', bool_field=True)
model1_obj2 = Model1('model1_obj2') model1_obj2 = Model1('model1_obj2')
...@@ -429,6 +473,7 @@ def test_on_model_change_delete(): ...@@ -429,6 +473,7 @@ def test_on_model_change_delete():
client.post(url) client.post(url)
ok_(view.deleted) ok_(view.deleted)
def test_multiple_delete(): def test_multiple_delete():
app, db, admin = setup() app, db, admin = setup()
M1, _ = create_models(db) M1, _ = create_models(db)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment