Commit b59ba9dc authored by PJ Janse van Rensburg's avatar PJ Janse van Rensburg

Merge branch 'column_defaul_sort'

parents cdaa6260 44bb3d66
......@@ -520,7 +520,9 @@ class ModelView(BaseModelView):
order = self._get_default_order()
if order:
query = query.order_by('%s%s' % ('-' if order[1] else '', order[0]))
keys = ['%s%s' % ('-' if desc else '', col)
for (col, desc) in order]
query = query.order_by(*keys)
# Pagination
if page_size is None:
......
......@@ -234,24 +234,20 @@ class ModelView(BaseModelView):
raise Exception('Failed to find field for filter: %s' % name)
# Check if field is in different model
model_class = None
try:
if attr.model_class != self.model:
visible_name = '%s / %s' % (self.get_column_name(attr.model_class.__name__),
self.get_column_name(attr.name))
else:
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)
model_class = attr.model_class
except AttributeError:
if attr.model != self.model:
visible_name = '%s / %s' % (self.get_column_name(attr.model.__name__),
self.get_column_name(attr.name))
model_class = attr.model
if model_class != self.model:
visible_name = '%s / %s' % (self.get_column_name(model_class.__name__),
self.get_column_name(attr.name))
else:
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else:
if not isinstance(name, string_types):
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)
visible_name = self.get_column_name(name)
type_name = type(attr).__name__
flt = self.filter_converter.convert(type_name,
......@@ -317,38 +313,42 @@ class ModelView(BaseModelView):
return create_ajax_loader(self.model, name, name, options)
def _handle_join(self, query, field, joins):
model_class = None
try:
if field.model_class != self.model:
model_name = field.model_class.__name__
if model_name not in joins:
query = query.join(field.model_class, JOIN.LEFT_OUTER)
joins.add(model_name)
model_class = field.model_class
except AttributeError:
if field.model != self.model:
model_name = field.model.__name__
if model_name not in joins:
query = query.join(field.model, JOIN.LEFT_OUTER)
joins.add(model_name)
model_class = field.model
if model_class != self.model:
model_name = model_class.__name__
if model_name not in joins:
query = query.join(model_class, JOIN.LEFT_OUTER)
joins.add(model_name)
return query
def _order_by(self, query, joins, sort_field, sort_desc):
def _order_by(self, query, joins, order):
clauses = []
for sort_field, sort_desc in order:
query, joins, clause = self._sort_clause(
query, joins, sort_field, sort_desc)
clauses.append(clause)
query = query.order_by(*clauses)
return query, joins
def _sort_clause(self, query, joins, sort_field, sort_desc):
if isinstance(sort_field, string_types):
field = getattr(self.model, sort_field)
query = query.order_by(field.desc() if sort_desc else field.asc())
elif isinstance(sort_field, Field):
model_class = None
try:
if sort_field.model_class != self.model:
query = self._handle_join(query, sort_field, joins)
model_class = sort_field.model_class
except AttributeError:
if sort_field.model != self.model:
query = self._handle_join(query, sort_field, joins)
query = query.order_by(sort_field.desc() if sort_desc else sort_field.asc())
return query, joins
model_class = sort_field.model
if model_class != self.model:
query = self._handle_join(query, sort_field, joins)
field = sort_field
clause = field.desc() if sort_desc else field.asc()
return query, joins, clause
def get_query(self):
return self.model.select()
......@@ -417,13 +417,12 @@ class ModelView(BaseModelView):
# Apply sorting
if sort_column is not None:
sort_field = self._sortable_columns[sort_column]
query, joins = self._order_by(query, joins, sort_field, sort_desc)
order = [(sort_field, sort_desc)]
query, joins = self._order_by(query, joins, order)
else:
order = self._get_default_order()
if order:
query, joins = self._order_by(query, joins, order[0], order[1])
query, joins = self._order_by(query, joins, order)
# Pagination
if page_size is None:
......
......@@ -262,7 +262,8 @@ class ModelView(BaseModelView):
order = self._get_default_order()
if order:
sort_by = [(order[0], pymongo.DESCENDING if order[1] else pymongo.ASCENDING)]
sort_by = [(col, pymongo.DESCENDING if desc else pymongo.ASCENDING)
for (col, desc) in order]
# Pagination
if page_size is None:
......
......@@ -850,15 +850,9 @@ class ModelView(BaseModelView):
def _get_default_order(self):
order = super(ModelView, self)._get_default_order()
if order is not None:
field, direction = order
for field, direction in (order or []):
attr, joins = tools.get_field_with_path(self.model, field)
return attr, joins, direction
return None
yield attr, joins, direction
def _apply_sorting(self, query, joins, sort_column, sort_desc):
if sort_column is not None:
......@@ -869,10 +863,7 @@ class ModelView(BaseModelView):
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_field, sort_joins, sort_desc = order
for sort_field, sort_joins, sort_desc in order:
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins
......
......@@ -403,6 +403,12 @@ class BaseModelView(BaseView, ActionsMixin):
class MyModelView(BaseModelView):
column_default_sort = ('user', True)
If you want to sort by more than one column,
you can pass a list of tuples::
class MyModelView(BaseModelView):
column_default_sort = [('name', True), ('last_name', True)]
"""
column_searchable_list = ObsoleteAttr('column_searchable_list',
......@@ -1463,10 +1469,12 @@ class BaseModelView(BaseView, ActionsMixin):
Return default sort order
"""
if self.column_default_sort:
if isinstance(self.column_default_sort, tuple):
if isinstance(self.column_default_sort, list):
return self.column_default_sort
if isinstance(self.column_default_sort, tuple):
return [self.column_default_sort]
else:
return self.column_default_sort, False
return [(self.column_default_sort, False)]
return None
......
......@@ -685,9 +685,9 @@ def test_default_sort():
app, db, admin = setup()
M1, _ = create_models(db)
M1(test1='c').save()
M1(test1='b').save()
M1(test1='a').save()
M1(test1='c', test2='x').save()
M1(test1='b', test2='x').save()
M1(test1='a', test2='y').save()
eq_(M1.objects.count(), 3)
......@@ -700,6 +700,18 @@ def test_default_sort():
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view2 = CustomModelView(M1, column_default_sort=order, endpoint='m1_2')
admin.add_view(view2)
_, data = view2.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_extra_fields():
app, db, admin = setup()
......
......@@ -870,8 +870,8 @@ def test_default_sort():
M1, _ = create_models(db)
M1('c', 1).save()
M1('b', 2).save()
M1('a', 3).save()
M1('b', 1).save()
M1('a', 2).save()
eq_(M1.select().count(), 3)
......@@ -884,6 +884,18 @@ def test_default_sort():
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view2 = CustomModelView(M1, column_default_sort=order, endpoint='m1_2')
admin.add_view(view2)
_, data = view2.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_extra_fields():
app, db, admin = setup()
......
......@@ -1676,7 +1676,7 @@ def test_default_sort():
app, db, admin = setup()
M1, _ = create_models(db)
db.session.add_all([M1('c'), M1('b'), M1('a')])
db.session.add_all([M1('c', 'x'), M1('b', 'x'), M1('a', 'y')])
db.session.commit()
eq_(M1.query.count(), 3)
......@@ -1715,6 +1715,18 @@ def test_default_sort():
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
# test default sort with multiple columns
order = [('test2', False), ('test1', False)]
view4 = CustomModelView(M1, db.session, column_default_sort=order, endpoint='m1_4')
admin.add_view(view4)
_, data = view4.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'b')
eq_(data[1].test1, 'c')
eq_(data[2].test1, 'a')
def test_complex_sort():
app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment