Commit 5934d31d authored by PJ Janse van Rensburg's avatar PJ Janse van Rensburg

Merge branch 'issue_675' of https://github.com/jbochi/flask-admin into column_defaul_sort

parents cdaa6260 0ece5c84
......@@ -520,7 +520,9 @@ class ModelView(BaseModelView):
order = self._get_default_order()
if order:
query = query.order_by('%s%s' % ('-' if order[1] else '', order[0]))
keys = ['%s%s' % ('-' if desc else '', col)
for (col, desc) in order]
query = query.order_by(*keys)
# Pagination
if page_size is None:
......
......@@ -334,21 +334,24 @@ class ModelView(BaseModelView):
return query
def _order_by(self, query, joins, sort_field, sort_desc):
def _order_by(self, query, joins, order):
clauses = []
for sort_field, sort_desc in order:
query, joins, clause = self._sort_clause(
query, joins, sort_field, sort_desc)
clauses.append(clause)
query = query.order_by(*clauses)
return query, joins
def _sort_clause(self, query, joins, sort_field, sort_desc):
if isinstance(sort_field, string_types):
field = getattr(self.model, sort_field)
query = query.order_by(field.desc() if sort_desc else field.asc())
elif isinstance(sort_field, Field):
try:
if sort_field.model_class != self.model:
query = self._handle_join(query, sort_field, joins)
except AttributeError:
if sort_field.model != self.model:
query = self._handle_join(query, sort_field, joins)
query = query.order_by(sort_field.desc() if sort_desc else sort_field.asc())
return query, joins
if sort_field.model_class != self.model:
query = self._handle_join(query, sort_field, joins)
field = sort_field
clause = field.desc() if sort_desc else field.asc()
return query, joins, clause
def get_query(self):
return self.model.select()
......@@ -417,13 +420,12 @@ class ModelView(BaseModelView):
# Apply sorting
if sort_column is not None:
sort_field = self._sortable_columns[sort_column]
query, joins = self._order_by(query, joins, sort_field, sort_desc)
order = [(sort_field, sort_desc)]
query, joins = self._order_by(query, joins, order)
else:
order = self._get_default_order()
if order:
query, joins = self._order_by(query, joins, order[0], order[1])
query, joins = self._order_by(query, joins, order)
# Pagination
if page_size is None:
......
......@@ -262,7 +262,8 @@ class ModelView(BaseModelView):
order = self._get_default_order()
if order:
sort_by = [(order[0], pymongo.DESCENDING if order[1] else pymongo.ASCENDING)]
sort_by = [(col, pymongo.DESCENDING if desc else pymongo.ASCENDING)
for (col, desc) in order]
# Pagination
if page_size is None:
......
......@@ -850,15 +850,9 @@ class ModelView(BaseModelView):
def _get_default_order(self):
order = super(ModelView, self)._get_default_order()
if order is not None:
field, direction = order
for field, direction in (order or []):
attr, joins = tools.get_field_with_path(self.model, field)
return attr, joins, direction
return None
yield attr, joins, direction
def _apply_sorting(self, query, joins, sort_column, sort_desc):
if sort_column is not None:
......@@ -869,10 +863,7 @@ class ModelView(BaseModelView):
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_field, sort_joins, sort_desc = order
for sort_field, sort_joins, sort_desc in order:
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins
......
......@@ -403,6 +403,12 @@ class BaseModelView(BaseView, ActionsMixin):
class MyModelView(BaseModelView):
column_default_sort = ('user', True)
If you want to sort by more than one column,
you can pass a list of tuples::
class MyModelView(BaseModelView):
column_default_sort = [('name', True), ('last_name', True)]
"""
column_searchable_list = ObsoleteAttr('column_searchable_list',
......@@ -1463,10 +1469,12 @@ class BaseModelView(BaseView, ActionsMixin):
Return default sort order
"""
if self.column_default_sort:
if isinstance(self.column_default_sort, tuple):
if isinstance(self.column_default_sort, list):
return self.column_default_sort
if isinstance(self.column_default_sort, tuple):
return [self.column_default_sort]
else:
return self.column_default_sort, False
return [(self.column_default_sort, False)]
return None
......
......@@ -685,9 +685,9 @@ def test_default_sort():
app, db, admin = setup()
M1, _ = create_models(db)
M1(test1='c').save()
M1(test1='b').save()
M1(test1='a').save()
M1(test1='c', test2='x').save()
M1(test1='b', test2='x').save()
M1(test1='a', test2='y').save()
eq_(M1.objects.count(), 3)
......@@ -700,6 +700,18 @@ def test_default_sort():
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view2 = CustomModelView(M1, column_default_sort=order, endpoint='m1_2')
admin.add_view(view2)
_, data = view2.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_extra_fields():
app, db, admin = setup()
......
......@@ -870,8 +870,8 @@ def test_default_sort():
M1, _ = create_models(db)
M1('c', 1).save()
M1('b', 2).save()
M1('a', 3).save()
M1('b', 1).save()
M1('a', 2).save()
eq_(M1.select().count(), 3)
......@@ -884,6 +884,18 @@ def test_default_sort():
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view2 = CustomModelView(M1, column_default_sort=order, endpoint='m1_2')
admin.add_view(view2)
_, data = view2.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_extra_fields():
app, db, admin = setup()
......
......@@ -1676,7 +1676,7 @@ def test_default_sort():
app, db, admin = setup()
M1, _ = create_models(db)
db.session.add_all([M1('c'), M1('b'), M1('a')])
db.session.add_all([M1('c', 'x'), M1('b', 'x'), M1('a', 'y')])
db.session.commit()
eq_(M1.query.count(), 3)
......@@ -1715,6 +1715,18 @@ def test_default_sort():
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view4 = CustomModelView(M1, db.session, column_default_sort=order, endpoint='m1_4')
admin.add_view(view4)
_, data = view4.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_complex_sort():
app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment