Commit 845f6588 authored by Serge S. Koval's avatar Serge S. Koval

Fixed #457, generally improved multi-PK models handling and inherited models

parent e97707e4
......@@ -11,7 +11,7 @@ from flask.ext.admin._compat import iteritems
from .validators import Unique
from .fields import QuerySelectField, QuerySelectMultipleField, InlineModelFormList
from .tools import is_inherited_primary_key, get_column_for_current_model, has_multiple_pks
from .tools import has_multiple_pks, filter_foreign_columns
from .ajax import create_ajax_loader
try:
......@@ -151,11 +151,13 @@ class AdminModelConverter(ModelConverterBase):
# Ignore pk/fk
if hasattr(prop, 'columns'):
# Check if more than one column mapped to the property
if len(prop.columns) != 1:
if is_inherited_primary_key(prop):
column = get_column_for_current_model(prop)
else:
if len(prop.columns) > 1:
columns = filter_foreign_columns(model.__table__, prop.columns)
if len(columns) > 1:
raise TypeError('Can not convert multiple-column properties (%s.%s)' % (model, prop.key))
column = columns[0]
else:
# Grab column
column = prop.columns[0]
......
......@@ -3,6 +3,9 @@ from sqlalchemy.sql.operators import eq
from sqlalchemy.exc import DBAPIError
from ast import literal_eval
from flask.ext.admin.tools import iterencode, iterdecode
def parse_like_term(term):
if term.startswith('^'):
stmt = '%s%%' % term[1:]
......@@ -14,6 +17,16 @@ def parse_like_term(term):
return stmt
def filter_foreign_columns(base_table, columns):
"""
Return list of columns that belong to passed table.
:param base_table: Table to check against
:param columns: List of columns to filter
"""
return filter(lambda c: c.table == base_table, columns)
def get_primary_key(model):
"""
Return primary key name from a model. If the primary key consists of multiple columns,
......@@ -26,18 +39,11 @@ def get_primary_key(model):
pks = []
for p in props:
if hasattr(p, 'expression'): # expression = primary column or expression for this ColumnProperty
if p.expression.primary_key:
if is_inherited_primary_key(p):
pks.append(get_column_for_current_model(p).key)
else:
if hasattr(p, 'columns'):
for c in filter_foreign_columns(model.__table__, p.columns):
if c.primary_key:
pks.append(p.key)
else:
if hasattr(p, 'columns'):
for c in p.columns:
if c.primary_key:
pks.append(p.key)
break
break
if len(pks) == 1:
return pks[0]
......@@ -46,54 +52,16 @@ def get_primary_key(model):
else:
return None
def is_inherited_primary_key(prop):
"""
Return True, if the ColumnProperty is an inherited primary key
Check if all columns are primary keys and _one_ does not have a foreign key -> looks like joined
table inheritance: http://docs.sqlalchemy.org/en/latest/orm/inheritance.html with "standard
practice" of same column name.
:param prop: The ColumnProperty to check
:return: Boolean
:raises: Exceptions as they occur - no ExceptionHandling here
"""
if not hasattr(prop, 'expression'):
return False
if prop.expression.primary_key:
return len(prop._orig_columns) == len(prop.columns)-1
return False
def get_column_for_current_model(prop):
"""
Return the Column() of the ColumnProperty "prop", that refers to the current model
When using inheritance, a ColumnProperty may contain multiple columns. This function
returns the Column(), the belongs to the Model of the ColumnProperty - the "current"
model
:param prop: The ColumnProperty
:return: The column for the current model
:raises: TypeError if not exactely one Column() for the current model could be found.
All other Exceptions not handled here but raised
"""
candidates = [column for column in prop.columns if column.expression == prop.expression]
if len(candidates) != 1:
raise TypeError('Not exactly one column for the current model found. ' +
'Found %d columns for property %s' % (len(candidates), prop))
else:
return candidates[0]
def has_multiple_pks(model):
"""Return True, if the model has more than one primary key
"""
Return True, if the model has more than one primary key
"""
if not hasattr(model, '_sa_class_manager'):
raise TypeError('model must be a sqlalchemy mapped model')
pks = model._sa_class_manager.mapper.primary_key
return len(pks) > 1
return len(model._sa_class_manager.mapper.primary_key) > 1
def tuple_operator_in(model_pk, ids):
"""The tuple_ Operator only works on certain engines like MySQL or Postgresql. It does not work with sqlite.
......@@ -123,36 +91,27 @@ def tuple_operator_in(model_pk, ids):
def get_query_for_ids(modelquery, model, ids):
"""
Return a query object, that contains all entities of the given model for
the primary keys provided in the ids-parameter.
The ``pks`` parameter is a tuple, that contains the different primary key values,
that should be returned. If the primary key of the model consists of multiple columns
every entry of the ``pks`` parameter must be a tuple containing the columns-values in the
correct order, that make up the primary key of the model
If the model has multiple primary keys, the
`tuple_ <http://docs.sqlalchemy.org/en/latest/core/expression_api.html#sqlalchemy.sql.expression.tuple_>`_
operator will be used. As this operator does not work on certain databases,
notably on sqlite, a workaround function :func:`tuple_operator_in` is provided
that implements the same logic using OR and AND operations.
When having multiple primary keys, the pks are provided as a list of tuple-look-alike-strings,
``[u'(1, 2)', u'(1, 1)']``. These needs to be evaluated into real tuples, where
`Stackoverflow Question 3945856 <http://stackoverflow.com/questions/3945856/converting-string-to-tuple-and-adding-to-tuple>`_
pointed to `Literal Eval <http://docs.python.org/2/library/ast.html#ast.literal_eval>`_, which is now used.
Return a query object filtered by primary key values passed in `ids` argument.
Unfortunately, it is not possible to use `in_` filter if model has more than one
primary key.
"""
if has_multiple_pks(model):
model_pk = [getattr(model, pk_name).expression for pk_name in get_primary_key(model)]
ids = [literal_eval(id) for id in ids]
# Decode keys to tuples
decoded_ids = [iterdecode(v) for v in ids]
# Get model primary key property references
model_pk = [getattr(model, name) for name in get_primary_key(model)]
try:
query = modelquery.filter(tuple_(*model_pk).in_(ids))
query = modelquery.filter(tuple_(*model_pk).in_(decoded_ids))
# Only the execution of the query will tell us, if the tuple_
# operator really works
query.all()
except DBAPIError:
query = modelquery.filter(tuple_operator_in(model_pk, ids))
query = modelquery.filter(tuple_operator_in(model_pk, decoded_ids))
else:
model_pk = getattr(model, get_primary_key(model))
query = modelquery.filter(model_pk.in_(ids))
return query
......@@ -16,7 +16,7 @@ from flask.ext.admin._backwards import ObsoleteAttr
from flask.ext.admin.contrib.sqla import form, filters, tools
from .typefmt import DEFAULT_FORMATTERS
from .tools import is_inherited_primary_key, get_column_for_current_model, get_query_for_ids
from .tools import get_query_for_ids
from .ajax import create_ajax_loader
......@@ -296,23 +296,20 @@ class ModelView(BaseModelView):
# Scaffolding
def scaffold_pk(self):
"""
Return the primary key name from a model
PK can be a single value or a tuple if multiple PKs exist
Return the primary key name(s) from a model
If model has single primary key, will return a string and tuple otherwise
"""
return tools.get_primary_key(self.model)
def get_pk_value(self, model):
"""
Return the PK value from a model object.
PK can be a single value or a tuple if multiple PKs exist
Return the primary key value from a model object.
If there are multiple primary keys, they're encoded into string representation.
"""
try:
if isinstance(self._primary_key, tuple):
return tools.iterencode(getattr(model, attr) for attr in self._primary_key)
else:
return getattr(model, self._primary_key)
except TypeError:
v = []
for attr in self._primary_key:
v.append(getattr(model, attr))
return tuple(v)
def scaffold_list_columns(self):
"""
......@@ -321,26 +318,21 @@ class ModelView(BaseModelView):
columns = []
for p in self._get_model_iterator():
# Verify type
if hasattr(p, 'direction'):
if self.column_display_all_relations or p.direction.name == 'MANYTOONE':
columns.append(p.key)
elif hasattr(p, 'columns'):
column_inherited_primary_key = False
if len(p.columns) > 1:
filtered = tools.filter_foreign_columns(self.model.__table__, p.columns)
if len(p.columns) != 1:
if is_inherited_primary_key(p):
column = get_column_for_current_model(p)
else:
if len(filtered) > 1:
# TODO: Skip column and issue a warning
raise TypeError('Can not convert multiple-column properties (%s.%s)' % (self.model, p.key))
column = filtered[0]
else:
# Grab column
column = p.columns[0]
# An inherited primary key has a foreign key as well
if column.foreign_keys and not is_inherited_primary_key(p):
continue
if not self.column_display_pk and column.primary_key:
continue
......@@ -783,7 +775,7 @@ class ModelView(BaseModelView):
:param id:
Model id
"""
return self.session.query(self.model).get(id)
return self.session.query(self.model).get(tools.iterdecode(id))
# Error handler
def handle_view_exception(self, exc):
......@@ -882,7 +874,6 @@ class ModelView(BaseModelView):
lazy_gettext('Are you sure you want to delete selected models?'))
def action_delete(self, ids):
try:
query = get_query_for_ids(self.get_query(), self.model, ids)
if self.fast_mass_delete:
......
......@@ -8,7 +8,7 @@ def setup():
app.config['SECRET_KEY'] = '1'
app.config['CSRF_ENABLED'] = False
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
#app.config['SQLALCHEMY_ECHO'] = True
app.config['SQLALCHEMY_ECHO'] = True
db = SQLAlchemy(app)
admin = Admin(app)
......
......@@ -468,42 +468,6 @@ def test_non_int_pk():
data = rv.data.decode('utf-8')
ok_('test2' in data)
def test_multiple__pk():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Model(db.Model):
id = db.Column(db.Integer, primary_key=True)
id2 = db.Column(db.String(20), primary_key=True)
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Model, db.session, form_columns=['id', 'id2', 'test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
rv = client.post('/admin/model/new/',
data=dict(id=1, id2='two', test='test3'))
eq_(rv.status_code, 302)
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
rv = client.get('/admin/model/edit/?id=1&id=two')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
# Correct order is mandatory -> fail here
rv = client.get('/admin/model/edit/?id=two&id=1')
eq_(rv.status_code, 302)
def test_form_columns():
app, db, admin = setup()
......
from nose.tools import eq_, ok_
from . import setup
from .test_basic import CustomModelView
from flask.ext.sqlalchemy import Model
from sqlalchemy.ext.declarative import declarative_base
def test_multiple_pk():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Model(db.Model):
id = db.Column(db.Integer, primary_key=True)
id2 = db.Column(db.String(20), primary_key=True)
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Model, db.session, form_columns=['id', 'id2', 'test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
rv = client.post('/admin/model/new/',
data=dict(id=1, id2='two', test='test3'))
eq_(rv.status_code, 302)
rv = client.get('/admin/model/')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
rv = client.get('/admin/model/edit/?id=1,two')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test3' in data)
# Correct order is mandatory -> fail here
rv = client.get('/admin/model/edit/?id=two,1')
eq_(rv.status_code, 302)
def test_joined_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
discriminator = db.Column('type', db.String(50))
__mapper_args__ = {'polymorphic_on': discriminator}
class Child(Parent):
__tablename__ = 'children'
__mapper_args__ = {'polymorphic_identity': 'child'}
id = db.Column(db.ForeignKey(Parent.id), primary_key=True)
name = db.Column(db.String(100))
db.create_all()
view = CustomModelView(Child, db.session, form_columns=['id', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
def test_single_table_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
CustomModel = declarative_base(Model, name='Model')
class Parent(CustomModel):
__tablename__ = 'parent'
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
discriminator = db.Column('type', db.String(50))
__mapper_args__ = {'polymorphic_on': discriminator}
class Child(Parent):
__mapper_args__ = {'polymorphic_identity': 'child'}
name = db.Column(db.String(100))
CustomModel.metadata.create_all(db.engine)
view = CustomModelView(Child, db.session, form_columns=['id', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
def test_concrete_table_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
class Child(Parent):
__mapper_args__ = {'concrete': True}
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(100))
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Child, db.session, form_columns=['id', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
def test_concrete_multipk_inheritance():
# Test multiple primary keys - mix int and string together
app, db, admin = setup()
class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
test = db.Column(db.String)
class Child(Parent):
__mapper_args__ = {'concrete': True}
id = db.Column(db.Integer, primary_key=True)
id2 = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(100))
test = db.Column(db.String)
db.create_all()
view = CustomModelView(Child, db.session, form_columns=['id', 'id2', 'test', 'name'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/child/')
eq_(rv.status_code, 200)
rv = client.post('/admin/child/new/',
data=dict(id=1, id2=2, test='foo', name='bar'))
eq_(rv.status_code, 302)
rv = client.get('/admin/child/edit/?id=1,2')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('foo' in data)
ok_('bar' in data)
from nose.tools import eq_, ok_
from flask.ext.admin import tools
def test_encode_decode():
eq_(tools.iterdecode(tools.iterencode([1, 2, 3])), (u'1', u'2', u'3'))
eq_(tools.iterdecode(tools.iterencode([',', ',', ','])), (u',', u',', u','))
eq_(tools.iterdecode(tools.iterencode(['.hello.,', ',', ','])), (u'.hello.,', u',', u','))
eq_(tools.iterdecode(tools.iterencode(['.....,,,.,,..,.,,.,'])), (u'.....,,,.,,..,.,,.,',))
eq_(tools.iterdecode(tools.iterencode([])), tuple())
# Malformed inputs should not crash
ok_(tools.iterdecode('.'))
eq_(tools.iterdecode(','), (u'', u''))
......@@ -2,7 +2,10 @@ import sys
import traceback
# Python 3 compatibility
from ._compat import reduce
from ._compat import reduce, as_unicode
CHAR_ESCAPE = u'.'
CHAR_SEPARATOR = u','
def import_module(name, required=True):
......@@ -96,3 +99,47 @@ def get_dict_attr(obj, attr, default=None):
return obj.__dict__[attr]
return default
def iterencode(iter):
"""
Encode enumerable as compact string representation.
:param iter:
Enumerable
"""
return ','.join(as_unicode(v)
.replace(CHAR_ESCAPE, CHAR_ESCAPE + CHAR_ESCAPE)
.replace(CHAR_SEPARATOR, CHAR_ESCAPE + CHAR_SEPARATOR)
for v in iter)
def iterdecode(value):
"""
Decode enumerable from string presentation as a tuple
"""
if not value:
return tuple()
result = []
accumulator = u''
escaped = False
for c in value:
if not escaped:
if c == CHAR_ESCAPE:
escaped = True
continue
elif c == CHAR_SEPARATOR:
result.append(accumulator)
accumulator = u''
continue
else:
escaped = False
accumulator += c
result.append(accumulator)
return tuple(result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment