Commit 78a66cda authored by Serge S. Koval's avatar Serge S. Koval

Fixed #556 Added support for complex sortables

parent 664f4622
......@@ -262,10 +262,12 @@ class ModelView(BaseModelView):
self.session = session
self._search_fields = None
self._search_joins = dict()
self._search_joins = []
self._filter_joins = dict()
self._sortable_joins = dict()
if self.form_choices is None:
self.form_choices = {}
......@@ -293,6 +295,41 @@ class ModelView(BaseModelView):
return model._sa_class_manager.mapper.iterate_properties
def _get_columns_for_field(self, field):
if (not field or
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def _get_field_with_path(self, name):
join_tables = []
if isinstance(name, string_types):
model = self.model
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
join_tables.append(table)
attr = value
else:
attr = name
return join_tables, attr
def _need_join(self, table):
return table not in self.model._sa_class_manager.mapper.tables
# Scaffolding
def scaffold_pk(self):
"""
......@@ -370,40 +407,35 @@ class ModelView(BaseModelView):
return columns
def _get_columns_for_field(self, field):
if (not field or
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def _get_field_with_path(self, name):
join_tables = []
def get_sortable_columns(self):
"""
Returns a dictionary of the sortable columns. Key is a model
field name and value is sort column (for example - attribute).
if isinstance(name, string_types):
model = self.model
If `column_sortable_list` is set, will use it. Otherwise, will call
`scaffold_sortable_columns` to get them from the model.
"""
self._sortable_joins = dict()
for attribute in name.split('.'):
value = getattr(model, attribute)
if self.column_sortable_list is None:
return self.scaffold_sortable_columns()
else:
result = dict()
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
for c in self.column_sortable_list:
if isinstance(c, tuple):
join_tables, column = self._get_field_with_path(c[1])
if self._need_join(table):
join_tables.append(table)
result[c[0]] = column
attr = value
if join_tables:
self._sortable_joins[c[0]] = join_tables
else:
attr = name
join_tables, column = self._get_field_with_path(c)
return join_tables, attr
result[c] = column
def _need_join(self, table):
return table not in self.model._sa_class_manager.mapper.tables
return result
def init_search(self):
"""
......@@ -415,7 +447,9 @@ class ModelView(BaseModelView):
"""
if self.column_searchable_list:
self._search_fields = []
self._search_joins = dict()
self._search_joins = []
joins = set()
for p in self.column_searchable_list:
join_tables, attr = self._get_field_with_path(p)
......@@ -433,9 +467,10 @@ class ModelView(BaseModelView):
self._search_fields.append(column)
# Store joins, avoid duplicates
if join_tables:
for table in join_tables:
self._search_joins[table.name] = table
if table.name not in joins:
self._search_joins.append(table)
joins.add(table.name)
return bool(self.column_searchable_list)
......@@ -621,7 +656,7 @@ class ModelView(BaseModelView):
"""
return self.session.query(func.count('*')).select_from(self.model)
def _order_by(self, query, joins, sort_field, sort_desc):
def _order_by(self, query, joins, sort_joins, sort_field, sort_desc):
"""
Apply order_by to the query
......@@ -635,33 +670,13 @@ class ModelView(BaseModelView):
Ascending or descending
"""
# TODO: Preprocessing for joins
# Try to handle it as a string
if isinstance(sort_field, string_types):
# Create automatic join against a table if column name
# contains dot.
if '.' in sort_field:
parts = sort_field.split('.', 1)
if parts[0] not in joins:
query = query.join(parts[0])
joins.add(parts[0])
elif isinstance(sort_field, InstrumentedAttribute):
# SQLAlchemy 0.8+ uses 'parent' as a name
mapper = getattr(sort_field, 'parent', None)
if mapper is None:
# SQLAlchemy 0.7.x uses parententity
mapper = getattr(sort_field, 'parententity', None)
if mapper is not None:
table = mapper.tables[0]
if self._need_join(table) and table.name not in joins:
# Handle joins
if sort_joins:
for table in sort_joins:
if table.name not in joins:
query = query.outerjoin(table)
joins.add(table.name)
elif isinstance(sort_field, Column):
pass
else:
raise TypeError('Wrong argument type')
if sort_field is not None:
if sort_desc:
......@@ -677,10 +692,9 @@ class ModelView(BaseModelView):
if order is not None:
field, direction = order
if isinstance(field, string_types):
field = getattr(self.model, field)
join_tables, attr = self._get_field_with_path(field)
return field, direction
return join_tables, field, direction
return None
......@@ -712,11 +726,11 @@ class ModelView(BaseModelView):
if self._search_supported and search:
# Apply search-related joins
if self._search_joins:
for jn in self._search_joins.values():
query = query.join(jn)
count_query = count_query.join(jn)
for table in self._search_joins:
query = query.join(table)
count_query = count_query.join(table)
joins = set(self._search_joins.keys())
joins.add(table.name)
# Apply terms
terms = search.split(' ')
......@@ -761,13 +775,16 @@ class ModelView(BaseModelView):
if sort_column is not None:
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_field, sort_desc)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
query, joins = self._order_by(query, joins, order[0], order[1])
sort_joins, sort_field, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
# Pagination
if page is not None:
......
......@@ -660,6 +660,29 @@ def test_default_sort():
eq_(data[2].test1, 'c')
def test_default_complex_sort():
app, db, admin = setup()
M1, M2 = create_models(db)
m1 = M1('b')
db.session.add(m1)
db.session.add(M2('c', model1=m1))
m2 = M1('a')
db.session.add(m2)
db.session.add(M2('c', model1=m2))
db.session.commit()
view = CustomModelView(M2, db.session, column_default_sort='model1.test1')
admin.add_view(view)
_, data = view.get_list(0, None, None, None, None)
eq_(len(data), 2)
eq_(data[0].model1.test1, 'a')
eq_(data[1].model1.test1, 'b')
def test_extra_fields():
app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment