Commit df9a3dd1 authored by Tom Kedem's avatar Tom Kedem

Added support for sqla association proxy in model view.

parent 391bdc8e
Example of how to use (and filter on) an association proxy with the SQLAlchemy backend.
For information about association proxies and how to use them, please visit:
http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html
To run this example:
1. Clone the repository::
git clone https://github.com/flask-admin/flask-admin.git
cd flask-admin
2. Create and activate a virtual environment::
virtualenv env
source env/bin/activate
3. Install requirements::
pip install -r 'examples/sqla-association_proxy/requirements.txt'
4. Run the application::
python examples/sqla-association_proxy/app.py
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship, backref
import flask_admin as admin
from flask_admin.contrib import sqla
# Create application
app = Flask(__name__)
# Create dummy secrey key so we can use sessions
app.config['SECRET_KEY'] = '123456790'
# Create in-memory database
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://'
app.config['SQLALCHEMY_ECHO'] = True
db = SQLAlchemy(app)
# Flask views
@app.route('/')
def index():
return '<a href="/admin/">Click me to get to Admin!</a>'
class User(db.Model):
__tablename__ = 'user'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(64))
# Association proxy of "user_keywords" collection to "keyword" attribute - a list of keywords objects.
keywords = association_proxy('user_keywords', 'keyword')
# Association proxy to association proxy - a list of keywords strings.
keywords_values = association_proxy('user_keywords', 'keyword_value')
def __init__(self, name=None):
self.name = name
class UserKeyword(db.Model):
__tablename__ = 'user_keyword'
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), primary_key=True)
keyword_id = db.Column(db.Integer, db.ForeignKey('keyword.id'), primary_key=True)
special_key = db.Column(db.String(50))
# bidirectional attribute/collection of "user"/"user_keywords"
user = relationship(User, backref=backref("user_keywords", cascade="all, delete-orphan"))
# reference to the "Keyword" object
keyword = relationship("Keyword")
# Reference to the "keyword" column inside the "Keyword" object.
keyword_value = association_proxy('keyword', 'keyword')
def __init__(self, keyword=None, user=None, special_key=None):
self.user = user
self.keyword = keyword
self.special_key = special_key
class Keyword(db.Model):
__tablename__ = 'keyword'
id = db.Column(db.Integer, primary_key=True)
keyword = db.Column('keyword', db.String(64))
def __init__(self, keyword=None):
self.keyword = keyword
def __repr__(self):
return 'Keyword(%s)' % repr(self.keyword)
class UserAdmin(sqla.ModelView):
# Support for association proxies to association proxies (e.g.: keywords_values) is currently limited
# to column_list only.
column_list = ('id', 'name', 'keywords', 'keywords_values')
column_sortable_list = ('id', 'name')
column_filters = ('id', 'name', 'keywords')
form_columns = ('name', 'keywords')
class KeywordAdmin(sqla.ModelView):
column_list = ('id', 'keyword')
# Create admin
admin = admin.Admin(app, name='Example: SQLAlchemy Association Proxy', template_mode='bootstrap3')
admin.add_view(UserAdmin(User, db.session))
admin.add_view(KeywordAdmin(Keyword, db.session))
if __name__ == '__main__':
# Create DB
db.create_all()
# Add sample data
user = User('log')
for kw in (Keyword('new_from_blammo'), Keyword('its_big')):
user.keywords.append(kw)
db.session.add(user)
db.session.commit()
# Start app
app.run(debug=True)
Flask
Flask-Admin
Flask-SQLAlchemy
...@@ -3,7 +3,7 @@ from sqlalchemy import or_ ...@@ -3,7 +3,7 @@ from sqlalchemy import or_
from flask_admin._compat import as_unicode, string_types from flask_admin._compat import as_unicode, string_types
from flask_admin.model.ajax import AjaxModelLoader, DEFAULT_PAGE_SIZE from flask_admin.model.ajax import AjaxModelLoader, DEFAULT_PAGE_SIZE
from .tools import get_primary_key, has_multiple_pks from .tools import get_primary_key, has_multiple_pks, is_relationship, is_association_proxy
class QueryAjaxModelLoader(AjaxModelLoader): class QueryAjaxModelLoader(AjaxModelLoader):
...@@ -75,8 +75,11 @@ def create_ajax_loader(model, session, name, field_name, options): ...@@ -75,8 +75,11 @@ def create_ajax_loader(model, session, name, field_name, options):
if attr is None: if attr is None:
raise ValueError('Model %s does not have field %s.' % (model, field_name)) raise ValueError('Model %s does not have field %s.' % (model, field_name))
if not hasattr(attr, 'property') or not hasattr(attr.property, 'direction'): if not is_relationship(attr) and not is_association_proxy(attr):
raise ValueError('%s.%s is not a relation.' % (model, field_name)) raise ValueError('%s.%s is not a relation.' % (model, field_name))
if is_association_proxy(attr):
attr = attr.remote_attr
remote_model = attr.prop.mapper.class_ remote_model = attr.prop.mapper.class_
return QueryAjaxModelLoader(name, session, remote_model, **options) return QueryAjaxModelLoader(name, session, remote_model, **options)
...@@ -16,7 +16,7 @@ from .fields import (QuerySelectField, QuerySelectMultipleField, ...@@ -16,7 +16,7 @@ from .fields import (QuerySelectField, QuerySelectMultipleField,
InlineModelFormList, InlineHstoreList, HstoreForm) InlineModelFormList, InlineHstoreList, HstoreForm)
from flask_admin.model.fields import InlineFormField from flask_admin.model.fields import InlineFormField
from .tools import (has_multiple_pks, filter_foreign_columns, from .tools import (has_multiple_pks, filter_foreign_columns,
get_field_with_path) get_field_with_path, is_association_proxy, is_relationship)
from .ajax import create_ajax_loader from .ajax import create_ajax_loader
...@@ -86,10 +86,10 @@ class AdminModelConverter(ModelConverterBase): ...@@ -86,10 +86,10 @@ class AdminModelConverter(ModelConverterBase):
else: else:
return QuerySelectField(**kwargs) return QuerySelectField(**kwargs)
def _convert_relation(self, prop, kwargs): def _convert_relation(self, name, prop, property_is_association_proxy, kwargs):
# Check if relation is specified # Check if relation is specified
form_columns = getattr(self.view, 'form_columns', None) form_columns = getattr(self.view, 'form_columns', None)
if form_columns and prop.key not in form_columns: if form_columns and name not in form_columns:
return None return None
remote_model = prop.mapper.class_ remote_model = prop.mapper.class_
...@@ -100,13 +100,13 @@ class AdminModelConverter(ModelConverterBase): ...@@ -100,13 +100,13 @@ class AdminModelConverter(ModelConverterBase):
if not column.foreign_keys: if not column.foreign_keys:
column = prop.local_remote_pairs[0][1] column = prop.local_remote_pairs[0][1]
kwargs['label'] = self._get_label(prop.key, kwargs) kwargs['label'] = self._get_label(name, kwargs)
kwargs['description'] = self._get_description(prop.key, kwargs) kwargs['description'] = self._get_description(name, kwargs)
# determine optional/required, or respect existing # determine optional/required, or respect existing
requirement_options = (validators.Optional, validators.InputRequired) requirement_options = (validators.Optional, validators.InputRequired)
if not any(isinstance(v, requirement_options) for v in kwargs['validators']): if not any(isinstance(v, requirement_options) for v in kwargs['validators']):
if column.nullable or prop.direction.name != 'MANYTOONE': if property_is_association_proxy or column.nullable or prop.direction.name != 'MANYTOONE':
kwargs['validators'].append(validators.Optional()) kwargs['validators'].append(validators.Optional())
else: else:
kwargs['validators'].append(validators.InputRequired()) kwargs['validators'].append(validators.InputRequired())
...@@ -120,14 +120,11 @@ class AdminModelConverter(ModelConverterBase): ...@@ -120,14 +120,11 @@ class AdminModelConverter(ModelConverterBase):
if override: if override:
return override(**kwargs) return override(**kwargs)
if prop.direction.name == 'MANYTOONE' or not prop.uselist: multiple = (property_is_association_proxy or
return self._model_select_field(prop, False, remote_model, **kwargs) (prop.direction.name in ('ONETOMANY', 'MANYTOMANY') and prop.uselist))
elif prop.direction.name == 'ONETOMANY': return self._model_select_field(prop, multiple, remote_model, **kwargs)
return self._model_select_field(prop, True, remote_model, **kwargs)
elif prop.direction.name == 'MANYTOMANY':
return self._model_select_field(prop, True, remote_model, **kwargs)
def convert(self, model, mapper, prop, field_args, hidden_pk): def convert(self, model, mapper, name, prop, field_args, hidden_pk):
# Properly handle forced fields # Properly handle forced fields
if isinstance(prop, FieldPlaceholder): if isinstance(prop, FieldPlaceholder):
return form.recreate_field(prop.field) return form.recreate_field(prop.field)
...@@ -145,8 +142,13 @@ class AdminModelConverter(ModelConverterBase): ...@@ -145,8 +142,13 @@ class AdminModelConverter(ModelConverterBase):
kwargs['validators'] = list(kwargs['validators']) kwargs['validators'] = list(kwargs['validators'])
# Check if it is relation or property # Check if it is relation or property
if hasattr(prop, 'direction'): if hasattr(prop, 'direction') or is_association_proxy(prop):
return self._convert_relation(prop, kwargs) property_is_association_proxy = is_association_proxy(prop)
if property_is_association_proxy:
if not hasattr(prop.remote_attr, 'prop'):
raise Exception('Association proxy referencing another association proxy is not supported.')
prop = prop.remote_attr.prop
return self._convert_relation(name, prop, property_is_association_proxy, kwargs)
elif hasattr(prop, 'columns'): # Ignore pk/fk elif hasattr(prop, 'columns'): # Ignore pk/fk
# Check if more than one column mapped to the property # Check if more than one column mapped to the property
if len(prop.columns) > 1: if len(prop.columns) > 1:
...@@ -414,16 +416,19 @@ def get_form(model, converter, ...@@ -414,16 +416,19 @@ def get_form(model, converter,
if extra_fields and name in extra_fields: if extra_fields and name in extra_fields:
return name, FieldPlaceholder(extra_fields[name]) return name, FieldPlaceholder(extra_fields[name])
column, path = get_field_with_path(model, name) column, path = get_field_with_path(model, name, return_remote_proxy_attr=False)
if path and not hasattr(column.prop, 'direction'): if path and not (is_relationship(column) or is_association_proxy(column)):
raise Exception("form column is located in another table and " raise Exception("form column is located in another table and "
"requires inline_models: {0}".format(name)) "requires inline_models: {0}".format(name))
name = column.key if is_association_proxy(column):
return name, column
relation_name = column.key
if column is not None and hasattr(column, 'property'): if column is not None and hasattr(column, 'property'):
return name, column.property return relation_name, column.property
raise ValueError('Invalid model property name %s.%s' % (model, name)) raise ValueError('Invalid model property name %s.%s' % (model, name))
...@@ -440,7 +445,7 @@ def get_form(model, converter, ...@@ -440,7 +445,7 @@ def get_form(model, converter,
prop = _resolve_prop(p) prop = _resolve_prop(p)
field = converter.convert(model, mapper, prop, field_args.get(name), hidden_pk) field = converter.convert(model, mapper, name, prop, field_args.get(name), hidden_pk)
if field is not None: if field is not None:
field_dict[name] = field field_dict[name] = field
......
from sqlalchemy import tuple_, or_, and_, inspect from sqlalchemy import tuple_, or_, and_, inspect
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.ext.associationproxy import ASSOCIATION_PROXY
from sqlalchemy.sql.operators import eq from sqlalchemy.sql.operators import eq
from sqlalchemy.exc import DBAPIError from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
...@@ -128,7 +129,7 @@ def need_join(model, table): ...@@ -128,7 +129,7 @@ def need_join(model, table):
return table not in model._sa_class_manager.mapper.tables return table not in model._sa_class_manager.mapper.tables
def get_field_with_path(model, name): def get_field_with_path(model, name, return_remote_proxy_attr=True):
""" """
Resolve property by name and figure out its join path. Resolve property by name and figure out its join path.
...@@ -141,24 +142,30 @@ def get_field_with_path(model, name): ...@@ -141,24 +142,30 @@ def get_field_with_path(model, name):
# create a copy to keep original model as `model` # create a copy to keep original model as `model`
current_model = model current_model = model
value = None
for attribute in name.split('.'): for attribute in name.split('.'):
value = getattr(current_model, attribute) value = getattr(current_model, attribute)
if (hasattr(value, 'property') and if is_association_proxy(value):
hasattr(value.property, 'direction')): relation_values = value.attr
current_model = value.property.mapper.class_ if return_remote_proxy_attr:
value = value.remote_attr
table = current_model.__table__ else:
relation_values = [value]
if need_join(model, table):
path.append(value) for relation_value in relation_values:
if is_relationship(relation_value):
attr = value current_model = relation_value.property.mapper.class_
table = current_model.__table__
if need_join(model, table):
path.append(relation_value)
attr = value
else: else:
attr = name attr = name
# Determine joins if table.column (relation object) is provided # Determine joins if table.column (relation object) is provided
if isinstance(attr, InstrumentedAttribute): if isinstance(attr, InstrumentedAttribute) or is_association_proxy(attr):
columns = get_columns_for_field(attr) columns = get_columns_for_field(attr)
if len(columns) > 1: if len(columns) > 1:
...@@ -184,3 +191,17 @@ def get_hybrid_properties(model): ...@@ -184,3 +191,17 @@ def get_hybrid_properties(model):
def is_hybrid_property(model, attr_name): def is_hybrid_property(model, attr_name):
return attr_name in get_hybrid_properties(model) return attr_name in get_hybrid_properties(model)
def is_relationship(attr):
return hasattr(attr, 'property') and hasattr(attr.property, 'direction')
def is_association_proxy(attr):
return hasattr(attr, 'extension_type') and attr.extension_type == ASSOCIATION_PROXY
def get_association_proxy_column_name(attr):
# TODO find a better way to get the name
name, = [key for key, value in inspect(attr.owning_class).all_orm_descriptors.items() if value is attr]
return name
from sqlalchemy.ext.associationproxy import _AssociationList
from flask_admin.model.typefmt import BASE_FORMATTERS, list_formatter from flask_admin.model.typefmt import BASE_FORMATTERS, list_formatter
from sqlalchemy.orm.collections import InstrumentedList from sqlalchemy.orm.collections import InstrumentedList
DEFAULT_FORMATTERS = BASE_FORMATTERS.copy() DEFAULT_FORMATTERS = BASE_FORMATTERS.copy()
DEFAULT_FORMATTERS.update({ DEFAULT_FORMATTERS.update({
InstrumentedList: list_formatter InstrumentedList: list_formatter,
_AssociationList: list_formatter
}) })
...@@ -14,6 +14,7 @@ from flask import current_app, flash ...@@ -14,6 +14,7 @@ from flask import current_app, flash
from flask_admin._compat import string_types, text_type from flask_admin._compat import string_types, text_type
from flask_admin.babel import gettext, ngettext, lazy_gettext from flask_admin.babel import gettext, ngettext, lazy_gettext
from flask_admin.contrib.sqla.tools import is_relationship
from flask_admin.model import BaseModelView from flask_admin.model import BaseModelView
from flask_admin.model.form import create_editable_list_form from flask_admin.model.form import create_editable_list_form
from flask_admin.actions import action from flask_admin.actions import action
...@@ -572,7 +573,7 @@ class ModelView(BaseModelView): ...@@ -572,7 +573,7 @@ class ModelView(BaseModelView):
raise Exception('Failed to find field for filter: %s' % name) raise Exception('Failed to find field for filter: %s' % name)
# Figure out filters for related column # Figure out filters for related column
if hasattr(attr, 'property') and hasattr(attr.property, 'direction'): if is_relationship(attr):
filters = [] filters = []
for p in self._get_model_iterator(attr.property.mapper.class_): for p in self._get_model_iterator(attr.property.mapper.class_):
...@@ -791,12 +792,12 @@ class ModelView(BaseModelView): ...@@ -791,12 +792,12 @@ class ModelView(BaseModelView):
if isinstance(column, tuple): if isinstance(column, tuple):
query = query.order_by(*map(desc, column)) query = query.order_by(*map(desc, column))
else: else:
query = query.order_by(desc(column)) query = query.order_by(desc(column))
else: else:
if isinstance(column, tuple): if isinstance(column, tuple):
query = query.order_by(*column) query = query.order_by(*column)
else: else:
query = query.order_by(column) query = query.order_by(column)
return query, joins return query, joins
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment