Commit 78a66cda authored by Serge S. Koval's avatar Serge S. Koval

Fixed #556 Added support for complex sortables

parent 664f4622
...@@ -262,10 +262,12 @@ class ModelView(BaseModelView): ...@@ -262,10 +262,12 @@ class ModelView(BaseModelView):
self.session = session self.session = session
self._search_fields = None self._search_fields = None
self._search_joins = dict() self._search_joins = []
self._filter_joins = dict() self._filter_joins = dict()
self._sortable_joins = dict()
if self.form_choices is None: if self.form_choices is None:
self.form_choices = {} self.form_choices = {}
...@@ -293,6 +295,41 @@ class ModelView(BaseModelView): ...@@ -293,6 +295,41 @@ class ModelView(BaseModelView):
return model._sa_class_manager.mapper.iterate_properties return model._sa_class_manager.mapper.iterate_properties
def _get_columns_for_field(self, field):
if (not field or
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def _get_field_with_path(self, name):
join_tables = []
if isinstance(name, string_types):
model = self.model
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
join_tables.append(table)
attr = value
else:
attr = name
return join_tables, attr
def _need_join(self, table):
return table not in self.model._sa_class_manager.mapper.tables
# Scaffolding # Scaffolding
def scaffold_pk(self): def scaffold_pk(self):
""" """
...@@ -370,40 +407,35 @@ class ModelView(BaseModelView): ...@@ -370,40 +407,35 @@ class ModelView(BaseModelView):
return columns return columns
def _get_columns_for_field(self, field): def get_sortable_columns(self):
if (not field or """
not hasattr(field, 'property') or Returns a dictionary of the sortable columns. Key is a model
not hasattr(field.property, 'columns') or field name and value is sort column (for example - attribute).
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def _get_field_with_path(self, name): If `column_sortable_list` is set, will use it. Otherwise, will call
join_tables = [] `scaffold_sortable_columns` to get them from the model.
"""
self._sortable_joins = dict()
if isinstance(name, string_types): if self.column_sortable_list is None:
model = self.model return self.scaffold_sortable_columns()
else:
result = dict()
for attribute in name.split('.'): for c in self.column_sortable_list:
value = getattr(model, attribute) if isinstance(c, tuple):
join_tables, column = self._get_field_with_path(c[1])
if (hasattr(value, 'property') and result[c[0]] = column
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table): if join_tables:
join_tables.append(table) self._sortable_joins[c[0]] = join_tables
else:
attr = value join_tables, column = self._get_field_with_path(c)
else:
attr = name
return join_tables, attr result[c] = column
def _need_join(self, table): return result
return table not in self.model._sa_class_manager.mapper.tables
def init_search(self): def init_search(self):
""" """
...@@ -415,7 +447,9 @@ class ModelView(BaseModelView): ...@@ -415,7 +447,9 @@ class ModelView(BaseModelView):
""" """
if self.column_searchable_list: if self.column_searchable_list:
self._search_fields = [] self._search_fields = []
self._search_joins = dict() self._search_joins = []
joins = set()
for p in self.column_searchable_list: for p in self.column_searchable_list:
join_tables, attr = self._get_field_with_path(p) join_tables, attr = self._get_field_with_path(p)
...@@ -433,9 +467,10 @@ class ModelView(BaseModelView): ...@@ -433,9 +467,10 @@ class ModelView(BaseModelView):
self._search_fields.append(column) self._search_fields.append(column)
# Store joins, avoid duplicates # Store joins, avoid duplicates
if join_tables: for table in join_tables:
for table in join_tables: if table.name not in joins:
self._search_joins[table.name] = table self._search_joins.append(table)
joins.add(table.name)
return bool(self.column_searchable_list) return bool(self.column_searchable_list)
...@@ -621,7 +656,7 @@ class ModelView(BaseModelView): ...@@ -621,7 +656,7 @@ class ModelView(BaseModelView):
""" """
return self.session.query(func.count('*')).select_from(self.model) return self.session.query(func.count('*')).select_from(self.model)
def _order_by(self, query, joins, sort_field, sort_desc): def _order_by(self, query, joins, sort_joins, sort_field, sort_desc):
""" """
Apply order_by to the query Apply order_by to the query
...@@ -635,33 +670,13 @@ class ModelView(BaseModelView): ...@@ -635,33 +670,13 @@ class ModelView(BaseModelView):
Ascending or descending Ascending or descending
""" """
# TODO: Preprocessing for joins # TODO: Preprocessing for joins
# Try to handle it as a string # Handle joins
if isinstance(sort_field, string_types): if sort_joins:
# Create automatic join against a table if column name for table in sort_joins:
# contains dot. if table.name not in joins:
if '.' in sort_field:
parts = sort_field.split('.', 1)
if parts[0] not in joins:
query = query.join(parts[0])
joins.add(parts[0])
elif isinstance(sort_field, InstrumentedAttribute):
# SQLAlchemy 0.8+ uses 'parent' as a name
mapper = getattr(sort_field, 'parent', None)
if mapper is None:
# SQLAlchemy 0.7.x uses parententity
mapper = getattr(sort_field, 'parententity', None)
if mapper is not None:
table = mapper.tables[0]
if self._need_join(table) and table.name not in joins:
query = query.outerjoin(table) query = query.outerjoin(table)
joins.add(table.name) joins.add(table.name)
elif isinstance(sort_field, Column):
pass
else:
raise TypeError('Wrong argument type')
if sort_field is not None: if sort_field is not None:
if sort_desc: if sort_desc:
...@@ -677,10 +692,9 @@ class ModelView(BaseModelView): ...@@ -677,10 +692,9 @@ class ModelView(BaseModelView):
if order is not None: if order is not None:
field, direction = order field, direction = order
if isinstance(field, string_types): join_tables, attr = self._get_field_with_path(field)
field = getattr(self.model, field)
return field, direction return join_tables, field, direction
return None return None
...@@ -712,11 +726,11 @@ class ModelView(BaseModelView): ...@@ -712,11 +726,11 @@ class ModelView(BaseModelView):
if self._search_supported and search: if self._search_supported and search:
# Apply search-related joins # Apply search-related joins
if self._search_joins: if self._search_joins:
for jn in self._search_joins.values(): for table in self._search_joins:
query = query.join(jn) query = query.join(table)
count_query = count_query.join(jn) count_query = count_query.join(table)
joins = set(self._search_joins.keys()) joins.add(table.name)
# Apply terms # Apply terms
terms = search.split(' ') terms = search.split(' ')
...@@ -761,13 +775,16 @@ class ModelView(BaseModelView): ...@@ -761,13 +775,16 @@ class ModelView(BaseModelView):
if sort_column is not None: if sort_column is not None:
if sort_column in self._sortable_columns: if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column] sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_field, sort_desc) query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else: else:
order = self._get_default_order() order = self._get_default_order()
if order: if order:
query, joins = self._order_by(query, joins, order[0], order[1]) sort_joins, sort_field, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
# Pagination # Pagination
if page is not None: if page is not None:
......
...@@ -660,6 +660,29 @@ def test_default_sort(): ...@@ -660,6 +660,29 @@ def test_default_sort():
eq_(data[2].test1, 'c') eq_(data[2].test1, 'c')
def test_default_complex_sort():
app, db, admin = setup()
M1, M2 = create_models(db)
m1 = M1('b')
db.session.add(m1)
db.session.add(M2('c', model1=m1))
m2 = M1('a')
db.session.add(m2)
db.session.add(M2('c', model1=m2))
db.session.commit()
view = CustomModelView(M2, db.session, column_default_sort='model1.test1')
admin.add_view(view)
_, data = view.get_list(0, None, None, None, None)
eq_(len(data), 2)
eq_(data[0].model1.test1, 'a')
eq_(data[1].model1.test1, 'b')
def test_extra_fields(): def test_extra_fields():
app, db, admin = setup() app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment