Commit a8b82eb8 authored by Paul Brown's avatar Paul Brown

SQLA - allow using model objects in column_list

parent 5a93a641
...@@ -471,20 +471,57 @@ class ModelView(BaseModelView): ...@@ -471,20 +471,57 @@ class ModelView(BaseModelView):
if isinstance(c, tuple): if isinstance(c, tuple):
column, path = tools.get_field_with_path(self.model, c[1]) column, path = tools.get_field_with_path(self.model, c[1])
column_name = c[0] column_name = c[0]
elif isinstance(c, InstrumentedAttribute):
column, path = tools.get_field_with_path(self.model, c)
column_name = str(c)
else: else:
column, path = tools.get_field_with_path(self.model, c) column, path = tools.get_field_with_path(self.model, c)
column_name = c column_name = c
result[column_name] = column
if path: if path:
# column is in another table, use full path as column_name
column_name = text_type(c)
self._sortable_joins[column_name] = path self._sortable_joins[column_name] = path
else:
# column is in same table, use only model attribute name
column_name = column.key
# column_name must match column_name used in `get_list_columns`
result[column_name] = column
return result return result
def get_list_columns(self):
"""
Returns a list of tuples with the model field name and formatted
field name. If `column_list` was set, returns it. Otherwise calls
`scaffold_list_columns` to generate the list from the model.
"""
if self.column_list is None:
columns = self.scaffold_list_columns()
# Filter excluded columns
if self.column_exclude_list:
columns = [c for c in columns
if c not in self.column_exclude_list]
return [(c, self.get_column_name(c)) for c in columns]
else:
columns = []
for c in self.column_list:
column, path = tools.get_field_with_path(self.model, c)
if path:
# column is in another table, use full path
column_name = text_type(c)
else:
# column is in same table, use only model attribute name
column_name = column.key
visible_name = self.get_column_name(column_name)
# column_name must match column_name in `get_sortable_columns`
columns.append((column_name, visible_name))
return columns
def init_search(self): def init_search(self):
""" """
Initialize search. Returns `True` if search is supported for this Initialize search. Returns `True` if search is supported for this
......
...@@ -168,6 +168,15 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -168,6 +168,15 @@ class BaseModelView(BaseView, ActionsMixin):
class MyModelView(BaseModelView): class MyModelView(BaseModelView):
column_list = ('name', 'last_name', 'email') column_list = ('name', 'last_name', 'email')
(Added in 1.4.0) SQLAlchemy model attributes can be used instead of strings::
class MyModelView(BaseModelView):
column_list = ('name', User.last_name)
When using SQLAlchemy models, you can reference related columns like this::
class MyModelView(BaseModelView):
column_list = ('<relationship>.<related column name>',)
""" """
column_exclude_list = ObsoleteAttr('column_exclude_list', column_exclude_list = ObsoleteAttr('column_exclude_list',
...@@ -505,6 +514,15 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -505,6 +514,15 @@ class BaseModelView(BaseView, ActionsMixin):
class MyModelView(BaseModelView): class MyModelView(BaseModelView):
form_columns = ('name', 'email') form_columns = ('name', 'email')
(Added in 1.4.0) SQLAlchemy model attributes can be used instead of
strings::
class MyModelView(BaseModelView):
form_columns = ('name', User.last_name)
SQLA Note: Model attributes must be on the same model as your ModelView
or you will need to use `inline_models`.
""" """
form_excluded_columns = ObsoleteAttr('form_excluded_columns', form_excluded_columns = ObsoleteAttr('form_excluded_columns',
...@@ -878,9 +896,9 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -878,9 +896,9 @@ class BaseModelView(BaseView, ActionsMixin):
def get_list_columns(self): def get_list_columns(self):
""" """
Returns a list of the model field names. If `column_list` was Returns a list of tuples with the model field name and formatted
set, returns it. Otherwise calls `scaffold_list_columns` field name. If `column_list` was set, returns it. Otherwise calls
to generate the list from the model. `scaffold_list_columns` to generate the list from the model.
""" """
columns = self.column_list columns = self.column_list
......
...@@ -218,6 +218,7 @@ def test_list_columns(): ...@@ -218,6 +218,7 @@ def test_list_columns():
Model1, Model2 = create_models(db) Model1, Model2 = create_models(db)
# test column_list with a list of strings
view = CustomModelView(Model1, db.session, view = CustomModelView(Model1, db.session,
column_list=['test1', 'test3'], column_list=['test1', 'test3'],
column_labels=dict(test1='Column1')) column_labels=dict(test1='Column1'))
...@@ -233,6 +234,58 @@ def test_list_columns(): ...@@ -233,6 +234,58 @@ def test_list_columns():
ok_('Column1' in data) ok_('Column1' in data)
ok_('Test2' not in data) ok_('Test2' not in data)
# test column_list with a list of SQLAlchemy columns
view2 = CustomModelView(Model1, db.session, endpoint='model1_2',
column_list=[Model1.test1, Model1.test3],
column_labels=dict(test1='Column1'))
admin.add_view(view2)
eq_(len(view2._list_columns), 2)
eq_(view2._list_columns, [('test1', 'Column1'), ('test3', 'Test3')])
rv = client.get('/admin/model1_2/')
data = rv.data.decode('utf-8')
ok_('Column1' in data)
ok_('Test2' not in data)
def test_complex_list_columns():
app, db, admin = setup()
M1, M2 = create_models(db)
m1 = M1('model1_val1')
db.session.add(m1)
db.session.add(M2('model2_val1', model1=m1))
db.session.commit()
# test column_list with a list of strings on a relation
view = CustomModelView(M2, db.session,
column_list=['model1.test1'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model2/')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('model1_val1' in data)
# TODO: Allow providing a list of related models
"""
# test column_list with a list of models on a relation
view2 = CustomModelView(M2, db.session, endpoint='model2_2',
column_list=[M1.test1])
admin.add_view(view2)
client = app.test_client()
rv = client.get('/admin/model2_2/')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('model1_val1' in data)
"""
def test_exclude_columns(): def test_exclude_columns():
app, db, admin = setup() app, db, admin = setup()
...@@ -1608,6 +1661,31 @@ def test_default_sort(): ...@@ -1608,6 +1661,31 @@ def test_default_sort():
eq_(data[1].test1, 'b') eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c') eq_(data[2].test1, 'c')
# test default sort on renamed columns - with column_list scaffolding
view2 = CustomModelView(M1, db.session, column_default_sort='test1',
column_labels={'test1': 'blah'}, endpoint='m1_2')
admin.add_view(view2)
_, data = view2.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'a')
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
# test default sort on renamed columns - without column_list scaffolding
view3 = CustomModelView(M1, db.session, column_default_sort='test1',
column_labels={'test1': 'blah'}, endpoint='m1_3',
column_list=['test1'])
admin.add_view(view3)
_, data = view3.get_list(0, None, None, None, None)
eq_(len(data), 3)
eq_(data[0].test1, 'a')
eq_(data[1].test1, 'b')
eq_(data[2].test1, 'c')
def test_complex_sort(): def test_complex_sort():
app, db, admin = setup() app, db, admin = setup()
...@@ -1625,8 +1703,8 @@ def test_complex_sort(): ...@@ -1625,8 +1703,8 @@ def test_complex_sort():
# test sorting on relation string - 'model1.test1' # test sorting on relation string - 'model1.test1'
view = CustomModelView(M2, db.session, view = CustomModelView(M2, db.session,
column_list = ['string_field', 'model1.test1'], column_list=['string_field', 'model1.test1'],
column_sortable_list = ['model1.test1']) column_sortable_list=['model1.test1'])
admin.add_view(view) admin.add_view(view)
client = app.test_client() client = app.test_client()
...@@ -1636,8 +1714,8 @@ def test_complex_sort(): ...@@ -1636,8 +1714,8 @@ def test_complex_sort():
# test sorting on relation object - M2.string_field # test sorting on relation object - M2.string_field
view2 = CustomModelView(M1, db.session, view2 = CustomModelView(M1, db.session,
column_list = ['model2.string_field'], column_list=['model2.string_field'],
column_sortable_list = [M2.string_field]) column_sortable_list=[M2.string_field])
admin.add_view(view2) admin.add_view(view2)
client = app.test_client() client = app.test_client()
...@@ -1647,6 +1725,19 @@ def test_complex_sort(): ...@@ -1647,6 +1725,19 @@ def test_complex_sort():
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('Sort by' in data) ok_('Sort by' in data)
# test sorting on relations with model in column_list
view3 = CustomModelView(M1, db.session, endpoint="model1_2",
column_list=[M2.string_field],
column_sortable_list=[M2.string_field])
admin.add_view(view3)
client = app.test_client()
rv = client.get('/admin/model1_2/?sort=1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('Sort by' in data)
def test_default_complex_sort(): def test_default_complex_sort():
app, db, admin = setup() app, db, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment