Commit 459010b0 authored by Paul Brown's avatar Paul Brown

allow using model objects in form_columns, columns in other tables throw exceptions

parent f17a9bbe
......@@ -9,13 +9,14 @@ from flask_admin.model.form import (converts, ModelConverterBase,
from flask_admin.model.fields import AjaxSelectField, AjaxSelectMultipleField
from flask_admin.model.helpers import prettify_name
from flask_admin._backwards import get_property
from flask_admin._compat import iteritems
from flask_admin._compat import iteritems, text_type
from .validators import Unique
from .fields import (QuerySelectField, QuerySelectMultipleField,
InlineModelFormList, InlineHstoreList, HstoreForm)
from flask_admin.model.fields import InlineFormField
from .tools import has_multiple_pks, filter_foreign_columns
from .tools import (has_multiple_pks, filter_foreign_columns,
get_field_with_path)
from .ajax import create_ajax_loader
......@@ -406,28 +407,26 @@ def get_form(model, converter,
properties = ((p.key, p) for p in mapper.iterate_properties)
if only:
props = dict(properties)
def find(name):
# If field is in extra_fields, it has higher priority
if extra_fields and name in extra_fields:
return FieldPlaceholder(extra_fields[name])
return name, FieldPlaceholder(extra_fields[name])
column, path = get_field_with_path(model, name)
# Try to look it up in properties list first
p = props.get(name)
if path and not hasattr(column.prop, 'direction'):
raise Exception("form column is located in another table and "
"requires inline_models: {0}".format(name))
if p is not None:
return p
name = column.key
# If it is hybrid property or alias, look it up in a model itself
p = getattr(model, name, None)
if p is not None and hasattr(p, 'property'):
return p.property
if column is not None and hasattr(column, 'property'):
return name, column.property
raise ValueError('Invalid model property name %s.%s' % (model, name))
# Filter properties while maintaining property order in 'only' list
properties = ((x, find(x)) for x in only)
properties = (find(x) for x in only)
elif exclude:
properties = (x for x in properties if x[0] not in exclude)
......
from sqlalchemy import tuple_, or_, and_
from sqlalchemy.sql.operators import eq
from sqlalchemy.exc import DBAPIError
from ast import literal_eval
from sqlalchemy.orm.attributes import InstrumentedAttribute
from flask_admin._compat import filter_list
from flask_admin._compat import filter_list, string_types
from flask_admin.tools import iterencode, iterdecode, escape
......@@ -108,3 +108,65 @@ def get_query_for_ids(modelquery, model, ids):
query = modelquery.filter(model_pk.in_(ids))
return query
def get_columns_for_field(field):
if (not field or
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def need_join(model, table):
"""
Check if join to a table is necessary.
"""
return table not in model._sa_class_manager.mapper.tables
def get_field_with_path(model, name):
"""
Resolve property by name and figure out its join path.
Join path might contain both properties and tables.
"""
path = []
# For strings, resolve path
if isinstance(name, string_types):
# create a copy to keep original model as `model`
current_model = model
for attribute in name.split('.'):
value = getattr(current_model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
current_model = value.property.mapper.class_
table = current_model.__table__
if need_join(model, table):
path.append(value)
attr = value
else:
attr = name
# Determine joins if table.column (relation object) is provided
if isinstance(attr, InstrumentedAttribute):
columns = get_columns_for_field(attr)
if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name)
column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if need_join(model, column.table):
path.append(column.table)
return attr, path
......@@ -21,7 +21,6 @@ from flask_admin._backwards import ObsoleteAttr
from flask_admin.contrib.sqla import form, filters as sqla_filters, tools
from .typefmt import DEFAULT_FORMATTERS
from .tools import get_query_for_ids
from .ajax import create_ajax_loader
# Set up logger
......@@ -338,64 +337,6 @@ class ModelView(BaseModelView):
return model._sa_class_manager.mapper.iterate_properties
def _get_columns_for_field(self, field):
if (not field or
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def _get_field_with_path(self, name):
"""
Resolve property by name and figure out its join path.
Join path might contain both properties and tables.
"""
path = []
model = self.model
# For strings, resolve path
if isinstance(name, string_types):
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
path.append(value)
attr = value
else:
attr = name
# Determine joins if table.column (relation object) is provided
if isinstance(attr, InstrumentedAttribute):
columns = self._get_columns_for_field(attr)
if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name)
column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if self._need_join(column.table):
path.append(column.table)
return attr, path
def _need_join(self, table):
"""
Check if join to a table is necessary.
"""
return table not in self.model._sa_class_manager.mapper.tables
def _apply_path_joins(self, query, joins, path, inner_join=True):
"""
Apply join path to the query.
......@@ -528,13 +469,13 @@ class ModelView(BaseModelView):
for c in self.column_sortable_list:
if isinstance(c, tuple):
column, path = self._get_field_with_path(c[1])
column, path = tools.get_field_with_path(self.model, c[1])
column_name = c[0]
elif isinstance(c, InstrumentedAttribute):
column, path = self._get_field_with_path(c)
column, path = tools.get_field_with_path(self.model, c)
column_name = str(c)
else:
column, path = self._get_field_with_path(c)
column, path = tools.get_field_with_path(self.model, c)
column_name = c
result[column_name] = column
......@@ -556,12 +497,12 @@ class ModelView(BaseModelView):
self._search_fields = []
for p in self.column_searchable_list:
attr, joins = self._get_field_with_path(p)
attr, joins = tools.get_field_with_path(self.model, p)
if not attr:
raise Exception('Failed to find field for search field: %s' % p)
for column in self._get_columns_for_field(attr):
for column in tools.get_columns_for_field(attr):
self._search_fields.append((column, joins))
return bool(self.column_searchable_list)
......@@ -571,7 +512,7 @@ class ModelView(BaseModelView):
Return list of enabled filters
"""
attr, joins = self._get_field_with_path(name)
attr, joins = tools.get_field_with_path(self.model, name)
if attr is None:
raise Exception('Failed to find field for filter: %s' % name)
......@@ -604,21 +545,22 @@ class ModelView(BaseModelView):
if joins:
self._filter_joins[column] = joins
elif self._need_join(table):
elif tools.need_join(self.model, table):
self._filter_joins[column] = [table]
filters.extend(flt)
return filters
else:
columns = self._get_columns_for_field(attr)
columns = tools.get_columns_for_field(attr)
if len(columns) > 1:
raise Exception('Can not filter more than on one column for %s' % name)
column = columns[0]
if self._need_join(column.table) and name not in self.column_labels:
if (tools.need_join(self.model, column.table) and
name not in self.column_labels):
visible_name = '%s / %s' % (
self.get_column_name(column.table.name),
self.get_column_name(column.name)
......@@ -640,7 +582,7 @@ class ModelView(BaseModelView):
if joins:
self._filter_joins[column] = joins
elif self._need_join(column.table):
elif tools.need_join(self.model, column.table):
self._filter_joins[column] = [column.table]
return flt
......@@ -651,7 +593,7 @@ class ModelView(BaseModelView):
# hybrid_property joins are not supported yet
if (isinstance(column, InstrumentedAttribute) and
self._need_join(column.table)):
tools.need_join(self.model, column.table)):
self._filter_joins[column] = [column.table]
return filter
......@@ -802,7 +744,7 @@ class ModelView(BaseModelView):
if order is not None:
field, direction = order
attr, joins = self._get_field_with_path(field)
attr, joins = tools.get_field_with_path(self.model, field)
return attr, joins, direction
......@@ -1100,7 +1042,7 @@ class ModelView(BaseModelView):
lazy_gettext('Are you sure you want to delete selected records?'))
def action_delete(self, ids):
try:
query = get_query_for_ids(self.get_query(), self.model, ids)
query = tools.get_query_for_ids(self.get_query(), self.model, ids)
if self.fast_mass_delete:
count = query.delete(synchronize_session=False)
......
......@@ -5,7 +5,7 @@ from wtforms import fields, validators
from flask_admin import form
from flask_admin._compat import as_unicode
from flask_admin._compat import iteritems
from flask_admin.contrib.sqla import ModelView, filters
from flask_admin.contrib.sqla import ModelView, filters, tools
from flask_babelex import Babel
from sqlalchemy.ext.hybrid import hybrid_property
......@@ -1444,6 +1444,22 @@ def test_form_columns():
ok_(type(form3.model).__name__ == 'QuerySelectField')
# test form_columns with model objects
view4 = CustomModelView(Model, db.session, endpoint='view1',
form_columns=[Model.int_field])
form4 = view4.create_form()
ok_('int_field' in form4._fields)
@raises(Exception)
def test_complex_form_columns():
app, db, admin = setup()
M1, M2 = create_models(db)
# test using a form column in another table
view = CustomModelView(M2, db.session, form_columns=['model1.test1'])
form = view.create_form()
def test_form_args():
app, db, admin = setup()
......@@ -1964,15 +1980,15 @@ def test_advanced_joins():
admin.add_view(view3)
# Test joins
attr, path = view2._get_field_with_path('model1.val1')
attr, path = tools.get_field_with_path(Model2, 'model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2')
attr, path = tools.get_field_with_path(Model1, 'model2.val2')
eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1')
attr, path = tools.get_field_with_path(Model3, 'model2.model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1])
......@@ -1986,7 +2002,7 @@ def test_advanced_joins():
ok_(alias is not None)
# Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test')
attr, path = tools.get_field_with_path(Model2, 'model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2)
......@@ -1995,8 +2011,8 @@ def test_advanced_joins():
ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path
attr, path = view2._get_field_with_path(Model1.test)
# Check if normal properties are supported by tools.get_field_with_path
attr, path = tools.get_field_with_path(Model2, Model1.test)
eq_(attr, Model1.test)
eq_(path, [Model1.__table__])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment