Commit 845f6588 authored by Serge S. Koval's avatar Serge S. Koval

Fixed #457, generally improved multi-PK models handling and inherited models

parent e97707e4
...@@ -11,7 +11,7 @@ from flask.ext.admin._compat import iteritems ...@@ -11,7 +11,7 @@ from flask.ext.admin._compat import iteritems
from .validators import Unique from .validators import Unique
from .fields import QuerySelectField, QuerySelectMultipleField, InlineModelFormList from .fields import QuerySelectField, QuerySelectMultipleField, InlineModelFormList
from .tools import is_inherited_primary_key, get_column_for_current_model, has_multiple_pks from .tools import has_multiple_pks, filter_foreign_columns
from .ajax import create_ajax_loader from .ajax import create_ajax_loader
try: try:
...@@ -151,11 +151,13 @@ class AdminModelConverter(ModelConverterBase): ...@@ -151,11 +151,13 @@ class AdminModelConverter(ModelConverterBase):
# Ignore pk/fk # Ignore pk/fk
if hasattr(prop, 'columns'): if hasattr(prop, 'columns'):
# Check if more than one column mapped to the property # Check if more than one column mapped to the property
if len(prop.columns) != 1: if len(prop.columns) > 1:
if is_inherited_primary_key(prop): columns = filter_foreign_columns(model.__table__, prop.columns)
column = get_column_for_current_model(prop)
else: if len(columns) > 1:
raise TypeError('Can not convert multiple-column properties (%s.%s)' % (model, prop.key)) raise TypeError('Can not convert multiple-column properties (%s.%s)' % (model, prop.key))
column = columns[0]
else: else:
# Grab column # Grab column
column = prop.columns[0] column = prop.columns[0]
......
...@@ -3,6 +3,9 @@ from sqlalchemy.sql.operators import eq ...@@ -3,6 +3,9 @@ from sqlalchemy.sql.operators import eq
from sqlalchemy.exc import DBAPIError from sqlalchemy.exc import DBAPIError
from ast import literal_eval from ast import literal_eval
from flask.ext.admin.tools import iterencode, iterdecode
def parse_like_term(term): def parse_like_term(term):
if term.startswith('^'): if term.startswith('^'):
stmt = '%s%%' % term[1:] stmt = '%s%%' % term[1:]
...@@ -14,6 +17,16 @@ def parse_like_term(term): ...@@ -14,6 +17,16 @@ def parse_like_term(term):
return stmt return stmt
def filter_foreign_columns(base_table, columns):
"""
Return list of columns that belong to passed table.
:param base_table: Table to check against
:param columns: List of columns to filter
"""
return filter(lambda c: c.table == base_table, columns)
def get_primary_key(model): def get_primary_key(model):
""" """
Return primary key name from a model. If the primary key consists of multiple columns, Return primary key name from a model. If the primary key consists of multiple columns,
...@@ -26,18 +39,11 @@ def get_primary_key(model): ...@@ -26,18 +39,11 @@ def get_primary_key(model):
pks = [] pks = []
for p in props: for p in props:
if hasattr(p, 'expression'): # expression = primary column or expression for this ColumnProperty if hasattr(p, 'columns'):
if p.expression.primary_key: for c in filter_foreign_columns(model.__table__, p.columns):
if is_inherited_primary_key(p): if c.primary_key:
pks.append(get_column_for_current_model(p).key)
else:
pks.append(p.key) pks.append(p.key)
else: break
if hasattr(p, 'columns'):
for c in p.columns:
if c.primary_key:
pks.append(p.key)
break
if len(pks) == 1: if len(pks) == 1:
return pks[0] return pks[0]
...@@ -46,54 +52,16 @@ def get_primary_key(model): ...@@ -46,54 +52,16 @@ def get_primary_key(model):
else: else:
return None return None
def is_inherited_primary_key(prop):
"""
Return True, if the ColumnProperty is an inherited primary key
Check if all columns are primary keys and _one_ does not have a foreign key -> looks like joined
table inheritance: http://docs.sqlalchemy.org/en/latest/orm/inheritance.html with "standard
practice" of same column name.
:param prop: The ColumnProperty to check
:return: Boolean
:raises: Exceptions as they occur - no ExceptionHandling here
"""
if not hasattr(prop, 'expression'):
return False
if prop.expression.primary_key:
return len(prop._orig_columns) == len(prop.columns)-1
return False
def get_column_for_current_model(prop):
"""
Return the Column() of the ColumnProperty "prop", that refers to the current model
When using inheritance, a ColumnProperty may contain multiple columns. This function
returns the Column(), the belongs to the Model of the ColumnProperty - the "current"
model
:param prop: The ColumnProperty
:return: The column for the current model
:raises: TypeError if not exactely one Column() for the current model could be found.
All other Exceptions not handled here but raised
"""
candidates = [column for column in prop.columns if column.expression == prop.expression]
if len(candidates) != 1:
raise TypeError('Not exactly one column for the current model found. ' +
'Found %d columns for property %s' % (len(candidates), prop))
else:
return candidates[0]
def has_multiple_pks(model): def has_multiple_pks(model):
"""Return True, if the model has more than one primary key """
Return True, if the model has more than one primary key
""" """
if not hasattr(model, '_sa_class_manager'): if not hasattr(model, '_sa_class_manager'):
raise TypeError('model must be a sqlalchemy mapped model') raise TypeError('model must be a sqlalchemy mapped model')
pks = model._sa_class_manager.mapper.primary_key
return len(pks) > 1 return len(model._sa_class_manager.mapper.primary_key) > 1
def tuple_operator_in(model_pk, ids): def tuple_operator_in(model_pk, ids):
"""The tuple_ Operator only works on certain engines like MySQL or Postgresql. It does not work with sqlite. """The tuple_ Operator only works on certain engines like MySQL or Postgresql. It does not work with sqlite.
...@@ -123,36 +91,27 @@ def tuple_operator_in(model_pk, ids): ...@@ -123,36 +91,27 @@ def tuple_operator_in(model_pk, ids):
def get_query_for_ids(modelquery, model, ids): def get_query_for_ids(modelquery, model, ids):
""" """
Return a query object, that contains all entities of the given model for Return a query object filtered by primary key values passed in `ids` argument.
the primary keys provided in the ids-parameter.
Unfortunately, it is not possible to use `in_` filter if model has more than one
The ``pks`` parameter is a tuple, that contains the different primary key values, primary key.
that should be returned. If the primary key of the model consists of multiple columns
every entry of the ``pks`` parameter must be a tuple containing the columns-values in the
correct order, that make up the primary key of the model
If the model has multiple primary keys, the
`tuple_ <http://docs.sqlalchemy.org/en/latest/core/expression_api.html#sqlalchemy.sql.expression.tuple_>`_
operator will be used. As this operator does not work on certain databases,
notably on sqlite, a workaround function :func:`tuple_operator_in` is provided
that implements the same logic using OR and AND operations.
When having multiple primary keys, the pks are provided as a list of tuple-look-alike-strings,
``[u'(1, 2)', u'(1, 1)']``. These needs to be evaluated into real tuples, where
`Stackoverflow Question 3945856 <http://stackoverflow.com/questions/3945856/converting-string-to-tuple-and-adding-to-tuple>`_
pointed to `Literal Eval <http://docs.python.org/2/library/ast.html#ast.literal_eval>`_, which is now used.
""" """
if has_multiple_pks(model): if has_multiple_pks(model):
model_pk = [getattr(model, pk_name).expression for pk_name in get_primary_key(model)] # Decode keys to tuples
ids = [literal_eval(id) for id in ids] decoded_ids = [iterdecode(v) for v in ids]
# Get model primary key property references
model_pk = [getattr(model, name) for name in get_primary_key(model)]
try: try:
query = modelquery.filter(tuple_(*model_pk).in_(ids)) query = modelquery.filter(tuple_(*model_pk).in_(decoded_ids))
# Only the execution of the query will tell us, if the tuple_ # Only the execution of the query will tell us, if the tuple_
# operator really works # operator really works
query.all() query.all()
except DBAPIError: except DBAPIError:
query = modelquery.filter(tuple_operator_in(model_pk, ids)) query = modelquery.filter(tuple_operator_in(model_pk, decoded_ids))
else: else:
model_pk = getattr(model, get_primary_key(model)) model_pk = getattr(model, get_primary_key(model))
query = modelquery.filter(model_pk.in_(ids)) query = modelquery.filter(model_pk.in_(ids))
return query return query
...@@ -16,7 +16,7 @@ from flask.ext.admin._backwards import ObsoleteAttr ...@@ -16,7 +16,7 @@ from flask.ext.admin._backwards import ObsoleteAttr
from flask.ext.admin.contrib.sqla import form, filters, tools from flask.ext.admin.contrib.sqla import form, filters, tools
from .typefmt import DEFAULT_FORMATTERS from .typefmt import DEFAULT_FORMATTERS
from .tools import is_inherited_primary_key, get_column_for_current_model, get_query_for_ids from .tools import get_query_for_ids
from .ajax import create_ajax_loader from .ajax import create_ajax_loader
...@@ -296,23 +296,20 @@ class ModelView(BaseModelView): ...@@ -296,23 +296,20 @@ class ModelView(BaseModelView):
# Scaffolding # Scaffolding
def scaffold_pk(self): def scaffold_pk(self):
""" """
Return the primary key name from a model Return the primary key name(s) from a model
PK can be a single value or a tuple if multiple PKs exist If model has single primary key, will return a string and tuple otherwise
""" """
return tools.get_primary_key(self.model) return tools.get_primary_key(self.model)
def get_pk_value(self, model): def get_pk_value(self, model):
""" """
Return the PK value from a model object. Return the primary key value from a model object.
PK can be a single value or a tuple if multiple PKs exist If there are multiple primary keys, they're encoded into string representation.
""" """
try: if isinstance(self._primary_key, tuple):
return tools.iterencode(getattr(model, attr) for attr in self._primary_key)
else:
return getattr(model, self._primary_key) return getattr(model, self._primary_key)
except TypeError:
v = []
for attr in self._primary_key:
v.append(getattr(model, attr))
return tuple(v)
def scaffold_list_columns(self): def scaffold_list_columns(self):
""" """
...@@ -321,26 +318,21 @@ class ModelView(BaseModelView): ...@@ -321,26 +318,21 @@ class ModelView(BaseModelView):
columns = [] columns = []
for p in self._get_model_iterator(): for p in self._get_model_iterator():
# Verify type
if hasattr(p, 'direction'): if hasattr(p, 'direction'):
if self.column_display_all_relations or p.direction.name == 'MANYTOONE': if self.column_display_all_relations or p.direction.name == 'MANYTOONE':
columns.append(p.key) columns.append(p.key)
elif hasattr(p, 'columns'): elif hasattr(p, 'columns'):
column_inherited_primary_key = False if len(p.columns) > 1:
filtered = tools.filter_foreign_columns(self.model.__table__, p.columns)
if len(p.columns) != 1: if len(filtered) > 1:
if is_inherited_primary_key(p): # TODO: Skip column and issue a warning
column = get_column_for_current_model(p)
else:
raise TypeError('Can not convert multiple-column properties (%s.%s)' % (self.model, p.key)) raise TypeError('Can not convert multiple-column properties (%s.%s)' % (self.model, p.key))
column = filtered[0]
else: else:
# Grab column
column = p.columns[0] column = p.columns[0]
# An inherited primary key has a foreign key as well
if column.foreign_keys and not is_inherited_primary_key(p):
continue
if not self.column_display_pk and column.primary_key: if not self.column_display_pk and column.primary_key:
continue continue
...@@ -783,7 +775,7 @@ class ModelView(BaseModelView): ...@@ -783,7 +775,7 @@ class ModelView(BaseModelView):
:param id: :param id:
Model id Model id
""" """
return self.session.query(self.model).get(id) return self.session.query(self.model).get(tools.iterdecode(id))
# Error handler # Error handler
def handle_view_exception(self, exc): def handle_view_exception(self, exc):
...@@ -882,7 +874,6 @@ class ModelView(BaseModelView): ...@@ -882,7 +874,6 @@ class ModelView(BaseModelView):
lazy_gettext('Are you sure you want to delete selected models?')) lazy_gettext('Are you sure you want to delete selected models?'))
def action_delete(self, ids): def action_delete(self, ids):
try: try:
query = get_query_for_ids(self.get_query(), self.model, ids) query = get_query_for_ids(self.get_query(), self.model, ids)
if self.fast_mass_delete: if self.fast_mass_delete:
......
...@@ -8,7 +8,7 @@ def setup(): ...@@ -8,7 +8,7 @@ def setup():
app.config['SECRET_KEY'] = '1' app.config['SECRET_KEY'] = '1'
app.config['CSRF_ENABLED'] = False app.config['CSRF_ENABLED'] = False
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
#app.config['SQLALCHEMY_ECHO'] = True app.config['SQLALCHEMY_ECHO'] = True
db = SQLAlchemy(app) db = SQLAlchemy(app)
admin = Admin(app) admin = Admin(app)
......
...@@ -468,42 +468,6 @@ def test_non_int_pk(): ...@@ -468,42 +468,6 @@ def test_non_int_pk():
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('test2' in data) ok_('test2' in data)
def test_multiple__pk():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Model(db.Model):
id = db.Column(db.Integer, primary_key=True)
id2 = db.Column(db.String(20), primary_key=True)
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Model, db.session, form_columns=['id', 'id2', 'test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
rv = client.post('/admin/model/new/',
data=dict(id=1, id2='two', test='test3'))
eq_(rv.status_code, 302)
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
rv = client.get('/admin/model/edit/?id=1&id=two')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
# Correct order is mandatory -> fail here
rv = client.get('/admin/model/edit/?id=two&id=1')
eq_(rv.status_code, 302)
def test_form_columns(): def test_form_columns():
app, db, admin = setup() app, db, admin = setup()
......
from nose.tools import eq_, ok_
from . import setup
from .test_basic import CustomModelView
from flask.ext.sqlalchemy import Model
from sqlalchemy.ext.declarative import declarative_base
def test_multiple_pk():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Model(db.Model):
id = db.Column(db.Integer, primary_key=True)
id2 = db.Column(db.String(20), primary_key=True)
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Model, db.session, form_columns=['id', 'id2', 'test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
rv = client.post('/admin/model/new/',
data=dict(id=1, id2='two', test='test3'))
eq_(rv.status_code, 302)
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
rv = client.get('/admin/model/edit/?id=1,two')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
# Correct order is mandatory -> fail here
rv = client.get('/admin/model/edit/?id=two,1')
eq_(rv.status_code, 302)
def test_joined_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
discriminator = db.Column('type', db.String(50))
__mapper_args__ = {'polymorphic_on': discriminator}
class Child(Parent):
__tablename__ = 'children'
__mapper_args__ = {'polymorphic_identity': 'child'}
id = db.Column(db.ForeignKey(Parent.id), primary_key=True)
name = db.Column(db.String(100))
db.create_all()
view = CustomModelView(Child, db.session, form_columns=['id', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
def test_single_table_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
CustomModel = declarative_base(Model, name='Model')
class Parent(CustomModel):
__tablename__ = 'parent'
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
discriminator = db.Column('type', db.String(50))
__mapper_args__ = {'polymorphic_on': discriminator}
class Child(Parent):
__mapper_args__ = {'polymorphic_identity': 'child'}
name = db.Column(db.String(100))
CustomModel.metadata.create_all(db.engine)
view = CustomModelView(Child, db.session, form_columns=['id', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
def test_concrete_table_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
class Child(Parent):
__mapper_args__ = {'concrete': True}
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(100))
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Child, db.session, form_columns=['id', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
def test_concrete_multipk_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
class Child(Parent):
__mapper_args__ = {'concrete': True}
id = db.Column(db.Integer, primary_key=True)
id2 = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(100))
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Child, db.session, form_columns=['id', 'id2', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, id2=2, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1,2')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
from nose.tools import eq_, ok_
from flask.ext.admin import tools
def test_encode_decode():
eq_(tools.iterdecode(tools.iterencode([1, 2, 3])), (u'1', u'2', u'3'))
eq_(tools.iterdecode(tools.iterencode([',', ',', ','])), (u',', u',', u','))
eq_(tools.iterdecode(tools.iterencode(['.hello.,', ',', ','])), (u'.hello.,', u',', u','))
eq_(tools.iterdecode(tools.iterencode(['.....,,,.,,..,.,,.,'])), (u'.....,,,.,,..,.,,.,',))
eq_(tools.iterdecode(tools.iterencode([])), tuple())
# Malformed inputs should not crash
ok_(tools.iterdecode('.'))
eq_(tools.iterdecode(','), (u'', u''))
...@@ -2,7 +2,10 @@ import sys ...@@ -2,7 +2,10 @@ import sys
import traceback import traceback
# Python 3 compatibility # Python 3 compatibility
from ._compat import reduce from ._compat import reduce, as_unicode
CHAR_ESCAPE = u'.'
CHAR_SEPARATOR = u','
def import_module(name, required=True): def import_module(name, required=True):
...@@ -96,3 +99,47 @@ def get_dict_attr(obj, attr, default=None): ...@@ -96,3 +99,47 @@ def get_dict_attr(obj, attr, default=None):
return obj.__dict__[attr] return obj.__dict__[attr]
return default return default
def iterencode(iter):
"""
Encode enumerable as compact string representation.
:param iter:
Enumerable
"""
return ','.join(as_unicode(v)
.replace(CHAR_ESCAPE, CHAR_ESCAPE + CHAR_ESCAPE)
.replace(CHAR_SEPARATOR, CHAR_ESCAPE + CHAR_SEPARATOR)
for v in iter)
def iterdecode(value):
"""
Decode enumerable from string presentation as a tuple
"""
if not value:
return tuple()
result = []
accumulator = u''
escaped = False
for c in value:
if not escaped:
if c == CHAR_ESCAPE:
escaped = True
continue
elif c == CHAR_SEPARATOR:
result.append(accumulator)
accumulator = u''
continue
else:
escaped = False
accumulator += c
result.append(accumulator)
return tuple(result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment