Commit 664f4622 authored by Serge S. Koval's avatar Serge S. Koval

Fixed #435. Support searching through intermediate relations

parent d11b26d4
...@@ -371,21 +371,36 @@ class ModelView(BaseModelView): ...@@ -371,21 +371,36 @@ class ModelView(BaseModelView):
return columns return columns
def _get_columns_for_field(self, field): def _get_columns_for_field(self, field):
if isinstance(field, string_types): if (not field or
attr = getattr(self.model, field, None) not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
if field is None: return field.property.columns
raise Exception('Field %s was not found.' % field)
else:
attr = field
if (not attr or def _get_field_with_path(self, name):
not hasattr(attr, 'property') or join_tables = []
not hasattr(attr.property, 'columns') or
not attr.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return attr.property.columns if isinstance(name, string_types):
model = self.model
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
join_tables.append(table)
attr = value
else:
attr = name
return join_tables, attr
def _need_join(self, table): def _need_join(self, table):
return table not in self.model._sa_class_manager.mapper.tables return table not in self.model._sa_class_manager.mapper.tables
...@@ -403,7 +418,12 @@ class ModelView(BaseModelView): ...@@ -403,7 +418,12 @@ class ModelView(BaseModelView):
self._search_joins = dict() self._search_joins = dict()
for p in self.column_searchable_list: for p in self.column_searchable_list:
for column in self._get_columns_for_field(p): join_tables, attr = self._get_field_with_path(p)
if not attr:
raise Exception('Failed to find field for search field: %s' % p)
for column in self._get_columns_for_field(attr):
column_type = type(column.type).__name__ column_type = type(column.type).__name__
if not self.is_text_column_type(column_type): if not self.is_text_column_type(column_type):
...@@ -412,9 +432,10 @@ class ModelView(BaseModelView): ...@@ -412,9 +432,10 @@ class ModelView(BaseModelView):
self._search_fields.append(column) self._search_fields.append(column)
# If it belongs to different table - add a join # Store joins, avoid duplicates
if self._need_join(column.table): if join_tables:
self._search_joins[column.table.name] = column.table for table in join_tables:
self._search_joins[table.name] = table
return bool(self.column_searchable_list) return bool(self.column_searchable_list)
...@@ -435,23 +456,7 @@ class ModelView(BaseModelView): ...@@ -435,23 +456,7 @@ class ModelView(BaseModelView):
Return list of enabled filters Return list of enabled filters
""" """
join_tables = [] join_tables, attr = self._get_field_with_path(name)
if isinstance(name, string_types):
model = self.model
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
join_tables.append(table)
attr = value
else:
attr = name
if attr is None: if attr is None:
raise Exception('Failed to find field for filter: %s' % name) raise Exception('Failed to find field for filter: %s' % name)
......
...@@ -38,6 +38,9 @@ def create_models(db): ...@@ -38,6 +38,9 @@ def create_models(db):
bool_field = db.Column(db.Boolean) bool_field = db.Column(db.Boolean)
enum_field = db.Column(db.Enum('model1_v1', 'model1_v1'), nullable=True) enum_field = db.Column(db.Enum('model1_v1', 'model1_v1'), nullable=True)
def __unicode__(self):
return self.test1
def __str__(self): def __str__(self):
return self.test1 return self.test1
...@@ -220,6 +223,29 @@ def test_column_searchable_list(): ...@@ -220,6 +223,29 @@ def test_column_searchable_list():
ok_('model2' not in data) ok_('model2' not in data)
def test_complex_searchable_list():
app, db, admin = setup()
Model1, Model2 = create_models(db)
view = CustomModelView(Model2, db.session,
column_searchable_list=['model1.test1'])
admin.add_view(view)
m1 = Model1('model1')
db.session.add(m1)
db.session.add(Model2('model2', model1=m1))
db.session.add(Model2('model3'))
db.session.commit()
client = app.test_client()
rv = client.get('/admin/model2/?search=model1')
data = rv.data.decode('utf-8')
ok_('model1' in data)
ok_('model3' not in data)
def test_column_filters(): def test_column_filters():
app, db, admin = setup() app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment