Commit db3ee4b0 authored by Serge S. Koval's avatar Serge S. Koval

Fixed #846, #808. Refactored JOIN logic so it now works with relationship...

Fixed #846, #808. Refactored JOIN logic so it now works with relationship properties instead of tables.

This is partially incompatible change - you will have to update your custom filters. `apply` method now
accepts additional `alias` argument. `alias` should be used to resolve column in query with specific JOIN path.

Long story short, old code which was looking like this:

```python
class MyFilter(SQLABaseFilter):
    def apply(self, query, value):
        return query.filter(self.column == value)
```

Should look like this:

```python
class MyFilter(SQLABaseFilter):
    def apply(self, query, value, alias=None):
        return query.filter(self.get_column(alias) == value)
```
parent 3935d7db
...@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView): ...@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
# List of columns that can be sorted. For 'user' column, use User.username as # List of columns that can be sorted. For 'user' column, use User.username as
# a column. # a column.
column_sortable_list = ('title', ('user', User.username), 'date') column_sortable_list = ('title', ('user', 'user.username'), 'date')
# Rename 'title' columns to 'Post Title' in list view # Rename 'title' columns to 'Post Title' in list view
column_labels = dict(title='Post Title') column_labels = dict(title='Post Title')
......
import warnings
import time
import datetime
from flask_admin.babel import lazy_gettext from flask_admin.babel import lazy_gettext
from flask_admin.model import filters from flask_admin.model import filters
from flask_admin.contrib.sqla import tools from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_ from sqlalchemy.sql import not_, or_
class BaseSQLAFilter(filters.BaseFilter): class BaseSQLAFilter(filters.BaseFilter):
""" """
Base SQLAlchemy filter. Base SQLAlchemy filter.
...@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter): ...@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
self.column = column self.column = column
def get_column(self, alias):
return self.column if alias is None else getattr(alias, self.column.key)
def apply(self, query, value, alias=None):
return super(self, BaseSQLAFilter).apply(query, value)
# Common filters # Common filters
class FilterEqual(BaseSQLAFilter): class FilterEqual(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column == value) return query.filter(self.get_column(alias) == value)
def operation(self): def operation(self):
return lazy_gettext('equals') return lazy_gettext('equals')
class FilterNotEqual(BaseSQLAFilter): class FilterNotEqual(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column != value) return query.filter(self.get_column(alias) != value)
def operation(self): def operation(self):
return lazy_gettext('not equal') return lazy_gettext('not equal')
class FilterLike(BaseSQLAFilter): class FilterLike(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value) stmt = tools.parse_like_term(value)
return query.filter(self.column.ilike(stmt)) return query.filter(self.get_column(alias).ilike(stmt))
def operation(self): def operation(self):
return lazy_gettext('contains') return lazy_gettext('contains')
class FilterNotLike(BaseSQLAFilter): class FilterNotLike(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value) stmt = tools.parse_like_term(value)
return query.filter(~self.column.ilike(stmt)) return query.filter(~self.get_column(alias).ilike(stmt))
def operation(self): def operation(self):
return lazy_gettext('not contains') return lazy_gettext('not contains')
class FilterGreater(BaseSQLAFilter): class FilterGreater(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column > value) return query.filter(self.get_column(alias) > value)
def operation(self): def operation(self):
return lazy_gettext('greater than') return lazy_gettext('greater than')
class FilterSmaller(BaseSQLAFilter): class FilterSmaller(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column < value) return query.filter(self.get_column(alias) < value)
def operation(self): def operation(self):
return lazy_gettext('smaller than') return lazy_gettext('smaller than')
class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter): class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
if value == '1': if value == '1':
return query.filter(self.column == None) return query.filter(self.get_column(alias) == None)
else: else:
return query.filter(self.column != None) return query.filter(self.get_column(alias) != None)
def operation(self): def operation(self):
return lazy_gettext('empty') return lazy_gettext('empty')
...@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter): ...@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
def clean(self, value): def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()] return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column.in_(value)) return query.filter(self.get_column(alias).in_(value))
def operation(self): def operation(self):
return lazy_gettext('in list') return lazy_gettext('in list')
class FilterNotInList(FilterInList): class FilterNotInList(FilterInList):
def apply(self, query, value): def apply(self, query, value, alias=None):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added # NOT IN can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(~self.column.in_(value), self.column == None)) column = self.get_column(alias)
return query.filter(or_(~column.in_(value), column == None))
def operation(self): def operation(self):
return lazy_gettext('not in list') return lazy_gettext('not in list')
...@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter): ...@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
options, options,
data_type='daterangepicker') data_type='daterangepicker')
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(self.column.between(start, end)) return query.filter(self.get_column(alias).between(start, end))
class DateNotBetweenFilter(DateBetweenFilter): class DateNotBetweenFilter(DateBetweenFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
# ~between() isn't possible until sqlalchemy 1.0.0 # ~between() isn't possible until sqlalchemy 1.0.0
return query.filter(not_(self.column.between(start, end))) return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self): def operation(self):
return lazy_gettext('not between') return lazy_gettext('not between')
...@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter): ...@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
options, options,
data_type='datetimerangepicker') data_type='datetimerangepicker')
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(self.column.between(start, end)) return query.filter(self.get_column(alias).between(start, end))
class DateTimeNotBetweenFilter(DateTimeBetweenFilter): class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(not_(self.column.between(start, end))) return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self): def operation(self):
return lazy_gettext('not between') return lazy_gettext('not between')
...@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter): ...@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
options, options,
data_type='timerangepicker') data_type='timerangepicker')
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(self.column.between(start, end)) return query.filter(self.get_column(alias).between(start, end))
class TimeNotBetweenFilter(TimeBetweenFilter): class TimeNotBetweenFilter(TimeBetweenFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(not_(self.column.between(start, end))) return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self): def operation(self):
return lazy_gettext('not between') return lazy_gettext('not between')
......
import logging import logging
import warnings import warnings
import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload, aliased
from sqlalchemy.sql.expression import desc from sqlalchemy.sql.expression import desc
from sqlalchemy import Boolean, func, or_ from sqlalchemy import Boolean, Table, func, or_
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from flask import flash from flask import flash
...@@ -276,7 +277,6 @@ class ModelView(BaseModelView): ...@@ -276,7 +277,6 @@ class ModelView(BaseModelView):
self.session = session self.session = session
self._search_fields = None self._search_fields = None
self._search_joins = []
self._filter_joins = dict() self._filter_joins = dict()
...@@ -322,43 +322,92 @@ class ModelView(BaseModelView): ...@@ -322,43 +322,92 @@ class ModelView(BaseModelView):
return field.property.columns return field.property.columns
def _get_field_with_path(self, name): def _get_field_with_path(self, name):
join_tables = [] """
Resolve property by name and figure out its join path.
if isinstance(name, string_types): Join path might contain both properties and tables.
model = self.model """
path = []
model = self.model
# For strings, resolve path
if isinstance(name, string_types):
for attribute in name.split('.'): for attribute in name.split('.'):
value = getattr(model, attribute) value = getattr(model, attribute)
if (hasattr(value, 'property') and if (hasattr(value, 'property') and
hasattr(value.property, 'direction')): hasattr(value.property, 'direction')):
model = value.property.mapper.class_ model = value.property.mapper.class_
table = model.__table__ table = model.__table__
if self._need_join(table): if self._need_join(table):
join_tables.append(table) path.append(value)
attr = value attr = value
else: else:
attr = name attr = name
# determine joins if Table.column (relation object) is given # Determine joins if table.column (relation object) is provided
if isinstance(name, InstrumentedAttribute): if isinstance(attr, InstrumentedAttribute):
columns = self._get_columns_for_field(name) columns = self._get_columns_for_field(attr)
if len(columns) > 1: if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name) raise Exception('Can only handle one column for %s' % name)
column = columns[0] column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if self._need_join(column.table): if self._need_join(column.table):
join_tables.append(column.table) path.append(column.table)
return join_tables, attr return attr, path
def _need_join(self, table): def _need_join(self, table):
"""
Check if join to a table is necessary.
"""
return table not in self.model._sa_class_manager.mapper.tables return table not in self.model._sa_class_manager.mapper.tables
def _apply_path_joins(self, query, joins, path, inner_join=True):
"""
Apply join path to the query.
:param query:
Query to add joins to
:param joins:
List of current joins. Used to avoid joining on same relationship more than once
:param path:
Path to be joined
:param fn:
Join function
"""
last = None
if path:
for item in path:
key = (inner_join, item)
alias = joins.get(key)
if key not in joins:
if not isinstance(item, Table):
alias = aliased(item.property.mapper.class_)
fn = query.join if inner_join else query.outerjoin
if last is None:
query = fn(item) if alias is None else fn(alias, item)
else:
prop = getattr(last, item.key)
query = fn(prop) if alias is None else fn(alias, prop)
joins[key] = alias
last = alias
return query, joins, last
# Scaffolding # Scaffolding
def scaffold_pk(self): def scaffold_pk(self):
""" """
...@@ -453,19 +502,19 @@ class ModelView(BaseModelView): ...@@ -453,19 +502,19 @@ class ModelView(BaseModelView):
for c in self.column_sortable_list: for c in self.column_sortable_list:
if isinstance(c, tuple): if isinstance(c, tuple):
join_tables, column = self._get_field_with_path(c[1]) column, path = self._get_field_with_path(c[1])
column_name = c[0] column_name = c[0]
elif isinstance(c, InstrumentedAttribute): elif isinstance(c, InstrumentedAttribute):
join_tables, column = self._get_field_with_path(c) column, path = self._get_field_with_path(c)
column_name = str(c) column_name = str(c)
else: else:
join_tables, column = self._get_field_with_path(c) column, path = self._get_field_with_path(c)
column_name = c column_name = c
result[column_name] = column result[column_name] = column
if join_tables: if path:
self._sortable_joins[column_name] = join_tables self._sortable_joins[column_name] = path
return result return result
...@@ -479,26 +528,15 @@ class ModelView(BaseModelView): ...@@ -479,26 +528,15 @@ class ModelView(BaseModelView):
""" """
if self.column_searchable_list: if self.column_searchable_list:
self._search_fields = [] self._search_fields = []
self._search_joins = []
joins = set()
for p in self.column_searchable_list: for p in self.column_searchable_list:
join_tables, attr = self._get_field_with_path(p) attr, joins = self._get_field_with_path(p)
if not attr: if not attr:
raise Exception('Failed to find field for search field: %s' % p) raise Exception('Failed to find field for search field: %s' % p)
for column in self._get_columns_for_field(attr): for column in self._get_columns_for_field(attr):
column_type = type(column.type).__name__ self._search_fields.append((column, joins))
self._search_fields.append(column)
# Store joins, avoid duplicates
for table in join_tables:
if table.name not in joins:
self._search_joins.append(table)
joins.add(table.name)
return bool(self.column_searchable_list) return bool(self.column_searchable_list)
...@@ -507,7 +545,7 @@ class ModelView(BaseModelView): ...@@ -507,7 +545,7 @@ class ModelView(BaseModelView):
Return list of enabled filters Return list of enabled filters
""" """
join_tables, attr = self._get_field_with_path(name) attr, joins = self._get_field_with_path(name)
if attr is None: if attr is None:
raise Exception('Failed to find field for filter: %s' % name) raise Exception('Failed to find field for filter: %s' % name)
...@@ -535,10 +573,11 @@ class ModelView(BaseModelView): ...@@ -535,10 +573,11 @@ class ModelView(BaseModelView):
if flt: if flt:
table = column.table table = column.table
if join_tables: if joins:
self._filter_joins[table.name] = join_tables self._filter_joins[column] = joins
elif self._need_join(table): elif self._need_join(table):
self._filter_joins[table.name] = [table] self._filter_joins[column] = [table]
filters.extend(flt) filters.extend(flt)
return filters return filters
...@@ -563,9 +602,6 @@ class ModelView(BaseModelView): ...@@ -563,9 +602,6 @@ class ModelView(BaseModelView):
type_name = type(column.type).__name__ type_name = type(column.type).__name__
if join_tables:
self._filter_joins[column.table.name] = join_tables
flt = self.filter_converter.convert( flt = self.filter_converter.convert(
type_name, type_name,
column, column,
...@@ -573,8 +609,10 @@ class ModelView(BaseModelView): ...@@ -573,8 +609,10 @@ class ModelView(BaseModelView):
options=self.column_choices.get(name), options=self.column_choices.get(name),
) )
if flt and not join_tables and self._need_join(column.table): if joins:
self._filter_joins[column.table.name] = [column.table] self._filter_joins[column] = joins
elif self._need_join(column.table):
self._filter_joins[column] = [column.table]
return flt return flt
...@@ -583,7 +621,7 @@ class ModelView(BaseModelView): ...@@ -583,7 +621,7 @@ class ModelView(BaseModelView):
column = filter.column column = filter.column
if self._need_join(column.table): if self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table] self._filter_joins[column] = [column.table]
return filter return filter
...@@ -707,27 +745,25 @@ class ModelView(BaseModelView): ...@@ -707,27 +745,25 @@ class ModelView(BaseModelView):
:param query: :param query:
Query Query
:param joins: :pram joins:
Joins set Current joins
:param sort_joins:
Sort joins (properties or tables)
:param sort_field: :param sort_field:
Sort field Sort field
:param sort_desc: :param sort_desc:
Ascending or descending Ascending or descending
""" """
# TODO: Preprocessing for joins if sort_field is not None:
# Handle joins # Handle joins
if sort_joins: query, joins, alias = self._apply_path_joins(query, joins, sort_joins, inner_join=False)
for table in sort_joins:
if table.name not in joins:
query = query.outerjoin(table)
joins.add(table.name) column = sort_field if alias is None else getattr(alias, sort_field.key)
if sort_field is not None:
if sort_desc: if sort_desc:
query = query.order_by(desc(sort_field)) query = query.order_by(desc(column))
else: else:
query = query.order_by(sort_field) query = query.order_by(column)
return query, joins return query, joins
...@@ -737,12 +773,112 @@ class ModelView(BaseModelView): ...@@ -737,12 +773,112 @@ class ModelView(BaseModelView):
if order is not None: if order is not None:
field, direction = order field, direction = order
join_tables, attr = self._get_field_with_path(field) attr, joins = self._get_field_with_path(field)
return join_tables, attr, direction return attr, joins, direction
return None return None
def _apply_sorting(self, query, joins, sort_column, sort_desc):
if sort_column is not None:
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_field, sort_joins, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins
def _apply_search(self, query, count_query, joins, count_joins, search):
"""
Apply search to a query.
"""
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = []
count_filter_stmt = []
for field, path in self._search_fields:
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
count_alias = None
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(count_query,
count_joins,
path,
inner_join=False)
column = field if alias is None else getattr(alias, field.key)
filter_stmt.append(column.ilike(stmt))
if count_filter_stmt is not None:
column = field if count_alias is None else getattr(count_alias, field.key)
count_filter_stmt.append(column.ilike(stmt))
query = query.filter(or_(*filter_stmt))
if count_query is not None:
count_query = count_query.filter(or_(*count_filter_stmt))
return query, count_query, joins, count_joins
def _apply_filters(self, query, count_query, joins, count_joins, filters):
for idx, flt_name, value in filters:
flt = self._filters[idx]
alias = None
count_alias = None
# Figure out joins
if isinstance(flt, sqla_filters.BaseSQLAFilter):
path = self._filter_joins.get(flt.column, [])
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(
count_query,
count_joins,
path,
inner_join=False)
# Clean value .clean() and apply the filter
clean_value = flt.clean(value)
try:
query = flt.apply(query, clean_value, alias)
except TypeError:
spec = inspect.getargspec(flt.apply)
if len(spec.args) == 2:
warnings.warn('Please update your custom filter %s to include additional `alias` parameter.' % repr(flt))
else:
raise
query = flt.apply(query, clean_value)
if count_query is not None:
try:
count_query = flt.apply(count_query, clean_value, count_alias)
except TypeError:
count_query = flt.apply(count_query, clean_value)
return query, count_query, joins, count_joins
def get_list(self, page, sort_column, sort_desc, search, filters, execute=True): def get_list(self, page, sort_column, sort_desc, search, filters, execute=True):
""" """
Return models from the database. Return models from the database.
...@@ -761,8 +897,9 @@ class ModelView(BaseModelView): ...@@ -761,8 +897,9 @@ class ModelView(BaseModelView):
List of filter tuples List of filter tuples
""" """
# Will contain names of joined tables to avoid duplicate joins # Will contain join paths with optional aliased object
joins = set() joins = {}
count_joins = {}
query = self.get_query() query = self.get_query()
count_query = self.get_count_query() if not self.simple_list_pager else None count_query = self.get_count_query() if not self.simple_list_pager else None
...@@ -772,60 +909,23 @@ class ModelView(BaseModelView): ...@@ -772,60 +909,23 @@ class ModelView(BaseModelView):
if hasattr(query, '_join_entities'): if hasattr(query, '_join_entities'):
for entity in query._join_entities: for entity in query._join_entities:
for table in entity.tables: for table in entity.tables:
joins.add(table.name) joins[table] = None
# Apply search criteria # Apply search criteria
if self._search_supported and search: if self._search_supported and search:
# Apply search-related joins query, count_query, joins, count_joins = self._apply_search(query,
if self._search_joins: count_query,
for table in self._search_joins: joins,
if table.name not in joins: count_joins,
query = query.outerjoin(table) search)
if count_query is not None:
count_query = count_query.outerjoin(table)
joins.add(table.name)
# Apply terms
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = [c.ilike(stmt) for c in self._search_fields]
query = query.filter(or_(*filter_stmt))
if count_query is not None:
count_query = count_query.filter(or_(*filter_stmt))
# Apply filters # Apply filters
if filters and self._filters: if filters and self._filters:
for idx, flt_name, value in filters: query, count_query, joins, count_joins = self._apply_filters(query,
flt = self._filters[idx] count_query,
joins,
# Figure out joins count_joins,
if isinstance(flt, sqla_filters.BaseSQLAFilter): filters)
tbl = flt.column.table.name
join_tables = self._filter_joins.get(tbl, [])
for table in join_tables:
if table.name not in joins:
query = query.join(table)
if count_query is not None:
count_query = count_query.join(table)
joins.add(table.name)
# turn into python format with .clean() and apply filter
query = flt.apply(query, flt.clean(value))
if count_query is not None:
count_query = flt.apply(count_query, flt.clean(value))
# Calculate number of rows if necessary # Calculate number of rows if necessary
count = count_query.scalar() if count_query else None count = count_query.scalar() if count_query else None
...@@ -835,19 +935,7 @@ class ModelView(BaseModelView): ...@@ -835,19 +935,7 @@ class ModelView(BaseModelView):
query = query.options(joinedload(j)) query = query.options(joinedload(j))
# Sorting # Sorting
if sort_column is not None: query, joins = self._apply_sorting(query, joins, sort_column, sort_desc)
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_joins, sort_field, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
# Pagination # Pagination
if page is not None: if page is not None:
......
...@@ -260,10 +260,11 @@ def test_column_searchable_list(): ...@@ -260,10 +260,11 @@ def test_column_searchable_list():
eq_(view._search_supported, True) eq_(view._search_supported, True)
eq_(len(view._search_fields), 2) eq_(len(view._search_fields), 2)
ok_(isinstance(view._search_fields[0], db.Column))
ok_(isinstance(view._search_fields[1], db.Column)) ok_(isinstance(view._search_fields[0][0], db.Column))
eq_(view._search_fields[0].name, 'string_field') ok_(isinstance(view._search_fields[1][0], db.Column))
eq_(view._search_fields[1].name, 'int_field') eq_(view._search_fields[0][0].name, 'string_field')
eq_(view._search_fields[1][0].name, 'int_field')
db.session.add(Model2('model1-test', 5000)) db.session.add(Model2('model1-test', 5000))
db.session.add(Model2('model2-test', 9000)) db.session.add(Model2('model2-test', 9000))
...@@ -418,6 +419,8 @@ def test_column_filters(): ...@@ -418,6 +419,8 @@ def test_column_filters():
) )
admin.add_view(view) admin.add_view(view)
client = app.test_client()
eq_(len(view._filters), 7) eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']], eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
...@@ -516,21 +519,23 @@ def test_column_filters(): ...@@ -516,21 +519,23 @@ def test_column_filters():
fill_db(db, Model1, Model2) fill_db(db, Model1, Model2)
client = app.test_client() # Test equals
rv = client.get('/admin/model1/?flt0_0=test1_val_1') rv = client.get('/admin/model1/?flt0_0=test1_val_1')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
# the filter value is always in "data" # the filter value is always in "data"
# need to check a different column than test1 for the expected row # need to check a different column than test1 for the expected row
ok_('test2_val_1' in data) ok_('test2_val_1' in data)
ok_('test1_val_2' not in data) ok_('test1_val_2' not in data)
# Test NOT IN filter
rv = client.get('/admin/model1/?flt0_6=test1_val_1') rv = client.get('/admin/model1/?flt0_6=test1_val_1')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test1_val_2' in data) ok_('test1_val_2' in data)
ok_('test2_val_1' not in data)
# Test string filter # Test string filter
view = CustomModelView(Model1, db.session, view = CustomModelView(Model1, db.session,
...@@ -1104,9 +1109,11 @@ def test_column_filters(): ...@@ -1104,9 +1109,11 @@ def test_column_filters():
rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1') rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1')
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('test1_val_1' in data) ok_('test1_val_1' in data)
ok_('test1_val_2' not in data) ok_('test1_val_2' not in data)
def test_url_args(): def test_url_args():
app, db, admin = setup() app, db, admin = setup()
...@@ -1699,3 +1706,105 @@ def test_simple_list_pager(): ...@@ -1699,3 +1706,105 @@ def test_simple_list_pager():
count, data = view.get_list(0, None, None, None, None) count, data = view.get_list(0, None, None, None, None)
assert_true(count is None) assert_true(count is None)
def test_advanced_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1, backref='model2')
class Model3(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model2_id = db.Column(db.Integer, db.ForeignKey(Model2.id))
model2 = db.relationship(Model2, backref='model3')
view1 = CustomModelView(Model1, db.session)
admin.add_view(view1)
view2 = CustomModelView(Model2, db.session)
admin.add_view(view2)
view3 = CustomModelView(Model3, db.session)
admin.add_view(view3)
# Test joins
attr, path = view2._get_field_with_path('model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2')
eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1])
# Test how joins are applied
query = view3.get_query()
joins = {}
q1, joins, alias = view3._apply_path_joins(query, joins, path)
ok_((True, Model3.model2) in joins)
ok_((True, Model2.model1) in joins)
ok_(alias is not None)
# Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2)
for p in q2._join_entities:
ok_(p in q1._join_entities)
ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path
attr, path = view2._get_field_with_path(Model1.test)
eq_(attr, Model1.test)
eq_(path, [Model1.__table__])
q3, joins, alias = view2._apply_path_joins(view2.get_query(), joins, path)
eq_(len(joins), 3)
ok_(alias is None)
def test_multipath_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
first_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
first = db.relationship(Model1, backref='first', foreign_keys=[first_id])
second_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
second = db.relationship(Model1, backref='second', foreign_keys=[second_id])
db.create_all()
view = CustomModelView(Model2, db.session, filters=['first.test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model2/')
eq_(rv.status_code, 200)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment