Commit cdadb1bd authored by Serge S. Koval's avatar Serge S. Koval

Peewee backend: support for filters in related models

parent 7d0741f4
...@@ -49,7 +49,8 @@ class PostAdmin(peeweemodel.ModelView): ...@@ -49,7 +49,8 @@ class PostAdmin(peeweemodel.ModelView):
searchable_columns = ('title',) searchable_columns = ('title',)
column_filters = ('title', column_filters = ('title',
'date') 'date',
User.username)
@app.route('/') @app.route('/')
......
...@@ -2,19 +2,6 @@ from flask.ext.admin.babel import gettext ...@@ -2,19 +2,6 @@ from flask.ext.admin.babel import gettext
from flask.ext.admin.model import filters from flask.ext.admin.model import filters
from peewee import Q
def parse_like_term(term):
if term.startswith('^'):
stmt = '%s%%' % term[1:]
elif term.startswith('='):
stmt = term[1:]
else:
stmt = '%%%s%%' % term
return stmt
class BasePeeweeFilter(filters.BaseFilter): class BasePeeweeFilter(filters.BaseFilter):
""" """
...@@ -41,8 +28,7 @@ class BasePeeweeFilter(filters.BaseFilter): ...@@ -41,8 +28,7 @@ class BasePeeweeFilter(filters.BaseFilter):
# Common filters # Common filters
class FilterEqual(BasePeeweeFilter): class FilterEqual(BasePeeweeFilter):
def apply(self, query, value): def apply(self, query, value):
stmt = '%s__eq' % self.column.name return query.filter(self.column == value)
return query.where(**{stmt: value})
def operation(self): def operation(self):
return gettext('equals') return gettext('equals')
...@@ -50,8 +36,7 @@ class FilterEqual(BasePeeweeFilter): ...@@ -50,8 +36,7 @@ class FilterEqual(BasePeeweeFilter):
class FilterNotEqual(BasePeeweeFilter): class FilterNotEqual(BasePeeweeFilter):
def apply(self, query, value): def apply(self, query, value):
stmt = '%s__neq' % self.column.name return query.filter(self.column != value)
return query.where(**{stmt: value})
def operation(self): def operation(self):
return gettext('not equal') return gettext('not equal')
...@@ -59,9 +44,7 @@ class FilterNotEqual(BasePeeweeFilter): ...@@ -59,9 +44,7 @@ class FilterNotEqual(BasePeeweeFilter):
class FilterLike(BasePeeweeFilter): class FilterLike(BasePeeweeFilter):
def apply(self, query, value): def apply(self, query, value):
stmt = '%s__icontains' % self.column.name return query.filter(self.column ** value)
val = parse_like_term(value)
return query.where(**{stmt: val})
def operation(self): def operation(self):
return gettext('contains') return gettext('contains')
...@@ -69,10 +52,7 @@ class FilterLike(BasePeeweeFilter): ...@@ -69,10 +52,7 @@ class FilterLike(BasePeeweeFilter):
class FilterNotLike(BasePeeweeFilter): class FilterNotLike(BasePeeweeFilter):
def apply(self, query, value): def apply(self, query, value):
stmt = '%s__icontains' % self.column.name return query.filter(~(self.column ** value))
val = parse_like_term(value)
node = ~Q(**{stmt: val})
return query.where(node)
def operation(self): def operation(self):
return gettext('not contains') return gettext('not contains')
...@@ -80,8 +60,7 @@ class FilterNotLike(BasePeeweeFilter): ...@@ -80,8 +60,7 @@ class FilterNotLike(BasePeeweeFilter):
class FilterGreater(BasePeeweeFilter): class FilterGreater(BasePeeweeFilter):
def apply(self, query, value): def apply(self, query, value):
stmt = '%s__gt' % self.column.name return query.filter(self.column > value)
return query.where(**{stmt: value})
def operation(self): def operation(self):
return gettext('greater than') return gettext('greater than')
...@@ -89,8 +68,7 @@ class FilterGreater(BasePeeweeFilter): ...@@ -89,8 +68,7 @@ class FilterGreater(BasePeeweeFilter):
class FilterSmaller(BasePeeweeFilter): class FilterSmaller(BasePeeweeFilter):
def apply(self, query, value): def apply(self, query, value):
stmt = '%s__lt' % self.column.name return query.filter(self.column < value)
return query.where(**{stmt: value})
def operation(self): def operation(self):
return gettext('smaller than') return gettext('smaller than')
...@@ -111,7 +89,7 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -111,7 +89,7 @@ class FilterConverter(filters.BaseFilterConverter):
numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller) numeric = (FilterEqual, FilterNotEqual, FilterGreater, FilterSmaller)
def convert(self, type_name, column, name): def convert(self, type_name, column, name):
print type_name, column, name #print type_name, column, name
if type_name in self.converters: if type_name in self.converters:
return self.converters[type_name](column, name) return self.converters[type_name](column, name)
......
...@@ -42,6 +42,9 @@ class ModelView(BaseModelView): ...@@ -42,6 +42,9 @@ class ModelView(BaseModelView):
self._primary_key = self.scaffold_pk() self._primary_key = self.scaffold_pk()
self._search_fields = []
self._search_joins = dict()
def _get_model_fields(self, model=None): def _get_model_fields(self, model=None):
if model is None: if model is None:
model = self.model model = self.model
...@@ -116,29 +119,6 @@ class ModelView(BaseModelView): ...@@ -116,29 +119,6 @@ class ModelView(BaseModelView):
return bool(self._search_fields) return bool(self._search_fields)
def _find_field(self, model, field, visited, path=None):
def make_path(n):
if path:
return '%s__%s' % (path, n)
else:
return n
for n, p in self._get_model_fields(model):
if p.model == model and p.name == field.name:
return make_path(n)
if type(p) == ForeignKeyField:
if p.to not in visited:
visited.add(p.to)
result = self._find_field(p.to, field, visited,
make_path(n))
if result is not None:
return result
return None
def scaffold_filters(self, name): def scaffold_filters(self, name):
if isinstance(name, basestring): if isinstance(name, basestring):
attr = getattr(self.model, name, None) attr = getattr(self.model, name, None)
...@@ -148,6 +128,11 @@ class ModelView(BaseModelView): ...@@ -148,6 +128,11 @@ class ModelView(BaseModelView):
if attr is None: if attr is None:
raise Exception('Failed to find field for filter: %s' % name) raise Exception('Failed to find field for filter: %s' % name)
# Check if field is in different model
if attr.model != self.model:
visible_name = '%s / %s' % (self.get_column_name(attr.model.__name__),
self.get_column_name(attr.name))
else:
if not isinstance(name, basestring): if not isinstance(name, basestring):
visible_name = self.get_column_name(attr.name) visible_name = self.get_column_name(attr.name)
else: else:
...@@ -158,10 +143,6 @@ class ModelView(BaseModelView): ...@@ -158,10 +143,6 @@ class ModelView(BaseModelView):
attr, attr,
visible_name) visible_name)
if flt:
# TODO: Related table search
pass
return flt return flt
def is_valid_filter(self, filter): def is_valid_filter(self, filter):
...@@ -179,6 +160,8 @@ class ModelView(BaseModelView): ...@@ -179,6 +160,8 @@ class ModelView(BaseModelView):
execute=True): execute=True):
query = self.model.select() query = self.model.select()
joins = set()
# Search # Search
if self._search_supported and search: if self._search_supported and search:
terms = search.split(' ') terms = search.split(' ')
...@@ -204,6 +187,15 @@ class ModelView(BaseModelView): ...@@ -204,6 +187,15 @@ class ModelView(BaseModelView):
# Filters # Filters
if self._filters: if self._filters:
for flt, value in filters: for flt, value in filters:
f = self._filters[flt]
if f.column.model != self.model:
model_name = f.column.model.__name__
if model_name not in joins:
query = query.join(f.column.model)
joins.add(model_name)
query = self._filters[flt].apply(query, value) query = self._filters[flt].apply(query, value)
# Get count # Get count
......
...@@ -19,6 +19,8 @@ var AdminFilters = function(element, filters_element, adminForm, operations, opt ...@@ -19,6 +19,8 @@ var AdminFilters = function(element, filters_element, adminForm, operations, opt
function removeFilter() { function removeFilter() {
$(this).parent().remove(); $(this).parent().remove();
$('button', $root).show(); $('button', $root).show();
return false;
} }
function addFilter(name, op) { function addFilter(name, op) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment