Commit 32aea3ca authored by Serge S. Koval's avatar Serge S. Koval

Support default sorting order

parent 42f74231
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
{% endblock %} {% endblock %}
{% block body %} {% block body %}
{% call lib.form_tag() %} {% call lib.form_tag(form) %}
{{ lib.render_form_fields(form, widget_args=form_widget_args) }} {{ lib.render_form_fields(form, widget_args=form_widget_args) }}
<div class="form-buttons"> <div class="form-buttons">
{{ lib.render_form_buttons(return_url, extra()) }} {{ lib.render_form_buttons(return_url, extra()) }}
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
{% endblock %} {% endblock %}
{% block body %} {% block body %}
{% call lib.form_tag() %} {% call lib.form_tag(form) %}
{{ lib.render_form_fields(form, widget_args=form_widget_args) }} {{ lib.render_form_fields(form, widget_args=form_widget_args) }}
<div class="form-buttons"> <div class="form-buttons">
{{ lib.render_form_buttons(return_url) }} {{ lib.render_form_buttons(return_url) }}
......
...@@ -4,6 +4,7 @@ from flask import flash ...@@ -4,6 +4,7 @@ from flask import flash
from flask.ext.admin.babel import gettext, ngettext, lazy_gettext from flask.ext.admin.babel import gettext, ngettext, lazy_gettext
from flask.ext.admin.model import BaseModelView from flask.ext.admin.model import BaseModelView
from flask.ext.admin.model.helpers import get_default_order
import mongoengine import mongoengine
from bson.objectid import ObjectId from bson.objectid import ObjectId
...@@ -300,6 +301,11 @@ class ModelView(BaseModelView): ...@@ -300,6 +301,11 @@ class ModelView(BaseModelView):
# Sorting # Sorting
if sort_column: if sort_column:
query = query.order_by('%s%s' % ('-' if sort_desc else '', sort_column)) query = query.order_by('%s%s' % ('-' if sort_desc else '', sort_column))
else:
order = get_default_order(self)
if order:
query = query.order_by('%s%s' % ('-' if order[1] else '', order[0]))
# Pagination # Pagination
if page is not None: if page is not None:
......
...@@ -5,6 +5,7 @@ from flask import flash ...@@ -5,6 +5,7 @@ from flask import flash
from flask.ext.admin import form from flask.ext.admin import form
from flask.ext.admin.babel import gettext, ngettext, lazy_gettext from flask.ext.admin.babel import gettext, ngettext, lazy_gettext
from flask.ext.admin.model import BaseModelView from flask.ext.admin.model import BaseModelView
from flask.ext.admin.model.helpers import get_default_order
from peewee import PrimaryKeyField, ForeignKeyField, Field, CharField, TextField from peewee import PrimaryKeyField, ForeignKeyField, Field, CharField, TextField
from wtfpeewee.orm import model_form from wtfpeewee.orm import model_form
...@@ -252,6 +253,18 @@ class ModelView(BaseModelView): ...@@ -252,6 +253,18 @@ class ModelView(BaseModelView):
return query return query
def _order_by(self, query, joins, sort_field, sort_desc):
if isinstance(sort_field, basestring):
field = getattr(self.model, sort_field)
query = query.order_by(field.desc() if sort_desc else field.asc())
elif isinstance(sort_field, Field):
if sort_field.model_class != self.model:
query = self._handle_join(query, sort_field, joins)
query = query.order_by(sort_field.desc() if sort_desc else sort_field.asc())
return query, joins
def get_query(self): def get_query(self):
return self.model.select() return self.model.select()
...@@ -299,14 +312,12 @@ class ModelView(BaseModelView): ...@@ -299,14 +312,12 @@ class ModelView(BaseModelView):
if sort_column is not None: if sort_column is not None:
sort_field = self._sortable_columns[sort_column] sort_field = self._sortable_columns[sort_column]
if isinstance(sort_field, basestring): query, joins = self._order_by(query, joins, sort_field, sort_desc)
field = getattr(self.model, sort_field) else:
query = query.order_by(field.desc() if sort_desc else field.asc()) order = get_default_order(self)
elif isinstance(sort_field, Field):
if sort_field.model_class != self.model:
query = self._handle_join(query, sort_field, joins)
query = query.order_by(sort_field.desc() if sort_desc else sort_field.asc()) if order:
query, joins = self._order_by(query, joins, order[0], order[1])
# Pagination # Pagination
if page is not None: if page is not None:
...@@ -315,7 +326,7 @@ class ModelView(BaseModelView): ...@@ -315,7 +326,7 @@ class ModelView(BaseModelView):
query = query.limit(self.page_size) query = query.limit(self.page_size)
if execute: if execute:
query = query.execute() query = list(query.execute())
return count, query return count, query
......
...@@ -8,6 +8,7 @@ from jinja2 import contextfunction ...@@ -8,6 +8,7 @@ from jinja2 import contextfunction
from flask.ext.admin.babel import gettext, ngettext, lazy_gettext from flask.ext.admin.babel import gettext, ngettext, lazy_gettext
from flask.ext.admin.model import BaseModelView from flask.ext.admin.model import BaseModelView
from flask.ext.admin.model.helpers import get_default_order
from flask.ext.admin.actions import action from flask.ext.admin.actions import action
from .filters import BasePyMongoFilter from .filters import BasePyMongoFilter
...@@ -222,6 +223,11 @@ class ModelView(BaseModelView): ...@@ -222,6 +223,11 @@ class ModelView(BaseModelView):
if sort_column: if sort_column:
sort_by = [(sort_column, pymongo.DESCENDING if sort_desc else pymongo.ASCENDING)] sort_by = [(sort_column, pymongo.DESCENDING if sort_desc else pymongo.ASCENDING)]
else:
order = get_default_order(self)
if order:
sort_by = [(order[0], pymongo.DESCENDING if order[1] else pymongo.ASCENDING)]
# Pagination # Pagination
skip = None skip = None
......
...@@ -123,11 +123,6 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -123,11 +123,6 @@ class FilterConverter(filters.BaseFilterConverter):
@filters.convert('Enum', 'ENUM') @filters.convert('Enum', 'ENUM')
def conv_enum(self, column, name, options=None, **kwargs): def conv_enum(self, column, name, options=None, **kwargs):
if not options: if not options:
warnings.warn(
'You can make SQ field with `Enum` type '
'more human readable in the form by using '
'`column_choices` in your `ModelView`'
)
options = [ options = [
(v, v) (v, v)
for v in column.type.enums for v in column.type.enums
......
...@@ -10,6 +10,7 @@ from flask import flash ...@@ -10,6 +10,7 @@ from flask import flash
from flask.ext.admin.tools import ObsoleteAttr from flask.ext.admin.tools import ObsoleteAttr
from flask.ext.admin.babel import gettext, ngettext, lazy_gettext from flask.ext.admin.babel import gettext, ngettext, lazy_gettext
from flask.ext.admin.model import BaseModelView from flask.ext.admin.model import BaseModelView
from flask.ext.admin.model.helpers import get_default_order
from flask.ext.admin.actions import action from flask.ext.admin.actions import action
from flask.ext.admin.contrib.sqlamodel import form, filters, tools from flask.ext.admin.contrib.sqlamodel import form, filters, tools
...@@ -570,7 +571,7 @@ class ModelView(BaseModelView): ...@@ -570,7 +571,7 @@ class ModelView(BaseModelView):
def get_query(self): def get_query(self):
""" """
Return a query for the model type. Return a query for the model type.
If you override this method, don't forget to override `get_count_query` as well. If you override this method, don't forget to override `get_count_query` as well.
""" """
return self.session.query(self.model) return self.session.query(self.model)
...@@ -579,7 +580,57 @@ class ModelView(BaseModelView): ...@@ -579,7 +580,57 @@ class ModelView(BaseModelView):
""" """
Return a the count query for the model type Return a the count query for the model type
""" """
return self.session.query( func.count('*') ).select_from(self.model) return self.session.query(func.count('*')).select_from(self.model)
def _order_by(self, query, joins, sort_field, sort_desc):
"""
Apply order_by to the query
:param query:
Query
:param joins:
Joins set
:param sort_field:
Sort field
:param sort_desc:
Ascending or descending
"""
# TODO: Preprocessing for joins
# Try to handle it as a string
if isinstance(sort_field, basestring):
# Create automatic join against a table if column name
# contains dot.
if '.' in sort_field:
parts = sort_field.split('.', 1)
if parts[0] not in joins:
query = query.join(parts[0])
joins.add(parts[0])
elif isinstance(sort_field, InstrumentedAttribute):
# SQLAlchemy 0.8+ uses 'parent' as a name
mapper = getattr(sort_field, 'parent', None)
if mapper is None:
# SQLAlchemy 0.7.x uses parententity
mapper = getattr(sort_field, 'parententity', None)
if mapper is not None:
table = mapper.tables[0]
if table.name not in joins:
query = query.join(table)
joins.add(table.name)
elif isinstance(sort_field, Column):
pass
else:
raise TypeError('Wrong argument type')
if sort_field is not None:
if sort_desc:
query = query.order_by(desc(sort_field))
else:
query = query.order_by(sort_field)
return query, joins
def get_list(self, page, sort_column, sort_desc, search, filters, execute=True): def get_list(self, page, sort_column, sort_desc, search, filters, execute=True):
""" """
...@@ -659,40 +710,12 @@ class ModelView(BaseModelView): ...@@ -659,40 +710,12 @@ class ModelView(BaseModelView):
if sort_column in self._sortable_columns: if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column] sort_field = self._sortable_columns[sort_column]
# TODO: Preprocessing for joins query, joins = self._order_by(query, joins, sort_field, sort_desc)
# Try to handle it as a string else:
if isinstance(sort_field, basestring): order = get_default_order(self)
# Create automatic join against a table if column name
# contains dot.
if '.' in sort_field:
parts = sort_field.split('.', 1)
if parts[0] not in joins:
query = query.join(parts[0])
joins.add(parts[0])
elif isinstance(sort_field, InstrumentedAttribute):
# SQLAlchemy 0.8+ uses 'parent' as a name
mapper = getattr(sort_field, 'parent', None)
if mapper is None:
# SQLAlchemy 0.7.x uses parententity
mapper = getattr(sort_field, 'parententity', None)
if mapper is not None:
table = mapper.tables[0]
if table.name not in joins:
query = query.join(table)
joins.add(table.name)
elif isinstance(sort_field, Column):
pass
else:
raise TypeError('Wrong argument type')
if sort_field is not None: if order:
if sort_desc: query, joins = self._order_by(query, joins, order[0], order[1])
query = query.order_by(desc(sort_field))
else:
query = query.order_by(sort_field)
# Pagination # Pagination
if page is not None: if page is not None:
......
...@@ -167,6 +167,22 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -167,6 +167,22 @@ class BaseModelView(BaseView, ActionsMixin):
column_sortable_list = ('name', ('user', User.username)) column_sortable_list = ('name', ('user', User.username))
""" """
column_default_sort = None
"""
Default sort column if no sorting is applied.
Example::
class MyModelView(BaseModelView):
column_default_sort = 'user'
You can use tuple to control ascending descending order. In following example, items
will be sorted in descending order::
class MyModelView(BaseModelView):
column_default_sort = ('user', True)
"""
column_searchable_list = ObsoleteAttr('column_searchable_list', column_searchable_list = ObsoleteAttr('column_searchable_list',
'searchable_columns', 'searchable_columns',
None) None)
......
def get_default_order(view):
"""
Get default sort order from model view.
Returns (field, desc) tuple.
:param view:
View instance
"""
if view.column_default_sort:
if isinstance(view.column_default_sort, tuple):
return view.column_default_sort
else:
return (view.column_default_sort, False)
return None
...@@ -56,7 +56,9 @@ ...@@ -56,7 +56,9 @@
<input type="checkbox" name="rowtoggle" class="action-rowtoggle" /> <input type="checkbox" name="rowtoggle" class="action-rowtoggle" />
</th> </th>
{% endif %} {% endif %}
{% block list_row_actions_header %}
<th class="span1">&nbsp;</th> <th class="span1">&nbsp;</th>
{% endblock %}
{% set column = 0 %} {% set column = 0 %}
{% for c, name in list_columns %} {% for c, name in list_columns %}
<th> <th>
...@@ -106,7 +108,7 @@ ...@@ -106,7 +108,7 @@
{%- if admin_view.can_delete -%} {%- if admin_view.can_delete -%}
<form class="icon" method="POST" action="{{ url_for('.delete_view', id=get_pk_value(row), url=return_url) }}"> <form class="icon" method="POST" action="{{ url_for('.delete_view', id=get_pk_value(row), url=return_url) }}">
<button onclick="return confirm('{{ _gettext('You sure you want to delete this item?') }}');"> <button onclick="return confirm('{{ _gettext('You sure you want to delete this item?') }}');">
<i class="icon-remove"></i> <i class="icon-trash"></i>
</button> </button>
</form> </form>
{%- endif -%} {%- endif -%}
......
...@@ -112,3 +112,23 @@ def test_model(): ...@@ -112,3 +112,23 @@ def test_model():
rv = client.post(url) rv = client.post(url)
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
eq_(Model1.objects.count(), 0) eq_(Model1.objects.count(), 0)
def test_default_sort():
app, db, admin = setup()
M1, _ = create_models(db)
M1(test1='c').save()
M1(test1='b').save()
M1(test1='a').save()
eq_(M1.objects.count(), 3)
view = CustomModelView(M1, column_default_sort='test1')
admin.add_view(view)
_, data = view.get_list(0, None, None, None, None)
eq_(data[0].test1, 'a')
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
...@@ -119,3 +119,23 @@ def test_model(): ...@@ -119,3 +119,23 @@ def test_model():
rv = client.post(url) rv = client.post(url)
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
eq_(Model1.select().count(), 0) eq_(Model1.select().count(), 0)
def test_default_sort():
app, db, admin = setup()
M1, _ = create_models(db)
M1('c', 1).save()
M1('b', 2).save()
M1('a', 3).save()
eq_(M1.select().count(), 3)
view = CustomModelView(M1, column_default_sort='test1')
admin.add_view(view)
_, data = view.get_list(0, None, None, None, None)
eq_(data[0].test1, 'a')
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
...@@ -502,3 +502,22 @@ def test_multiple_delete(): ...@@ -502,3 +502,22 @@ def test_multiple_delete():
rv = client.post('/admin/model1view/action/', data=dict(action='delete', rowid=[1,2,3])) rv = client.post('/admin/model1view/action/', data=dict(action='delete', rowid=[1,2,3]))
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
eq_(M1.query.count(), 0) eq_(M1.query.count(), 0)
def test_default_sort():
app, db, admin = setup()
M1, _ = create_models(db)
db.session.add_all([M1('c'), M1('b'), M1('a')])
db.session.commit()
eq_(M1.query.count(), 3)
view = CustomModelView(M1, db.session, column_default_sort='test1')
admin.add_view(view)
_, data = view.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'a')
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment