Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Sign in
Toggle navigation
F
flask-admin
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
JIRA
JIRA
Merge Requests
0
Merge Requests
0
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Commits
Issue Boards
Open sidebar
Python-Dev
flask-admin
Commits
78a66cda
Commit
78a66cda
authored
Jul 11, 2014
by
Serge S. Koval
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fixed #556 Added support for complex sortables
parent
664f4622
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
67 deletions
+107
-67
view.py
flask_admin/contrib/sqla/view.py
+84
-67
test_basic.py
flask_admin/tests/sqlamodel/test_basic.py
+23
-0
No files found.
flask_admin/contrib/sqla/view.py
View file @
78a66cda
...
@@ -262,10 +262,12 @@ class ModelView(BaseModelView):
...
@@ -262,10 +262,12 @@ class ModelView(BaseModelView):
self
.
session
=
session
self
.
session
=
session
self
.
_search_fields
=
None
self
.
_search_fields
=
None
self
.
_search_joins
=
dict
()
self
.
_search_joins
=
[]
self
.
_filter_joins
=
dict
()
self
.
_filter_joins
=
dict
()
self
.
_sortable_joins
=
dict
()
if
self
.
form_choices
is
None
:
if
self
.
form_choices
is
None
:
self
.
form_choices
=
{}
self
.
form_choices
=
{}
...
@@ -293,6 +295,41 @@ class ModelView(BaseModelView):
...
@@ -293,6 +295,41 @@ class ModelView(BaseModelView):
return
model
.
_sa_class_manager
.
mapper
.
iterate_properties
return
model
.
_sa_class_manager
.
mapper
.
iterate_properties
def
_get_columns_for_field
(
self
,
field
):
if
(
not
field
or
not
hasattr
(
field
,
'property'
)
or
not
hasattr
(
field
.
property
,
'columns'
)
or
not
field
.
property
.
columns
):
raise
Exception
(
'Invalid field
%
s: does not contains any columns.'
%
field
)
return
field
.
property
.
columns
def
_get_field_with_path
(
self
,
name
):
join_tables
=
[]
if
isinstance
(
name
,
string_types
):
model
=
self
.
model
for
attribute
in
name
.
split
(
'.'
):
value
=
getattr
(
model
,
attribute
)
if
(
hasattr
(
value
,
'property'
)
and
hasattr
(
value
.
property
,
'direction'
)):
model
=
value
.
property
.
mapper
.
class_
table
=
model
.
__table__
if
self
.
_need_join
(
table
):
join_tables
.
append
(
table
)
attr
=
value
else
:
attr
=
name
return
join_tables
,
attr
def
_need_join
(
self
,
table
):
return
table
not
in
self
.
model
.
_sa_class_manager
.
mapper
.
tables
# Scaffolding
# Scaffolding
def
scaffold_pk
(
self
):
def
scaffold_pk
(
self
):
"""
"""
...
@@ -370,40 +407,35 @@ class ModelView(BaseModelView):
...
@@ -370,40 +407,35 @@ class ModelView(BaseModelView):
return
columns
return
columns
def
_get_columns_for_field
(
self
,
field
):
def
get_sortable_columns
(
self
):
if
(
not
field
or
"""
not
hasattr
(
field
,
'property'
)
or
Returns a dictionary of the sortable columns. Key is a model
not
hasattr
(
field
.
property
,
'columns'
)
or
field name and value is sort column (for example - attribute).
not
field
.
property
.
columns
):
raise
Exception
(
'Invalid field
%
s: does not contains any columns.'
%
field
)
return
field
.
property
.
columns
def
_get_field_with_path
(
self
,
name
):
If `column_sortable_list` is set, will use it. Otherwise, will call
join_tables
=
[]
`scaffold_sortable_columns` to get them from the model.
"""
self
.
_sortable_joins
=
dict
()
if
isinstance
(
name
,
string_types
):
if
self
.
column_sortable_list
is
None
:
model
=
self
.
model
return
self
.
scaffold_sortable_columns
()
else
:
result
=
dict
()
for
attribute
in
name
.
split
(
'.'
):
for
c
in
self
.
column_sortable_list
:
value
=
getattr
(
model
,
attribute
)
if
isinstance
(
c
,
tuple
):
join_tables
,
column
=
self
.
_get_field_with_path
(
c
[
1
])
if
(
hasattr
(
value
,
'property'
)
and
result
[
c
[
0
]]
=
column
hasattr
(
value
.
property
,
'direction'
)):
model
=
value
.
property
.
mapper
.
class_
table
=
model
.
__table__
if
self
.
_need_join
(
table
):
if
join_tables
:
join_tables
.
append
(
table
)
self
.
_sortable_joins
[
c
[
0
]]
=
join_tables
else
:
attr
=
value
join_tables
,
column
=
self
.
_get_field_with_path
(
c
)
else
:
attr
=
name
return
join_tables
,
attr
result
[
c
]
=
column
def
_need_join
(
self
,
table
):
return
result
return
table
not
in
self
.
model
.
_sa_class_manager
.
mapper
.
tables
def
init_search
(
self
):
def
init_search
(
self
):
"""
"""
...
@@ -415,7 +447,9 @@ class ModelView(BaseModelView):
...
@@ -415,7 +447,9 @@ class ModelView(BaseModelView):
"""
"""
if
self
.
column_searchable_list
:
if
self
.
column_searchable_list
:
self
.
_search_fields
=
[]
self
.
_search_fields
=
[]
self
.
_search_joins
=
dict
()
self
.
_search_joins
=
[]
joins
=
set
()
for
p
in
self
.
column_searchable_list
:
for
p
in
self
.
column_searchable_list
:
join_tables
,
attr
=
self
.
_get_field_with_path
(
p
)
join_tables
,
attr
=
self
.
_get_field_with_path
(
p
)
...
@@ -433,9 +467,10 @@ class ModelView(BaseModelView):
...
@@ -433,9 +467,10 @@ class ModelView(BaseModelView):
self
.
_search_fields
.
append
(
column
)
self
.
_search_fields
.
append
(
column
)
# Store joins, avoid duplicates
# Store joins, avoid duplicates
if
join_tables
:
for
table
in
join_tables
:
for
table
in
join_tables
:
if
table
.
name
not
in
joins
:
self
.
_search_joins
[
table
.
name
]
=
table
self
.
_search_joins
.
append
(
table
)
joins
.
add
(
table
.
name
)
return
bool
(
self
.
column_searchable_list
)
return
bool
(
self
.
column_searchable_list
)
...
@@ -621,7 +656,7 @@ class ModelView(BaseModelView):
...
@@ -621,7 +656,7 @@ class ModelView(BaseModelView):
"""
"""
return
self
.
session
.
query
(
func
.
count
(
'*'
))
.
select_from
(
self
.
model
)
return
self
.
session
.
query
(
func
.
count
(
'*'
))
.
select_from
(
self
.
model
)
def
_order_by
(
self
,
query
,
joins
,
sort_field
,
sort_desc
):
def
_order_by
(
self
,
query
,
joins
,
sort_
joins
,
sort_
field
,
sort_desc
):
"""
"""
Apply order_by to the query
Apply order_by to the query
...
@@ -635,33 +670,13 @@ class ModelView(BaseModelView):
...
@@ -635,33 +670,13 @@ class ModelView(BaseModelView):
Ascending or descending
Ascending or descending
"""
"""
# TODO: Preprocessing for joins
# TODO: Preprocessing for joins
# Try to handle it as a string
# Handle joins
if
isinstance
(
sort_field
,
string_types
):
if
sort_joins
:
# Create automatic join against a table if column name
for
table
in
sort_joins
:
# contains dot.
if
table
.
name
not
in
joins
:
if
'.'
in
sort_field
:
parts
=
sort_field
.
split
(
'.'
,
1
)
if
parts
[
0
]
not
in
joins
:
query
=
query
.
join
(
parts
[
0
])
joins
.
add
(
parts
[
0
])
elif
isinstance
(
sort_field
,
InstrumentedAttribute
):
# SQLAlchemy 0.8+ uses 'parent' as a name
mapper
=
getattr
(
sort_field
,
'parent'
,
None
)
if
mapper
is
None
:
# SQLAlchemy 0.7.x uses parententity
mapper
=
getattr
(
sort_field
,
'parententity'
,
None
)
if
mapper
is
not
None
:
table
=
mapper
.
tables
[
0
]
if
self
.
_need_join
(
table
)
and
table
.
name
not
in
joins
:
query
=
query
.
outerjoin
(
table
)
query
=
query
.
outerjoin
(
table
)
joins
.
add
(
table
.
name
)
joins
.
add
(
table
.
name
)
elif
isinstance
(
sort_field
,
Column
):
pass
else
:
raise
TypeError
(
'Wrong argument type'
)
if
sort_field
is
not
None
:
if
sort_field
is
not
None
:
if
sort_desc
:
if
sort_desc
:
...
@@ -677,10 +692,9 @@ class ModelView(BaseModelView):
...
@@ -677,10 +692,9 @@ class ModelView(BaseModelView):
if
order
is
not
None
:
if
order
is
not
None
:
field
,
direction
=
order
field
,
direction
=
order
if
isinstance
(
field
,
string_types
):
join_tables
,
attr
=
self
.
_get_field_with_path
(
field
)
field
=
getattr
(
self
.
model
,
field
)
return
field
,
direction
return
join_tables
,
field
,
direction
return
None
return
None
...
@@ -712,11 +726,11 @@ class ModelView(BaseModelView):
...
@@ -712,11 +726,11 @@ class ModelView(BaseModelView):
if
self
.
_search_supported
and
search
:
if
self
.
_search_supported
and
search
:
# Apply search-related joins
# Apply search-related joins
if
self
.
_search_joins
:
if
self
.
_search_joins
:
for
jn
in
self
.
_search_joins
.
values
()
:
for
table
in
self
.
_search_joins
:
query
=
query
.
join
(
jn
)
query
=
query
.
join
(
table
)
count_query
=
count_query
.
join
(
jn
)
count_query
=
count_query
.
join
(
table
)
joins
=
set
(
self
.
_search_joins
.
keys
()
)
joins
.
add
(
table
.
name
)
# Apply terms
# Apply terms
terms
=
search
.
split
(
' '
)
terms
=
search
.
split
(
' '
)
...
@@ -761,13 +775,16 @@ class ModelView(BaseModelView):
...
@@ -761,13 +775,16 @@ class ModelView(BaseModelView):
if
sort_column
is
not
None
:
if
sort_column
is
not
None
:
if
sort_column
in
self
.
_sortable_columns
:
if
sort_column
in
self
.
_sortable_columns
:
sort_field
=
self
.
_sortable_columns
[
sort_column
]
sort_field
=
self
.
_sortable_columns
[
sort_column
]
sort_joins
=
self
.
_sortable_joins
.
get
(
sort_column
)
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
sort_field
,
sort_desc
)
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
sort_
joins
,
sort_
field
,
sort_desc
)
else
:
else
:
order
=
self
.
_get_default_order
()
order
=
self
.
_get_default_order
()
if
order
:
if
order
:
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
order
[
0
],
order
[
1
])
sort_joins
,
sort_field
,
sort_desc
=
order
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
sort_joins
,
sort_field
,
sort_desc
)
# Pagination
# Pagination
if
page
is
not
None
:
if
page
is
not
None
:
...
...
flask_admin/tests/sqlamodel/test_basic.py
View file @
78a66cda
...
@@ -660,6 +660,29 @@ def test_default_sort():
...
@@ -660,6 +660,29 @@ def test_default_sort():
eq_
(
data
[
2
]
.
test1
,
'c'
)
eq_
(
data
[
2
]
.
test1
,
'c'
)
def
test_default_complex_sort
():
app
,
db
,
admin
=
setup
()
M1
,
M2
=
create_models
(
db
)
m1
=
M1
(
'b'
)
db
.
session
.
add
(
m1
)
db
.
session
.
add
(
M2
(
'c'
,
model1
=
m1
))
m2
=
M1
(
'a'
)
db
.
session
.
add
(
m2
)
db
.
session
.
add
(
M2
(
'c'
,
model1
=
m2
))
db
.
session
.
commit
()
view
=
CustomModelView
(
M2
,
db
.
session
,
column_default_sort
=
'model1.test1'
)
admin
.
add_view
(
view
)
_
,
data
=
view
.
get_list
(
0
,
None
,
None
,
None
,
None
)
eq_
(
len
(
data
),
2
)
eq_
(
data
[
0
]
.
model1
.
test1
,
'a'
)
eq_
(
data
[
1
]
.
model1
.
test1
,
'b'
)
def
test_extra_fields
():
def
test_extra_fields
():
app
,
db
,
admin
=
setup
()
app
,
db
,
admin
=
setup
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment