Commit 459010b0 authored by Paul Brown's avatar Paul Brown

allow using model objects in form_columns, columns in other tables throw exceptions

parent f17a9bbe
...@@ -9,13 +9,14 @@ from flask_admin.model.form import (converts, ModelConverterBase, ...@@ -9,13 +9,14 @@ from flask_admin.model.form import (converts, ModelConverterBase,
from flask_admin.model.fields import AjaxSelectField, AjaxSelectMultipleField from flask_admin.model.fields import AjaxSelectField, AjaxSelectMultipleField
from flask_admin.model.helpers import prettify_name from flask_admin.model.helpers import prettify_name
from flask_admin._backwards import get_property from flask_admin._backwards import get_property
from flask_admin._compat import iteritems from flask_admin._compat import iteritems, text_type
from .validators import Unique from .validators import Unique
from .fields import (QuerySelectField, QuerySelectMultipleField, from .fields import (QuerySelectField, QuerySelectMultipleField,
InlineModelFormList, InlineHstoreList, HstoreForm) InlineModelFormList, InlineHstoreList, HstoreForm)
from flask_admin.model.fields import InlineFormField from flask_admin.model.fields import InlineFormField
from .tools import has_multiple_pks, filter_foreign_columns from .tools import (has_multiple_pks, filter_foreign_columns,
get_field_with_path)
from .ajax import create_ajax_loader from .ajax import create_ajax_loader
...@@ -406,28 +407,26 @@ def get_form(model, converter, ...@@ -406,28 +407,26 @@ def get_form(model, converter,
properties = ((p.key, p) for p in mapper.iterate_properties) properties = ((p.key, p) for p in mapper.iterate_properties)
if only: if only:
props = dict(properties)
def find(name): def find(name):
# If field is in extra_fields, it has higher priority # If field is in extra_fields, it has higher priority
if extra_fields and name in extra_fields: if extra_fields and name in extra_fields:
return FieldPlaceholder(extra_fields[name]) return name, FieldPlaceholder(extra_fields[name])
column, path = get_field_with_path(model, name)
# Try to look it up in properties list first if path and not hasattr(column.prop, 'direction'):
p = props.get(name) raise Exception("form column is located in another table and "
"requires inline_models: {0}".format(name))
if p is not None: name = column.key
return p
# If it is hybrid property or alias, look it up in a model itself if column is not None and hasattr(column, 'property'):
p = getattr(model, name, None) return name, column.property
if p is not None and hasattr(p, 'property'):
return p.property
raise ValueError('Invalid model property name %s.%s' % (model, name)) raise ValueError('Invalid model property name %s.%s' % (model, name))
# Filter properties while maintaining property order in 'only' list # Filter properties while maintaining property order in 'only' list
properties = ((x, find(x)) for x in only) properties = (find(x) for x in only)
elif exclude: elif exclude:
properties = (x for x in properties if x[0] not in exclude) properties = (x for x in properties if x[0] not in exclude)
......
from sqlalchemy import tuple_, or_, and_ from sqlalchemy import tuple_, or_, and_
from sqlalchemy.sql.operators import eq from sqlalchemy.sql.operators import eq
from sqlalchemy.exc import DBAPIError from sqlalchemy.exc import DBAPIError
from ast import literal_eval from sqlalchemy.orm.attributes import InstrumentedAttribute
from flask_admin._compat import filter_list from flask_admin._compat import filter_list, string_types
from flask_admin.tools import iterencode, iterdecode, escape from flask_admin.tools import iterencode, iterdecode, escape
...@@ -108,3 +108,65 @@ def get_query_for_ids(modelquery, model, ids): ...@@ -108,3 +108,65 @@ def get_query_for_ids(modelquery, model, ids):
query = modelquery.filter(model_pk.in_(ids)) query = modelquery.filter(model_pk.in_(ids))
return query return query
def get_columns_for_field(field):
if (not field or
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def need_join(model, table):
"""
Check if join to a table is necessary.
"""
return table not in model._sa_class_manager.mapper.tables
def get_field_with_path(model, name):
"""
Resolve property by name and figure out its join path.
Join path might contain both properties and tables.
"""
path = []
# For strings, resolve path
if isinstance(name, string_types):
# create a copy to keep original model as `model`
current_model = model
for attribute in name.split('.'):
value = getattr(current_model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
current_model = value.property.mapper.class_
table = current_model.__table__
if need_join(model, table):
path.append(value)
attr = value
else:
attr = name
# Determine joins if table.column (relation object) is provided
if isinstance(attr, InstrumentedAttribute):
columns = get_columns_for_field(attr)
if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name)
column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if need_join(model, column.table):
path.append(column.table)
return attr, path
...@@ -21,7 +21,6 @@ from flask_admin._backwards import ObsoleteAttr ...@@ -21,7 +21,6 @@ from flask_admin._backwards import ObsoleteAttr
from flask_admin.contrib.sqla import form, filters as sqla_filters, tools from flask_admin.contrib.sqla import form, filters as sqla_filters, tools
from .typefmt import DEFAULT_FORMATTERS from .typefmt import DEFAULT_FORMATTERS
from .tools import get_query_for_ids
from .ajax import create_ajax_loader from .ajax import create_ajax_loader
# Set up logger # Set up logger
...@@ -338,64 +337,6 @@ class ModelView(BaseModelView): ...@@ -338,64 +337,6 @@ class ModelView(BaseModelView):
return model._sa_class_manager.mapper.iterate_properties return model._sa_class_manager.mapper.iterate_properties
def _get_columns_for_field(self, field):
if (not field or
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
raise Exception('Invalid field %s: does not contains any columns.' % field)
return field.property.columns
def _get_field_with_path(self, name):
"""
Resolve property by name and figure out its join path.
Join path might contain both properties and tables.
"""
path = []
model = self.model
# For strings, resolve path
if isinstance(name, string_types):
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
path.append(value)
attr = value
else:
attr = name
# Determine joins if table.column (relation object) is provided
if isinstance(attr, InstrumentedAttribute):
columns = self._get_columns_for_field(attr)
if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name)
column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if self._need_join(column.table):
path.append(column.table)
return attr, path
def _need_join(self, table):
"""
Check if join to a table is necessary.
"""
return table not in self.model._sa_class_manager.mapper.tables
def _apply_path_joins(self, query, joins, path, inner_join=True): def _apply_path_joins(self, query, joins, path, inner_join=True):
""" """
Apply join path to the query. Apply join path to the query.
...@@ -528,13 +469,13 @@ class ModelView(BaseModelView): ...@@ -528,13 +469,13 @@ class ModelView(BaseModelView):
for c in self.column_sortable_list: for c in self.column_sortable_list:
if isinstance(c, tuple): if isinstance(c, tuple):
column, path = self._get_field_with_path(c[1]) column, path = tools.get_field_with_path(self.model, c[1])
column_name = c[0] column_name = c[0]
elif isinstance(c, InstrumentedAttribute): elif isinstance(c, InstrumentedAttribute):
column, path = self._get_field_with_path(c) column, path = tools.get_field_with_path(self.model, c)
column_name = str(c) column_name = str(c)
else: else:
column, path = self._get_field_with_path(c) column, path = tools.get_field_with_path(self.model, c)
column_name = c column_name = c
result[column_name] = column result[column_name] = column
...@@ -556,12 +497,12 @@ class ModelView(BaseModelView): ...@@ -556,12 +497,12 @@ class ModelView(BaseModelView):
self._search_fields = [] self._search_fields = []
for p in self.column_searchable_list: for p in self.column_searchable_list:
attr, joins = self._get_field_with_path(p) attr, joins = tools.get_field_with_path(self.model, p)
if not attr: if not attr:
raise Exception('Failed to find field for search field: %s' % p) raise Exception('Failed to find field for search field: %s' % p)
for column in self._get_columns_for_field(attr): for column in tools.get_columns_for_field(attr):
self._search_fields.append((column, joins)) self._search_fields.append((column, joins))
return bool(self.column_searchable_list) return bool(self.column_searchable_list)
...@@ -571,7 +512,7 @@ class ModelView(BaseModelView): ...@@ -571,7 +512,7 @@ class ModelView(BaseModelView):
Return list of enabled filters Return list of enabled filters
""" """
attr, joins = self._get_field_with_path(name) attr, joins = tools.get_field_with_path(self.model, name)
if attr is None: if attr is None:
raise Exception('Failed to find field for filter: %s' % name) raise Exception('Failed to find field for filter: %s' % name)
...@@ -604,21 +545,22 @@ class ModelView(BaseModelView): ...@@ -604,21 +545,22 @@ class ModelView(BaseModelView):
if joins: if joins:
self._filter_joins[column] = joins self._filter_joins[column] = joins
elif self._need_join(table): elif tools.need_join(self.model, table):
self._filter_joins[column] = [table] self._filter_joins[column] = [table]
filters.extend(flt) filters.extend(flt)
return filters return filters
else: else:
columns = self._get_columns_for_field(attr) columns = tools.get_columns_for_field(attr)
if len(columns) > 1: if len(columns) > 1:
raise Exception('Can not filter more than on one column for %s' % name) raise Exception('Can not filter more than on one column for %s' % name)
column = columns[0] column = columns[0]
if self._need_join(column.table) and name not in self.column_labels: if (tools.need_join(self.model, column.table) and
name not in self.column_labels):
visible_name = '%s / %s' % ( visible_name = '%s / %s' % (
self.get_column_name(column.table.name), self.get_column_name(column.table.name),
self.get_column_name(column.name) self.get_column_name(column.name)
...@@ -640,7 +582,7 @@ class ModelView(BaseModelView): ...@@ -640,7 +582,7 @@ class ModelView(BaseModelView):
if joins: if joins:
self._filter_joins[column] = joins self._filter_joins[column] = joins
elif self._need_join(column.table): elif tools.need_join(self.model, column.table):
self._filter_joins[column] = [column.table] self._filter_joins[column] = [column.table]
return flt return flt
...@@ -651,7 +593,7 @@ class ModelView(BaseModelView): ...@@ -651,7 +593,7 @@ class ModelView(BaseModelView):
# hybrid_property joins are not supported yet # hybrid_property joins are not supported yet
if (isinstance(column, InstrumentedAttribute) and if (isinstance(column, InstrumentedAttribute) and
self._need_join(column.table)): tools.need_join(self.model, column.table)):
self._filter_joins[column] = [column.table] self._filter_joins[column] = [column.table]
return filter return filter
...@@ -802,7 +744,7 @@ class ModelView(BaseModelView): ...@@ -802,7 +744,7 @@ class ModelView(BaseModelView):
if order is not None: if order is not None:
field, direction = order field, direction = order
attr, joins = self._get_field_with_path(field) attr, joins = tools.get_field_with_path(self.model, field)
return attr, joins, direction return attr, joins, direction
...@@ -1100,7 +1042,7 @@ class ModelView(BaseModelView): ...@@ -1100,7 +1042,7 @@ class ModelView(BaseModelView):
lazy_gettext('Are you sure you want to delete selected records?')) lazy_gettext('Are you sure you want to delete selected records?'))
def action_delete(self, ids): def action_delete(self, ids):
try: try:
query = get_query_for_ids(self.get_query(), self.model, ids) query = tools.get_query_for_ids(self.get_query(), self.model, ids)
if self.fast_mass_delete: if self.fast_mass_delete:
count = query.delete(synchronize_session=False) count = query.delete(synchronize_session=False)
......
...@@ -5,7 +5,7 @@ from wtforms import fields, validators ...@@ -5,7 +5,7 @@ from wtforms import fields, validators
from flask_admin import form from flask_admin import form
from flask_admin._compat import as_unicode from flask_admin._compat import as_unicode
from flask_admin._compat import iteritems from flask_admin._compat import iteritems
from flask_admin.contrib.sqla import ModelView, filters from flask_admin.contrib.sqla import ModelView, filters, tools
from flask_babelex import Babel from flask_babelex import Babel
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
...@@ -1444,6 +1444,22 @@ def test_form_columns(): ...@@ -1444,6 +1444,22 @@ def test_form_columns():
ok_(type(form3.model).__name__ == 'QuerySelectField') ok_(type(form3.model).__name__ == 'QuerySelectField')
# test form_columns with model objects
view4 = CustomModelView(Model, db.session, endpoint='view1',
form_columns=[Model.int_field])
form4 = view4.create_form()
ok_('int_field' in form4._fields)
@raises(Exception)
def test_complex_form_columns():
app, db, admin = setup()
M1, M2 = create_models(db)
# test using a form column in another table
view = CustomModelView(M2, db.session, form_columns=['model1.test1'])
form = view.create_form()
def test_form_args(): def test_form_args():
app, db, admin = setup() app, db, admin = setup()
...@@ -1964,15 +1980,15 @@ def test_advanced_joins(): ...@@ -1964,15 +1980,15 @@ def test_advanced_joins():
admin.add_view(view3) admin.add_view(view3)
# Test joins # Test joins
attr, path = view2._get_field_with_path('model1.val1') attr, path = tools.get_field_with_path(Model2, 'model1.val1')
eq_(attr, Model1.val1) eq_(attr, Model1.val1)
eq_(path, [Model2.model1]) eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2') attr, path = tools.get_field_with_path(Model1, 'model2.val2')
eq_(attr, Model2.val2) eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2)) eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1') attr, path = tools.get_field_with_path(Model3, 'model2.model1.val1')
eq_(attr, Model1.val1) eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1]) eq_(path, [Model3.model2, Model2.model1])
...@@ -1986,7 +2002,7 @@ def test_advanced_joins(): ...@@ -1986,7 +2002,7 @@ def test_advanced_joins():
ok_(alias is not None) ok_(alias is not None)
# Check if another join would use same path # Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test') attr, path = tools.get_field_with_path(Model2, 'model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path) q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2) eq_(len(joins), 2)
...@@ -1995,8 +2011,8 @@ def test_advanced_joins(): ...@@ -1995,8 +2011,8 @@ def test_advanced_joins():
ok_(alias is not None) ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path # Check if normal properties are supported by tools.get_field_with_path
attr, path = view2._get_field_with_path(Model1.test) attr, path = tools.get_field_with_path(Model2, Model1.test)
eq_(attr, Model1.test) eq_(attr, Model1.test)
eq_(path, [Model1.__table__]) eq_(path, [Model1.__table__])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment