Commit b59ba9dc authored by PJ Janse van Rensburg's avatar PJ Janse van Rensburg

Merge branch 'column_defaul_sort'

parents cdaa6260 44bb3d66
...@@ -520,7 +520,9 @@ class ModelView(BaseModelView): ...@@ -520,7 +520,9 @@ class ModelView(BaseModelView):
order = self._get_default_order() order = self._get_default_order()
if order: if order:
query = query.order_by('%s%s' % ('-' if order[1] else '', order[0])) keys = ['%s%s' % ('-' if desc else '', col)
for (col, desc) in order]
query = query.order_by(*keys)
# Pagination # Pagination
if page_size is None: if page_size is None:
......
...@@ -234,24 +234,20 @@ class ModelView(BaseModelView): ...@@ -234,24 +234,20 @@ class ModelView(BaseModelView):
raise Exception('Failed to find field for filter: %s' % name) raise Exception('Failed to find field for filter: %s' % name)
# Check if field is in different model # Check if field is in different model
model_class = None
try: try:
if attr.model_class != self.model: model_class = attr.model_class
visible_name = '%s / %s' % (self.get_column_name(attr.model_class.__name__),
self.get_column_name(attr.name))
else:
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)
except AttributeError: except AttributeError:
if attr.model != self.model: model_class = attr.model
visible_name = '%s / %s' % (self.get_column_name(attr.model.__name__),
self.get_column_name(attr.name)) if model_class != self.model:
visible_name = '%s / %s' % (self.get_column_name(model_class.__name__),
self.get_column_name(attr.name))
else:
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else: else:
if not isinstance(name, string_types): visible_name = self.get_column_name(name)
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)
type_name = type(attr).__name__ type_name = type(attr).__name__
flt = self.filter_converter.convert(type_name, flt = self.filter_converter.convert(type_name,
...@@ -317,38 +313,42 @@ class ModelView(BaseModelView): ...@@ -317,38 +313,42 @@ class ModelView(BaseModelView):
return create_ajax_loader(self.model, name, name, options) return create_ajax_loader(self.model, name, name, options)
def _handle_join(self, query, field, joins): def _handle_join(self, query, field, joins):
model_class = None
try: try:
if field.model_class != self.model: model_class = field.model_class
model_name = field.model_class.__name__
if model_name not in joins:
query = query.join(field.model_class, JOIN.LEFT_OUTER)
joins.add(model_name)
except AttributeError: except AttributeError:
if field.model != self.model: model_class = field.model
model_name = field.model.__name__ if model_class != self.model:
model_name = model_class.__name__
if model_name not in joins:
query = query.join(field.model, JOIN.LEFT_OUTER)
joins.add(model_name)
if model_name not in joins:
query = query.join(model_class, JOIN.LEFT_OUTER)
joins.add(model_name)
return query return query
def _order_by(self, query, joins, sort_field, sort_desc): def _order_by(self, query, joins, order):
clauses = []
for sort_field, sort_desc in order:
query, joins, clause = self._sort_clause(
query, joins, sort_field, sort_desc)
clauses.append(clause)
query = query.order_by(*clauses)
return query, joins
def _sort_clause(self, query, joins, sort_field, sort_desc):
if isinstance(sort_field, string_types): if isinstance(sort_field, string_types):
field = getattr(self.model, sort_field) field = getattr(self.model, sort_field)
query = query.order_by(field.desc() if sort_desc else field.asc())
elif isinstance(sort_field, Field): elif isinstance(sort_field, Field):
model_class = None
try: try:
if sort_field.model_class != self.model: model_class = sort_field.model_class
query = self._handle_join(query, sort_field, joins)
except AttributeError: except AttributeError:
if sort_field.model != self.model: model_class = sort_field.model
query = self._handle_join(query, sort_field, joins) if model_class != self.model:
query = self._handle_join(query, sort_field, joins)
query = query.order_by(sort_field.desc() if sort_desc else sort_field.asc()) field = sort_field
clause = field.desc() if sort_desc else field.asc()
return query, joins return query, joins, clause
def get_query(self): def get_query(self):
return self.model.select() return self.model.select()
...@@ -417,13 +417,12 @@ class ModelView(BaseModelView): ...@@ -417,13 +417,12 @@ class ModelView(BaseModelView):
# Apply sorting # Apply sorting
if sort_column is not None: if sort_column is not None:
sort_field = self._sortable_columns[sort_column] sort_field = self._sortable_columns[sort_column]
order = [(sort_field, sort_desc)]
query, joins = self._order_by(query, joins, sort_field, sort_desc) query, joins = self._order_by(query, joins, order)
else: else:
order = self._get_default_order() order = self._get_default_order()
if order: if order:
query, joins = self._order_by(query, joins, order[0], order[1]) query, joins = self._order_by(query, joins, order)
# Pagination # Pagination
if page_size is None: if page_size is None:
......
...@@ -262,7 +262,8 @@ class ModelView(BaseModelView): ...@@ -262,7 +262,8 @@ class ModelView(BaseModelView):
order = self._get_default_order() order = self._get_default_order()
if order: if order:
sort_by = [(order[0], pymongo.DESCENDING if order[1] else pymongo.ASCENDING)] sort_by = [(col, pymongo.DESCENDING if desc else pymongo.ASCENDING)
for (col, desc) in order]
# Pagination # Pagination
if page_size is None: if page_size is None:
......
...@@ -850,15 +850,9 @@ class ModelView(BaseModelView): ...@@ -850,15 +850,9 @@ class ModelView(BaseModelView):
def _get_default_order(self): def _get_default_order(self):
order = super(ModelView, self)._get_default_order() order = super(ModelView, self)._get_default_order()
for field, direction in (order or []):
if order is not None:
field, direction = order
attr, joins = tools.get_field_with_path(self.model, field) attr, joins = tools.get_field_with_path(self.model, field)
yield attr, joins, direction
return attr, joins, direction
return None
def _apply_sorting(self, query, joins, sort_column, sort_desc): def _apply_sorting(self, query, joins, sort_column, sort_desc):
if sort_column is not None: if sort_column is not None:
...@@ -869,10 +863,7 @@ class ModelView(BaseModelView): ...@@ -869,10 +863,7 @@ class ModelView(BaseModelView):
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc) query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else: else:
order = self._get_default_order() order = self._get_default_order()
for sort_field, sort_joins, sort_desc in order:
if order:
sort_field, sort_joins, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc) query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins return query, joins
......
...@@ -403,6 +403,12 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -403,6 +403,12 @@ class BaseModelView(BaseView, ActionsMixin):
class MyModelView(BaseModelView): class MyModelView(BaseModelView):
column_default_sort = ('user', True) column_default_sort = ('user', True)
If you want to sort by more than one column,
you can pass a list of tuples::
class MyModelView(BaseModelView):
column_default_sort = [('name', True), ('last_name', True)]
""" """
column_searchable_list = ObsoleteAttr('column_searchable_list', column_searchable_list = ObsoleteAttr('column_searchable_list',
...@@ -1463,10 +1469,12 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1463,10 +1469,12 @@ class BaseModelView(BaseView, ActionsMixin):
Return default sort order Return default sort order
""" """
if self.column_default_sort: if self.column_default_sort:
if isinstance(self.column_default_sort, tuple): if isinstance(self.column_default_sort, list):
return self.column_default_sort return self.column_default_sort
if isinstance(self.column_default_sort, tuple):
return [self.column_default_sort]
else: else:
return self.column_default_sort, False return [(self.column_default_sort, False)]
return None return None
......
...@@ -685,9 +685,9 @@ def test_default_sort(): ...@@ -685,9 +685,9 @@ def test_default_sort():
app, db, admin = setup() app, db, admin = setup()
M1, _ = create_models(db) M1, _ = create_models(db)
M1(test1='c').save() M1(test1='c', test2='x').save()
M1(test1='b').save() M1(test1='b', test2='x').save()
M1(test1='a').save() M1(test1='a', test2='y').save()
eq_(M1.objects.count(), 3) eq_(M1.objects.count(), 3)
...@@ -700,6 +700,18 @@ def test_default_sort(): ...@@ -700,6 +700,18 @@ def test_default_sort():
eq_(data[1].test1, 'b') eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c') eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view2 = CustomModelView(M1, column_default_sort=order, endpoint='m1_2')
admin.add_view(view2)
_, data = view2.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_extra_fields(): def test_extra_fields():
app, db, admin = setup() app, db, admin = setup()
......
...@@ -870,8 +870,8 @@ def test_default_sort(): ...@@ -870,8 +870,8 @@ def test_default_sort():
M1, _ = create_models(db) M1, _ = create_models(db)
M1('c', 1).save() M1('c', 1).save()
M1('b', 2).save() M1('b', 1).save()
M1('a', 3).save() M1('a', 2).save()
eq_(M1.select().count(), 3) eq_(M1.select().count(), 3)
...@@ -884,6 +884,18 @@ def test_default_sort(): ...@@ -884,6 +884,18 @@ def test_default_sort():
eq_(data[1].test1, 'b') eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c') eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view2 = CustomModelView(M1, column_default_sort=order, endpoint='m1_2')
admin.add_view(view2)
_, data = view2.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_extra_fields(): def test_extra_fields():
app, db, admin = setup() app, db, admin = setup()
......
...@@ -1676,7 +1676,7 @@ def test_default_sort(): ...@@ -1676,7 +1676,7 @@ def test_default_sort():
app, db, admin = setup() app, db, admin = setup()
M1, _ = create_models(db) M1, _ = create_models(db)
db.session.add_all([M1('c'), M1('b'), M1('a')]) db.session.add_all([M1('c', 'x'), M1('b', 'x'), M1('a', 'y')])
db.session.commit() db.session.commit()
eq_(M1.query.count(), 3) eq_(M1.query.count(), 3)
...@@ -1715,6 +1715,18 @@ def test_default_sort(): ...@@ -1715,6 +1715,18 @@ def test_default_sort():
eq_(data[1].test1, 'b') eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c') eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view4 = CustomModelView(M1, db.session, column_default_sort=order, endpoint='m1_4')
admin.add_view(view4)
_, data = view4.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_complex_sort(): def test_complex_sort():
app, db, admin = setup() app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment