Commit bd2fb1a3 authored by Serge S. Koval's avatar Serge S. Koval

Merge pull request #778 from pawl/fix_sort_joins

Fix sorting and searching for relation objects
parents fa50542f db2b416e
...@@ -341,6 +341,18 @@ class ModelView(BaseModelView): ...@@ -341,6 +341,18 @@ class ModelView(BaseModelView):
else: else:
attr = name attr = name
# determine joins if Table.column (relation object) is given
if isinstance(name, InstrumentedAttribute):
columns = self._get_columns_for_field(name)
if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name)
column = columns[0]
if self._need_join(column.table):
join_tables.append(column.table)
return join_tables, attr return join_tables, attr
def _need_join(self, table): def _need_join(self, table):
...@@ -441,18 +453,18 @@ class ModelView(BaseModelView): ...@@ -441,18 +453,18 @@ class ModelView(BaseModelView):
for c in self.column_sortable_list: for c in self.column_sortable_list:
if isinstance(c, tuple): if isinstance(c, tuple):
join_tables, column = self._get_field_with_path(c[1]) join_tables, column = self._get_field_with_path(c[1])
column_name = c[0]
result[c[0]] = column elif isinstance(c, InstrumentedAttribute):
join_tables, column = self._get_field_with_path(c)
if join_tables: column_name = str(c)
self._sortable_joins[c[0]] = join_tables
else: else:
join_tables, column = self._get_field_with_path(c) join_tables, column = self._get_field_with_path(c)
column_name = c
result[c] = column result[column_name] = column
if join_tables: if join_tables:
self._sortable_joins[c] = join_tables self._sortable_joins[column_name] = join_tables
return result return result
......
...@@ -1003,10 +1003,12 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1003,10 +1003,12 @@ class BaseModelView(BaseView, ActionsMixin):
""" """
Verify if column is sortable. Verify if column is sortable.
Not case-sensitive.
:param name: :param name:
Column name. Column name.
""" """
return name in self._sortable_columns return name.lower() in (x.lower() for x in self._sortable_columns)
def is_editable(self, name): def is_editable(self, name):
""" """
......
...@@ -285,18 +285,31 @@ def test_complex_searchable_list(): ...@@ -285,18 +285,31 @@ def test_complex_searchable_list():
column_searchable_list=['model1.test1']) column_searchable_list=['model1.test1'])
admin.add_view(view) admin.add_view(view)
m1 = Model1('model1') m1 = Model1('model1-test1-val')
m2 = Model1('model1-test2-val')
db.session.add(m1) db.session.add(m1)
db.session.add(Model2('model2', model1=m1)) db.session.add(m2)
db.session.add(Model2('model3')) db.session.add(Model2('model2-test1-val', model1=m1))
db.session.add(Model2('model2-test2-val', model1=m2))
db.session.commit() db.session.commit()
client = app.test_client() client = app.test_client()
rv = client.get('/admin/model2/?search=model1') # test relation string - 'model1.test1'
rv = client.get('/admin/model2/?search=model1-test1')
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('model1' in data) ok_('model2-test1-val' in data)
ok_('model3' not in data) ok_('model2-test2-val' not in data)
view2 = CustomModelView(Model1, db.session,
column_searchable_list=[Model2.string_field])
admin.add_view(view2)
# test relation object - Model2.string_field
rv = client.get('/admin/model1/?search=model2-test1')
data = rv.data.decode('utf-8')
ok_('model1-test1-val' in data)
ok_('model1-test2-val' not in data)
def test_complex_searchable_list_missing_children(): def test_complex_searchable_list_missing_children():
...@@ -1357,6 +1370,7 @@ def test_complex_sort(): ...@@ -1357,6 +1370,7 @@ def test_complex_sort():
db.session.commit() db.session.commit()
# test sorting on relation string - 'model1.test1'
view = CustomModelView(M2, db.session, view = CustomModelView(M2, db.session,
column_list = ['string_field', 'model1.test1'], column_list = ['string_field', 'model1.test1'],
column_sortable_list = ['model1.test1']) column_sortable_list = ['model1.test1'])
...@@ -1367,6 +1381,19 @@ def test_complex_sort(): ...@@ -1367,6 +1381,19 @@ def test_complex_sort():
rv = client.get('/admin/model2/?sort=1') rv = client.get('/admin/model2/?sort=1')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
# test sorting on relation object - M2.string_field
view2 = CustomModelView(M1, db.session,
column_list = ['model2.string_field'],
column_sortable_list = [M2.string_field])
admin.add_view(view2)
client = app.test_client()
rv = client.get('/admin/model1/?sort=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('Sort by' in data)
def test_default_complex_sort(): def test_default_complex_sort():
app, db, admin = setup() app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment