Commit c8cb7347 authored by Serge S. Koval's avatar Serge S. Koval

Minor code refactoring and fixes.

parent 815fe815
...@@ -83,7 +83,7 @@ class AdminModelConverter(ModelConverter): ...@@ -83,7 +83,7 @@ class AdminModelConverter(ModelConverter):
'query_factory': lambda: self.view.session.query(remote_model) 'query_factory': lambda: self.view.session.query(remote_model)
}) })
if local_column.nullable or prop.direction.name == 'MANYTOMANY': if local_column.nullable:
kwargs['validators'].append(validators.Optional()) kwargs['validators'].append(validators.Optional())
elif prop.direction.name != 'MANYTOMANY': elif prop.direction.name != 'MANYTOMANY':
kwargs['validators'].append(validators.Required()) kwargs['validators'].append(validators.Required())
......
...@@ -134,9 +134,9 @@ class ModelView(BaseModelView): ...@@ -134,9 +134,9 @@ class ModelView(BaseModelView):
self.session = session self.session = session
self._search_fields = None self._search_fields = None
self._search_joins_names = set() self._search_joins = dict()
self._filter_joins_names = set() self._filter_joins = dict()
super(ModelView, self).__init__(model, name, category, endpoint, url) super(ModelView, self).__init__(model, name, category, endpoint, url)
...@@ -260,7 +260,7 @@ class ModelView(BaseModelView): ...@@ -260,7 +260,7 @@ class ModelView(BaseModelView):
""" """
if self.searchable_columns: if self.searchable_columns:
self._search_fields = [] self._search_fields = []
self._search_joins = set() self._search_joins = dict()
for p in self.searchable_columns: for p in self.searchable_columns:
for column in self._get_columns_for_field(p): for column in self._get_columns_for_field(p):
...@@ -274,7 +274,7 @@ class ModelView(BaseModelView): ...@@ -274,7 +274,7 @@ class ModelView(BaseModelView):
# If it belongs to different table - add a join # If it belongs to different table - add a join
if column.table != self.model.__table__: if column.table != self.model.__table__:
self._search_joins.add(column.table) self._search_joins[column.table.name] = column.table
return bool(self.searchable_columns) return bool(self.searchable_columns)
...@@ -316,7 +316,7 @@ class ModelView(BaseModelView): ...@@ -316,7 +316,7 @@ class ModelView(BaseModelView):
visible_name) visible_name)
if flt: if flt:
self._filter_joins_names.add(column.table.name) self._filter_joins[column.table.name] = column.table
filters.extend(flt) filters.extend(flt)
return filters return filters
...@@ -341,7 +341,7 @@ class ModelView(BaseModelView): ...@@ -341,7 +341,7 @@ class ModelView(BaseModelView):
if flt: if flt:
# If there's relation to other table, do it # If there's relation to other table, do it
if column.table != self.model.__table__: if column.table != self.model.__table__:
self._filter_joins_names.add(column.table.name) self._filter_joins[column.table.name] = column.table
return flt return flt
...@@ -414,8 +414,8 @@ class ModelView(BaseModelView): ...@@ -414,8 +414,8 @@ class ModelView(BaseModelView):
if self._search_supported and search: if self._search_supported and search:
# Apply search-related joins # Apply search-related joins
if self._search_joins: if self._search_joins:
query = query.join(*self._search_joins) query = query.join(*self._search_joins.values())
joins |= self._search_joins joins = set(self._search_joins.keys())
# Apply terms # Apply terms
terms = search.split(' ') terms = search.split(' ')
...@@ -431,12 +431,12 @@ class ModelView(BaseModelView): ...@@ -431,12 +431,12 @@ class ModelView(BaseModelView):
# Apply filters # Apply filters
if self._filters: if self._filters:
# Apply search-related joins # Apply search-related joins
if self._filter_joins_names: if self._filter_joins:
new_joins = self._filter_joins_names - joins new_joins = set(self._filter_joins.keys()) - joins
if new_joins: if new_joins:
query = query.join(*new_joins) query = query.join(*[self._filter_joins[jn] for jn in new_joins])
joins |= self._search_joins_names joins |= new_joins
# Apply filters # Apply filters
for flt, value in filters: for flt, value in filters:
......
...@@ -70,10 +70,12 @@ def test_model(): ...@@ -70,10 +70,12 @@ def test_model():
eq_(view.endpoint, 'model1view') eq_(view.endpoint, 'model1view')
eq_(view._primary_key, 'id') eq_(view._primary_key, 'id')
eq_(view._sortable_columns, dict(test1='test1',
test2='test2', ok_('test1' in view._sortable_columns)
test3='test3', ok_('test2' in view._sortable_columns)
test4='test4')) ok_('test3' in view._sortable_columns)
ok_('test4' in view._sortable_columns)
ok_(view._create_form_class is not None) ok_(view._create_form_class is not None)
ok_(view._edit_form_class is not None) ok_(view._edit_form_class is not None)
eq_(view._search_supported, False) eq_(view._search_supported, False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment