Commit db3ee4b0 authored by Serge S. Koval's avatar Serge S. Koval

Fixed #846, #808. Refactored JOIN logic so it now works with relationship...

Fixed #846, #808. Refactored JOIN logic so it now works with relationship properties instead of tables.

This is partially incompatible change - you will have to update your custom filters. `apply` method now
accepts additional `alias` argument. `alias` should be used to resolve column in query with specific JOIN path.

Long story short, old code which was looking like this:

```python
class MyFilter(SQLABaseFilter):
    def apply(self, query, value):
        return query.filter(self.column == value)
```

Should look like this:

```python
class MyFilter(SQLABaseFilter):
    def apply(self, query, value, alias=None):
        return query.filter(self.get_column(alias) == value)
```
parent 3935d7db
......@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
# List of columns that can be sorted. For 'user' column, use User.username as
# a column.
column_sortable_list = ('title', ('user', User.username), 'date')
column_sortable_list = ('title', ('user', 'user.username'), 'date')
# Rename 'title' columns to 'Post Title' in list view
column_labels = dict(title='Post Title')
......
import warnings
import time
import datetime
from flask_admin.babel import lazy_gettext
from flask_admin.model import filters
from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_
class BaseSQLAFilter(filters.BaseFilter):
"""
Base SQLAlchemy filter.
......@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
self.column = column
def get_column(self, alias):
return self.column if alias is None else getattr(alias, self.column.key)
def apply(self, query, value, alias=None):
return super(self, BaseSQLAFilter).apply(query, value)
# Common filters
class FilterEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column == value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) == value)
def operation(self):
return lazy_gettext('equals')
class FilterNotEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column != value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) != value)
def operation(self):
return lazy_gettext('not equal')
class FilterLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(self.column.ilike(stmt))
return query.filter(self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('contains')
class FilterNotLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(~self.column.ilike(stmt))
return query.filter(~self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('not contains')
class FilterGreater(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column > value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) > value)
def operation(self):
return lazy_gettext('greater than')
class FilterSmaller(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column < value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) < value)
def operation(self):
return lazy_gettext('smaller than')
class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
if value == '1':
return query.filter(self.column == None)
return query.filter(self.get_column(alias) == None)
else:
return query.filter(self.column != None)
return query.filter(self.get_column(alias) != None)
def operation(self):
return lazy_gettext('empty')
......@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value):
return query.filter(self.column.in_(value))
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias).in_(value))
def operation(self):
return lazy_gettext('in list')
class FilterNotInList(FilterInList):
def apply(self, query, value):
def apply(self, query, value, alias=None):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(~self.column.in_(value), self.column == None))
column = self.get_column(alias)
return query.filter(or_(~column.in_(value), column == None))
def operation(self):
return lazy_gettext('not in list')
......@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
options,
data_type='daterangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateNotBetweenFilter(DateBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
# ~between() isn't possible until sqlalchemy 1.0.0
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
options,
data_type='datetimerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
options,
data_type='timerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class TimeNotBetweenFilter(TimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......
import logging
import warnings
import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, aliased
from sqlalchemy.sql.expression import desc
from sqlalchemy import Boolean, func, or_
from sqlalchemy import Boolean, Table, func, or_
from sqlalchemy.exc import IntegrityError
from flask import flash
......@@ -276,7 +277,6 @@ class ModelView(BaseModelView):
self.session = session
self._search_fields = None
self._search_joins = []
self._filter_joins = dict()
......@@ -322,43 +322,92 @@ class ModelView(BaseModelView):
return field.property.columns
def _get_field_with_path(self, name):
join_tables = []
"""
Resolve property by name and figure out its join path.
if isinstance(name, string_types):
model = self.model
Join path might contain both properties and tables.
"""
path = []
model = self.model
# For strings, resolve path
if isinstance(name, string_types):
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
join_tables.append(table)
path.append(value)
attr = value
else:
attr = name
# determine joins if Table.column (relation object) is given
if isinstance(name, InstrumentedAttribute):
columns = self._get_columns_for_field(name)
# Determine joins if table.column (relation object) is provided
if isinstance(attr, InstrumentedAttribute):
columns = self._get_columns_for_field(attr)
if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name)
column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if self._need_join(column.table):
join_tables.append(column.table)
path.append(column.table)
return join_tables, attr
return attr, path
def _need_join(self, table):
"""
Check if join to a table is necessary.
"""
return table not in self.model._sa_class_manager.mapper.tables
def _apply_path_joins(self, query, joins, path, inner_join=True):
"""
Apply join path to the query.
:param query:
Query to add joins to
:param joins:
List of current joins. Used to avoid joining on same relationship more than once
:param path:
Path to be joined
:param fn:
Join function
"""
last = None
if path:
for item in path:
key = (inner_join, item)
alias = joins.get(key)
if key not in joins:
if not isinstance(item, Table):
alias = aliased(item.property.mapper.class_)
fn = query.join if inner_join else query.outerjoin
if last is None:
query = fn(item) if alias is None else fn(alias, item)
else:
prop = getattr(last, item.key)
query = fn(prop) if alias is None else fn(alias, prop)
joins[key] = alias
last = alias
return query, joins, last
# Scaffolding
def scaffold_pk(self):
"""
......@@ -453,19 +502,19 @@ class ModelView(BaseModelView):
for c in self.column_sortable_list:
if isinstance(c, tuple):
join_tables, column = self._get_field_with_path(c[1])
column, path = self._get_field_with_path(c[1])
column_name = c[0]
elif isinstance(c, InstrumentedAttribute):
join_tables, column = self._get_field_with_path(c)
column, path = self._get_field_with_path(c)
column_name = str(c)
else:
join_tables, column = self._get_field_with_path(c)
column, path = self._get_field_with_path(c)
column_name = c
result[column_name] = column
if join_tables:
self._sortable_joins[column_name] = join_tables
if path:
self._sortable_joins[column_name] = path
return result
......@@ -479,26 +528,15 @@ class ModelView(BaseModelView):
"""
if self.column_searchable_list:
self._search_fields = []
self._search_joins = []
joins = set()
for p in self.column_searchable_list:
join_tables, attr = self._get_field_with_path(p)
attr, joins = self._get_field_with_path(p)
if not attr:
raise Exception('Failed to find field for search field: %s' % p)
for column in self._get_columns_for_field(attr):
column_type = type(column.type).__name__
self._search_fields.append(column)
# Store joins, avoid duplicates
for table in join_tables:
if table.name not in joins:
self._search_joins.append(table)
joins.add(table.name)
self._search_fields.append((column, joins))
return bool(self.column_searchable_list)
......@@ -507,7 +545,7 @@ class ModelView(BaseModelView):
Return list of enabled filters
"""
join_tables, attr = self._get_field_with_path(name)
attr, joins = self._get_field_with_path(name)
if attr is None:
raise Exception('Failed to find field for filter: %s' % name)
......@@ -535,10 +573,11 @@ class ModelView(BaseModelView):
if flt:
table = column.table
if join_tables:
self._filter_joins[table.name] = join_tables
if joins:
self._filter_joins[column] = joins
elif self._need_join(table):
self._filter_joins[table.name] = [table]
self._filter_joins[column] = [table]
filters.extend(flt)
return filters
......@@ -563,9 +602,6 @@ class ModelView(BaseModelView):
type_name = type(column.type).__name__
if join_tables:
self._filter_joins[column.table.name] = join_tables
flt = self.filter_converter.convert(
type_name,
column,
......@@ -573,8 +609,10 @@ class ModelView(BaseModelView):
options=self.column_choices.get(name),
)
if flt and not join_tables and self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
if joins:
self._filter_joins[column] = joins
elif self._need_join(column.table):
self._filter_joins[column] = [column.table]
return flt
......@@ -583,7 +621,7 @@ class ModelView(BaseModelView):
column = filter.column
if self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
self._filter_joins[column] = [column.table]
return filter
......@@ -707,27 +745,25 @@ class ModelView(BaseModelView):
:param query:
Query
:param joins:
Joins set
:pram joins:
Current joins
:param sort_joins:
Sort joins (properties or tables)
:param sort_field:
Sort field
:param sort_desc:
Ascending or descending
"""
# TODO: Preprocessing for joins
# Handle joins
if sort_joins:
for table in sort_joins:
if table.name not in joins:
query = query.outerjoin(table)
if sort_field is not None:
# Handle joins
query, joins, alias = self._apply_path_joins(query, joins, sort_joins, inner_join=False)
joins.add(table.name)
column = sort_field if alias is None else getattr(alias, sort_field.key)
if sort_field is not None:
if sort_desc:
query = query.order_by(desc(sort_field))
query = query.order_by(desc(column))
else:
query = query.order_by(sort_field)
query = query.order_by(column)
return query, joins
......@@ -737,12 +773,112 @@ class ModelView(BaseModelView):
if order is not None:
field, direction = order
join_tables, attr = self._get_field_with_path(field)
attr, joins = self._get_field_with_path(field)
return join_tables, attr, direction
return attr, joins, direction
return None
def _apply_sorting(self, query, joins, sort_column, sort_desc):
if sort_column is not None:
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_field, sort_joins, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins
def _apply_search(self, query, count_query, joins, count_joins, search):
"""
Apply search to a query.
"""
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = []
count_filter_stmt = []
for field, path in self._search_fields:
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
count_alias = None
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(count_query,
count_joins,
path,
inner_join=False)
column = field if alias is None else getattr(alias, field.key)
filter_stmt.append(column.ilike(stmt))
if count_filter_stmt is not None:
column = field if count_alias is None else getattr(count_alias, field.key)
count_filter_stmt.append(column.ilike(stmt))
query = query.filter(or_(*filter_stmt))
if count_query is not None:
count_query = count_query.filter(or_(*count_filter_stmt))
return query, count_query, joins, count_joins
def _apply_filters(self, query, count_query, joins, count_joins, filters):
for idx, flt_name, value in filters:
flt = self._filters[idx]
alias = None
count_alias = None
# Figure out joins
if isinstance(flt, sqla_filters.BaseSQLAFilter):
path = self._filter_joins.get(flt.column, [])
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(
count_query,
count_joins,
path,
inner_join=False)
# Clean value .clean() and apply the filter
clean_value = flt.clean(value)
try:
query = flt.apply(query, clean_value, alias)
except TypeError:
spec = inspect.getargspec(flt.apply)
if len(spec.args) == 2:
warnings.warn('Please update your custom filter %s to include additional `alias` parameter.' % repr(flt))
else:
raise
query = flt.apply(query, clean_value)
if count_query is not None:
try:
count_query = flt.apply(count_query, clean_value, count_alias)
except TypeError:
count_query = flt.apply(count_query, clean_value)
return query, count_query, joins, count_joins
def get_list(self, page, sort_column, sort_desc, search, filters, execute=True):
"""
Return models from the database.
......@@ -761,8 +897,9 @@ class ModelView(BaseModelView):
List of filter tuples
"""
# Will contain names of joined tables to avoid duplicate joins
joins = set()
# Will contain join paths with optional aliased object
joins = {}
count_joins = {}
query = self.get_query()
count_query = self.get_count_query() if not self.simple_list_pager else None
......@@ -772,60 +909,23 @@ class ModelView(BaseModelView):
if hasattr(query, '_join_entities'):
for entity in query._join_entities:
for table in entity.tables:
joins.add(table.name)
joins[table] = None
# Apply search criteria
if self._search_supported and search:
# Apply search-related joins
if self._search_joins:
for table in self._search_joins:
if table.name not in joins:
query = query.outerjoin(table)
if count_query is not None:
count_query = count_query.outerjoin(table)
joins.add(table.name)
# Apply terms
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = [c.ilike(stmt) for c in self._search_fields]
query = query.filter(or_(*filter_stmt))
if count_query is not None:
count_query = count_query.filter(or_(*filter_stmt))
query, count_query, joins, count_joins = self._apply_search(query,
count_query,
joins,
count_joins,
search)
# Apply filters
if filters and self._filters:
for idx, flt_name, value in filters:
flt = self._filters[idx]
# Figure out joins
if isinstance(flt, sqla_filters.BaseSQLAFilter):
tbl = flt.column.table.name
join_tables = self._filter_joins.get(tbl, [])
for table in join_tables:
if table.name not in joins:
query = query.join(table)
if count_query is not None:
count_query = count_query.join(table)
joins.add(table.name)
# turn into python format with .clean() and apply filter
query = flt.apply(query, flt.clean(value))
if count_query is not None:
count_query = flt.apply(count_query, flt.clean(value))
query, count_query, joins, count_joins = self._apply_filters(query,
count_query,
joins,
count_joins,
filters)
# Calculate number of rows if necessary
count = count_query.scalar() if count_query else None
......@@ -835,19 +935,7 @@ class ModelView(BaseModelView):
query = query.options(joinedload(j))
# Sorting
if sort_column is not None:
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_joins, sort_field, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
query, joins = self._apply_sorting(query, joins, sort_column, sort_desc)
# Pagination
if page is not None:
......
......@@ -260,10 +260,11 @@ def test_column_searchable_list():
eq_(view._search_supported, True)
eq_(len(view._search_fields), 2)
ok_(isinstance(view._search_fields[0], db.Column))
ok_(isinstance(view._search_fields[1], db.Column))
eq_(view._search_fields[0].name, 'string_field')
eq_(view._search_fields[1].name, 'int_field')
ok_(isinstance(view._search_fields[0][0], db.Column))
ok_(isinstance(view._search_fields[1][0], db.Column))
eq_(view._search_fields[0][0].name, 'string_field')
eq_(view._search_fields[1][0].name, 'int_field')
db.session.add(Model2('model1-test', 5000))
db.session.add(Model2('model2-test', 9000))
......@@ -418,6 +419,8 @@ def test_column_filters():
)
admin.add_view(view)
client = app.test_client()
eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
......@@ -516,21 +519,23 @@ def test_column_filters():
fill_db(db, Model1, Model2)
client = app.test_client()
# Test equals
rv = client.get('/admin/model1/?flt0_0=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
# the filter value is always in "data"
# need to check a different column than test1 for the expected row
ok_('test2_val_1' in data)
ok_('test1_val_2' not in data)
# Test NOT IN filter
rv = client.get('/admin/model1/?flt0_6=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test1_val_2' in data)
ok_('test2_val_1' not in data)
# Test string filter
view = CustomModelView(Model1, db.session,
......@@ -1104,9 +1109,11 @@ def test_column_filters():
rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1')
data = rv.data.decode('utf-8')
ok_('test1_val_1' in data)
ok_('test1_val_2' not in data)
def test_url_args():
app, db, admin = setup()
......@@ -1699,3 +1706,105 @@ def test_simple_list_pager():
count, data = view.get_list(0, None, None, None, None)
assert_true(count is None)
def test_advanced_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1, backref='model2')
class Model3(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model2_id = db.Column(db.Integer, db.ForeignKey(Model2.id))
model2 = db.relationship(Model2, backref='model3')
view1 = CustomModelView(Model1, db.session)
admin.add_view(view1)
view2 = CustomModelView(Model2, db.session)
admin.add_view(view2)
view3 = CustomModelView(Model3, db.session)
admin.add_view(view3)
# Test joins
attr, path = view2._get_field_with_path('model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2')
eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1])
# Test how joins are applied
query = view3.get_query()
joins = {}
q1, joins, alias = view3._apply_path_joins(query, joins, path)
ok_((True, Model3.model2) in joins)
ok_((True, Model2.model1) in joins)
ok_(alias is not None)
# Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2)
for p in q2._join_entities:
ok_(p in q1._join_entities)
ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path
attr, path = view2._get_field_with_path(Model1.test)
eq_(attr, Model1.test)
eq_(path, [Model1.__table__])
q3, joins, alias = view2._apply_path_joins(view2.get_query(), joins, path)
eq_(len(joins), 3)
ok_(alias is None)
def test_multipath_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
first_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
first = db.relationship(Model1, backref='first', foreign_keys=[first_id])
second_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
second = db.relationship(Model1, backref='second', foreign_keys=[second_id])
db.create_all()
view = CustomModelView(Model2, db.session, filters=['first.test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model2/')
eq_(rv.status_code, 200)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment