Commit db3ee4b0 authored by Serge S. Koval's avatar Serge S. Koval

Fixed #846, #808. Refactored JOIN logic so it now works with relationship...

Fixed #846, #808. Refactored JOIN logic so it now works with relationship properties instead of tables.

This is partially incompatible change - you will have to update your custom filters. `apply` method now
accepts additional `alias` argument. `alias` should be used to resolve column in query with specific JOIN path.

Long story short, old code which was looking like this:

```python
class MyFilter(SQLABaseFilter):
    def apply(self, query, value):
        return query.filter(self.column == value)
```

Should look like this:

```python
class MyFilter(SQLABaseFilter):
    def apply(self, query, value, alias=None):
        return query.filter(self.get_column(alias) == value)
```
parent 3935d7db
......@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
# List of columns that can be sorted. For 'user' column, use User.username as
# a column.
column_sortable_list = ('title', ('user', User.username), 'date')
column_sortable_list = ('title', ('user', 'user.username'), 'date')
# Rename 'title' columns to 'Post Title' in list view
column_labels = dict(title='Post Title')
......
import warnings
import time
import datetime
from flask_admin.babel import lazy_gettext
from flask_admin.model import filters
from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_
class BaseSQLAFilter(filters.BaseFilter):
"""
Base SQLAlchemy filter.
......@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
self.column = column
def get_column(self, alias):
return self.column if alias is None else getattr(alias, self.column.key)
def apply(self, query, value, alias=None):
return super(self, BaseSQLAFilter).apply(query, value)
# Common filters
class FilterEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column == value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) == value)
def operation(self):
return lazy_gettext('equals')
class FilterNotEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column != value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) != value)
def operation(self):
return lazy_gettext('not equal')
class FilterLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(self.column.ilike(stmt))
return query.filter(self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('contains')
class FilterNotLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(~self.column.ilike(stmt))
return query.filter(~self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('not contains')
class FilterGreater(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column > value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) > value)
def operation(self):
return lazy_gettext('greater than')
class FilterSmaller(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column < value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) < value)
def operation(self):
return lazy_gettext('smaller than')
class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
if value == '1':
return query.filter(self.column == None)
return query.filter(self.get_column(alias) == None)
else:
return query.filter(self.column != None)
return query.filter(self.get_column(alias) != None)
def operation(self):
return lazy_gettext('empty')
......@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value):
return query.filter(self.column.in_(value))
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias).in_(value))
def operation(self):
return lazy_gettext('in list')
class FilterNotInList(FilterInList):
def apply(self, query, value):
def apply(self, query, value, alias=None):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(~self.column.in_(value), self.column == None))
column = self.get_column(alias)
return query.filter(or_(~column.in_(value), column == None))
def operation(self):
return lazy_gettext('not in list')
......@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
options,
data_type='daterangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateNotBetweenFilter(DateBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
# ~between() isn't possible until sqlalchemy 1.0.0
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
options,
data_type='datetimerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
options,
data_type='timerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class TimeNotBetweenFilter(TimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......
This diff is collapsed.
......@@ -260,10 +260,11 @@ def test_column_searchable_list():
eq_(view._search_supported, True)
eq_(len(view._search_fields), 2)
ok_(isinstance(view._search_fields[0], db.Column))
ok_(isinstance(view._search_fields[1], db.Column))
eq_(view._search_fields[0].name, 'string_field')
eq_(view._search_fields[1].name, 'int_field')
ok_(isinstance(view._search_fields[0][0], db.Column))
ok_(isinstance(view._search_fields[1][0], db.Column))
eq_(view._search_fields[0][0].name, 'string_field')
eq_(view._search_fields[1][0].name, 'int_field')
db.session.add(Model2('model1-test', 5000))
db.session.add(Model2('model2-test', 9000))
......@@ -418,6 +419,8 @@ def test_column_filters():
)
admin.add_view(view)
client = app.test_client()
eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
......@@ -516,21 +519,23 @@ def test_column_filters():
fill_db(db, Model1, Model2)
client = app.test_client()
# Test equals
rv = client.get('/admin/model1/?flt0_0=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
# the filter value is always in "data"
# need to check a different column than test1 for the expected row
ok_('test2_val_1' in data)
ok_('test1_val_2' not in data)
# Test NOT IN filter
rv = client.get('/admin/model1/?flt0_6=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test1_val_2' in data)
ok_('test2_val_1' not in data)
# Test string filter
view = CustomModelView(Model1, db.session,
......@@ -1104,9 +1109,11 @@ def test_column_filters():
rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1')
data = rv.data.decode('utf-8')
ok_('test1_val_1' in data)
ok_('test1_val_2' not in data)
def test_url_args():
app, db, admin = setup()
......@@ -1699,3 +1706,105 @@ def test_simple_list_pager():
count, data = view.get_list(0, None, None, None, None)
assert_true(count is None)
def test_advanced_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1, backref='model2')
class Model3(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model2_id = db.Column(db.Integer, db.ForeignKey(Model2.id))
model2 = db.relationship(Model2, backref='model3')
view1 = CustomModelView(Model1, db.session)
admin.add_view(view1)
view2 = CustomModelView(Model2, db.session)
admin.add_view(view2)
view3 = CustomModelView(Model3, db.session)
admin.add_view(view3)
# Test joins
attr, path = view2._get_field_with_path('model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2')
eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1])
# Test how joins are applied
query = view3.get_query()
joins = {}
q1, joins, alias = view3._apply_path_joins(query, joins, path)
ok_((True, Model3.model2) in joins)
ok_((True, Model2.model1) in joins)
ok_(alias is not None)
# Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2)
for p in q2._join_entities:
ok_(p in q1._join_entities)
ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path
attr, path = view2._get_field_with_path(Model1.test)
eq_(attr, Model1.test)
eq_(path, [Model1.__table__])
q3, joins, alias = view2._apply_path_joins(view2.get_query(), joins, path)
eq_(len(joins), 3)
ok_(alias is None)
def test_multipath_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
first_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
first = db.relationship(Model1, backref='first', foreign_keys=[first_id])
second_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
second = db.relationship(Model1, backref='second', foreign_keys=[second_id])
db.create_all()
view = CustomModelView(Model2, db.session, filters=['first.test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model2/')
eq_(rv.status_code, 200)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment