Commit cdadb1bd authored by Serge S. Koval's avatar Serge S. Koval

Peewee backend: support for filters in related models

parent 7d0741f4
......@@ -49,7 +49,8 @@ class PostAdmin(peeweemodel.ModelView):
searchable_columns = ('title',)
column_filters = ('title',
'date')
'date',
User.username)
@app.route('/')
......
......@@ -2,19 +2,6 @@ from flask.ext.admin.babel import gettext
from flask.ext.admin.model import filters
from peewee import Q
def parse_like_term(term):
if term.startswith('^'):
stmt = '%s%%' % term[1:]
elif term.startswith('='):
stmt = term[1:]
else:
stmt = '%%%s%%' % term
return stmt
class BasePeeweeFilter(filters.BaseFilter):
"""
......@@ -41,8 +28,7 @@ class BasePeeweeFilter(filters.BaseFilter):
# Common filters
class FilterEqual(BasePeeweeFilter):
def apply(self, query, value):
stmt = '%s__eq' % self.column.name
return query.where(**{stmt: value})
return query.filter(self.column == value)
def operation(self):
return gettext('equals')
......@@ -50,8 +36,7 @@ class FilterEqual(BasePeeweeFilter):
class FilterNotEqual(BasePeeweeFilter):
def apply(self, query, value):
stmt = '%s__neq' % self.column.name
return query.where(**{stmt: value})
return query.filter(self.column != value)
def operation(self):
return gettext('not equal')
......@@ -59,9 +44,7 @@ class FilterNotEqual(BasePeeweeFilter):
class FilterLike(BasePeeweeFilter):
def apply(self, query, value):
stmt = '%s__icontains' % self.column.name
val = parse_like_term(value)
return query.where(**{stmt: val})
return query.filter(self.column ** value)
def operation(self):
return gettext('contains')
......@@ -69,10 +52,7 @@ class FilterLike(BasePeeweeFilter):
class FilterNotLike(BasePeeweeFilter):
def apply(self, query, value):
stmt = '%s__icontains' % self.column.name
val = parse_like_term(value)
node = ~Q(**{stmt: val})
return query.where(node)
return query.filter(~(self.column ** value))
def operation(self):
return gettext('not contains')
......@@ -80,8 +60,7 @@ class FilterNotLike(BasePeeweeFilter):
class FilterGreater(BasePeeweeFilter):
def apply(self, query, value):
stmt = '%s__gt' % self.column.name
return query.where(**{stmt: value})
return query.filter(self.column > value)
def operation(self):
return gettext('greater than')
......@@ -89,8 +68,7 @@ class FilterGreater(BasePeeweeFilter):
class FilterSmaller(BasePeeweeFilter):
def apply(self, query, value):
stmt = '%s__lt' % self.column.name
return query.where(**{stmt: value})
return query.filter(self.column < value)
def operation(self):
return gettext('smaller than')
......@@ -111,7 +89,7 @@ class FilterConverter(filters.BaseFilterConverter):
numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller)
def convert(self, type_name, column, name):
print type_name, column, name
#print type_name, column, name
if type_name in self.converters:
return self.converters[type_name](column, name)
......
......@@ -42,6 +42,9 @@ class ModelView(BaseModelView):
self._primary_key = self.scaffold_pk()
self._search_fields = []
self._search_joins = dict()
def _get_model_fields(self, model=None):
if model is None:
model = self.model
......@@ -116,29 +119,6 @@ class ModelView(BaseModelView):
return bool(self._search_fields)
def _find_field(self, model, field, visited, path=None):
def make_path(n):
if path:
return '%s__%s' % (path, n)
else:
return n
for n, p in self._get_model_fields(model):
if p.model == model and p.name == field.name:
return make_path(n)
if type(p) == ForeignKeyField:
if p.to not in visited:
visited.add(p.to)
result = self._find_field(p.to, field, visited,
make_path(n))
if result is not None:
return result
return None
def scaffold_filters(self, name):
if isinstance(name, basestring):
attr = getattr(self.model, name, None)
......@@ -148,20 +128,21 @@ class ModelView(BaseModelView):
if attr is None:
raise Exception('Failed to find field for filter: %s' % name)
if not isinstance(name, basestring):
visible_name = self.get_column_name(attr.name)
# Check if field is in different model
if attr.model != self.model:
visible_name = '%s / %s' % (self.get_column_name(attr.model.__name__),
self.get_column_name(attr.name))
else:
visible_name = self.get_column_name(name)
if not isinstance(name, basestring):
visible_name = self.get_column_name(attr.name)
else:
visible_name = self.get_column_name(name)
type_name = type(attr).__name__
flt = self.filter_converter.convert(type_name,
attr,
visible_name)
if flt:
# TODO: Related table search
pass
return flt
def is_valid_filter(self, filter):
......@@ -179,6 +160,8 @@ class ModelView(BaseModelView):
execute=True):
query = self.model.select()
joins = set()
# Search
if self._search_supported and search:
terms = search.split(' ')
......@@ -204,6 +187,15 @@ class ModelView(BaseModelView):
# Filters
if self._filters:
for flt, value in filters:
f = self._filters[flt]
if f.column.model != self.model:
model_name = f.column.model.__name__
if model_name not in joins:
query = query.join(f.column.model)
joins.add(model_name)
query = self._filters[flt].apply(query, value)
# Get count
......
......@@ -19,6 +19,8 @@ var AdminFilters = function(element, filters_element, adminForm, operations, opt
function removeFilter() {
$(this).parent().remove();
$('button', $root).show();
return false;
}
function addFilter(name, op) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment