Commit c8cb7347 authored by Serge S. Koval's avatar Serge S. Koval

Minor code refactoring and fixes.

parent 815fe815
......@@ -83,7 +83,7 @@ class AdminModelConverter(ModelConverter):
'query_factory': lambda: self.view.session.query(remote_model)
})
if local_column.nullable or prop.direction.name == 'MANYTOMANY':
if local_column.nullable:
kwargs['validators'].append(validators.Optional())
elif prop.direction.name != 'MANYTOMANY':
kwargs['validators'].append(validators.Required())
......
......@@ -134,9 +134,9 @@ class ModelView(BaseModelView):
self.session = session
self._search_fields = None
self._search_joins_names = set()
self._search_joins = dict()
self._filter_joins_names = set()
self._filter_joins = dict()
super(ModelView, self).__init__(model, name, category, endpoint, url)
......@@ -260,7 +260,7 @@ class ModelView(BaseModelView):
"""
if self.searchable_columns:
self._search_fields = []
self._search_joins = set()
self._search_joins = dict()
for p in self.searchable_columns:
for column in self._get_columns_for_field(p):
......@@ -274,7 +274,7 @@ class ModelView(BaseModelView):
# If it belongs to different table - add a join
if column.table != self.model.__table__:
self._search_joins.add(column.table)
self._search_joins[column.table.name] = column.table
return bool(self.searchable_columns)
......@@ -316,7 +316,7 @@ class ModelView(BaseModelView):
visible_name)
if flt:
self._filter_joins_names.add(column.table.name)
self._filter_joins[column.table.name] = column.table
filters.extend(flt)
return filters
......@@ -341,7 +341,7 @@ class ModelView(BaseModelView):
if flt:
# If there's relation to other table, do it
if column.table != self.model.__table__:
self._filter_joins_names.add(column.table.name)
self._filter_joins[column.table.name] = column.table
return flt
......@@ -414,8 +414,8 @@ class ModelView(BaseModelView):
if self._search_supported and search:
# Apply search-related joins
if self._search_joins:
query = query.join(*self._search_joins)
joins |= self._search_joins
query = query.join(*self._search_joins.values())
joins = set(self._search_joins.keys())
# Apply terms
terms = search.split(' ')
......@@ -431,12 +431,12 @@ class ModelView(BaseModelView):
# Apply filters
if self._filters:
# Apply search-related joins
if self._filter_joins_names:
new_joins = self._filter_joins_names - joins
if self._filter_joins:
new_joins = set(self._filter_joins.keys()) - joins
if new_joins:
query = query.join(*new_joins)
joins |= self._search_joins_names
query = query.join(*[self._filter_joins[jn] for jn in new_joins])
joins |= new_joins
# Apply filters
for flt, value in filters:
......
......@@ -70,10 +70,12 @@ def test_model():
eq_(view.endpoint, 'model1view')
eq_(view._primary_key, 'id')
eq_(view._sortable_columns, dict(test1='test1',
test2='test2',
test3='test3',
test4='test4'))
ok_('test1' in view._sortable_columns)
ok_('test2' in view._sortable_columns)
ok_('test3' in view._sortable_columns)
ok_('test4' in view._sortable_columns)
ok_(view._create_form_class is not None)
ok_(view._edit_form_class is not None)
eq_(view._search_supported, False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment