Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Sign in
Toggle navigation
F
flask-admin
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
JIRA
JIRA
Merge Requests
0
Merge Requests
0
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Commits
Issue Boards
Open sidebar
Python-Dev
flask-admin
Commits
207d23fd
Commit
207d23fd
authored
Jun 14, 2015
by
Petrus J.v.Rensburg
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into examples
parents
2d3d1d63
6b6fe519
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
616 additions
and
215 deletions
+616
-215
changelog.rst
doc/changelog.rst
+12
-18
app.py
examples/forms/app.py
+2
-1
app.py
examples/sqla/app.py
+1
-1
__init__.py
flask_admin/__init__.py
+1
-1
base.py
flask_admin/base.py
+13
-9
view.py
flask_admin/contrib/mongoengine/view.py
+2
-2
view.py
flask_admin/contrib/peewee/view.py
+3
-2
view.py
flask_admin/contrib/pymongo/view.py
+2
-2
filters.py
flask_admin/contrib/sqla/filters.py
+39
-35
view.py
flask_admin/contrib/sqla/view.py
+210
-113
upload.py
flask_admin/form/upload.py
+25
-2
base.py
flask_admin/model/base.py
+48
-15
layout.html
flask_admin/templates/bootstrap2/admin/layout.html
+2
-2
lib.html
flask_admin/templates/bootstrap2/admin/lib.html
+25
-0
list.html
flask_admin/templates/bootstrap2/admin/model/list.html
+11
-1
layout.html
flask_admin/templates/bootstrap3/admin/layout.html
+2
-2
lib.html
flask_admin/templates/bootstrap3/admin/lib.html
+23
-0
list.html
flask_admin/templates/bootstrap3/admin/model/list.html
+11
-1
test_basic.py
flask_admin/tests/mongoengine/test_basic.py
+17
-0
test_basic.py
flask_admin/tests/sqla/test_basic.py
+135
-7
test_base.py
flask_admin/tests/test_base.py
+10
-1
test_form_upload.py
flask_admin/tests/test_form_upload.py
+22
-0
No files found.
doc/changelog.rst
View file @
207d23fd
Changelog
Changelog
=========
=========
1.2.0
-----
* Codebase was migrated to Flask-Admin GitHub organization
* Automatically inject Flask-WTF CSRF token to internal Flask-Admin forms
* MapBox v4 support for GeoAlchemy
* Updated translations with help of CrowdIn
* Show warning if field was ignored in form rendering rules
* Simple AppEngine backend
* Optional support for Font Awesome in templates and menus
* Bug fixes
1.1.0
1.1.0
-----
-----
...
@@ -43,21 +55,3 @@ Highlights:
...
@@ -43,21 +55,3 @@ Highlights:
* Support for newer wtforms versions
* Support for newer wtforms versions
* `form_rules` property that affects both create and edit forms
* `form_rules` property that affects both create and edit forms
* Lots of bugfixes
* Lots of bugfixes
1.0.7
-----
Full change log and feature walkthrough can be found `here <http://mrjoes.github.io/2013/10/21/flask-admin-107.html>`_.
Highlights:
* Python 3 support
* AJAX-based foreign-key data loading for all backends
* New, optional, rule-based form rendering engine
* MongoEngine fixes and features: GridFS support, nested subdocument configuration and much more
* Greatly improved and more configurable inline models
* New WTForms fields and widgets
* `form_extra_columns` allows adding custom columns to the form declaratively
* Redis cli
* SQLAlchemy backend can handle inherited models with multiple PKs
* Lots of bug fixes
examples/forms/app.py
View file @
207d23fd
...
@@ -101,7 +101,8 @@ class FileView(sqla.ModelView):
...
@@ -101,7 +101,8 @@ class FileView(sqla.ModelView):
form_args
=
{
form_args
=
{
'path'
:
{
'path'
:
{
'label'
:
'File'
,
'label'
:
'File'
,
'base_path'
:
file_path
'base_path'
:
file_path
,
'allow_overwrite'
:
False
}
}
}
}
...
...
examples/sqla/app.py
View file @
207d23fd
...
@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
...
@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
# List of columns that can be sorted. For 'user' column, use User.username as
# List of columns that can be sorted. For 'user' column, use User.username as
# a column.
# a column.
column_sortable_list
=
(
'title'
,
(
'user'
,
User
.
username
),
'date'
)
column_sortable_list
=
(
'title'
,
(
'user'
,
'user.username'
),
'date'
)
# Rename 'title' columns to 'Post Title' in list view
# Rename 'title' columns to 'Post Title' in list view
column_labels
=
dict
(
title
=
'Post Title'
)
column_labels
=
dict
(
title
=
'Post Title'
)
...
...
flask_admin/__init__.py
View file @
207d23fd
__version__
=
'1.
1.1-dev
'
__version__
=
'1.
2.0
'
__author__
=
'Serge S. Koval'
__author__
=
'Serge S. Koval'
__email__
=
'serge.koval+github@gmail.com'
__email__
=
'serge.koval+github@gmail.com'
...
...
flask_admin/base.py
View file @
207d23fd
...
@@ -188,7 +188,7 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
...
@@ -188,7 +188,7 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
"""
"""
self
.
name
=
name
self
.
name
=
name
self
.
category
=
category
self
.
category
=
category
self
.
endpoint
=
endpoint
self
.
endpoint
=
self
.
_get_endpoint
(
endpoint
)
self
.
url
=
url
self
.
url
=
url
self
.
static_folder
=
static_folder
self
.
static_folder
=
static_folder
self
.
static_url_path
=
static_url_path
self
.
static_url_path
=
static_url_path
...
@@ -206,6 +206,16 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
...
@@ -206,6 +206,16 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if
self
.
_default_view
is
None
:
if
self
.
_default_view
is
None
:
raise
Exception
(
u'Attempted to instantiate admin view
%
s without default view'
%
self
.
__class__
.
__name__
)
raise
Exception
(
u'Attempted to instantiate admin view
%
s without default view'
%
self
.
__class__
.
__name__
)
def
_get_endpoint
(
self
,
endpoint
):
"""
Generate Flask endpoint name. By default converts class name to lower case if endpoint is
not explicitly provided.
"""
if
endpoint
:
return
endpoint
return
self
.
__class__
.
__name__
.
lower
()
def
create_blueprint
(
self
,
admin
):
def
create_blueprint
(
self
,
admin
):
"""
"""
Create Flask blueprint.
Create Flask blueprint.
...
@@ -213,10 +223,6 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
...
@@ -213,10 +223,6 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
# Store admin instance
# Store admin instance
self
.
admin
=
admin
self
.
admin
=
admin
# If endpoint name is not provided, get it from the class name
if
self
.
endpoint
is
None
:
self
.
endpoint
=
self
.
__class__
.
__name__
.
lower
()
# If the static_url_path is not provided, use the admin's
# If the static_url_path is not provided, use the admin's
if
not
self
.
static_url_path
:
if
not
self
.
static_url_path
:
self
.
static_url_path
=
admin
.
static_url_path
self
.
static_url_path
=
admin
.
static_url_path
...
@@ -234,15 +240,13 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
...
@@ -234,15 +240,13 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if
not
self
.
url
.
startswith
(
'/'
):
if
not
self
.
url
.
startswith
(
'/'
):
self
.
url
=
'
%
s/
%
s'
%
(
self
.
admin
.
url
,
self
.
url
)
self
.
url
=
'
%
s/
%
s'
%
(
self
.
admin
.
url
,
self
.
url
)
# If we're working from the root of the site, set prefix to None
# If we're working from the root of the site, set prefix to None
if
self
.
url
==
'/'
:
if
self
.
url
==
'/'
:
self
.
url
=
None
self
.
url
=
None
# prevent admin static files from conflicting with flask static files
# prevent admin static files from conflicting with flask static files
if
not
self
.
static_url_path
:
if
not
self
.
static_url_path
:
self
.
static_folder
=
'static'
self
.
static_folder
=
'static'
self
.
static_url_path
=
'/static/admin'
self
.
static_url_path
=
'/static/admin'
# If name is not povided, use capitalized endpoint name
# If name is not povided, use capitalized endpoint name
if
self
.
name
is
None
:
if
self
.
name
is
None
:
...
...
flask_admin/contrib/mongoengine/view.py
View file @
207d23fd
...
@@ -484,7 +484,7 @@ class ModelView(BaseModelView):
...
@@ -484,7 +484,7 @@ class ModelView(BaseModelView):
query
=
self
.
_search
(
query
,
search
)
query
=
self
.
_search
(
query
,
search
)
# Get count
# Get count
count
=
query
.
count
()
count
=
query
.
count
()
if
not
self
.
simple_list_pager
else
None
# Sorting
# Sorting
if
sort_column
:
if
sort_column
:
...
@@ -592,7 +592,7 @@ class ModelView(BaseModelView):
...
@@ -592,7 +592,7 @@ class ModelView(BaseModelView):
return
False
return
False
else
:
else
:
self
.
after_model_delete
(
model
)
self
.
after_model_delete
(
model
)
return
True
return
True
...
...
flask_admin/contrib/peewee/view.py
View file @
207d23fd
...
@@ -339,7 +339,7 @@ class ModelView(BaseModelView):
...
@@ -339,7 +339,7 @@ class ModelView(BaseModelView):
query
=
f
.
apply
(
query
,
f
.
clean
(
value
))
query
=
f
.
apply
(
query
,
f
.
clean
(
value
))
# Get count
# Get count
count
=
query
.
count
()
count
=
query
.
count
()
if
not
self
.
simple_list_pager
else
None
# Apply sorting
# Apply sorting
if
sort_column
is
not
None
:
if
sort_column
is
not
None
:
...
@@ -417,7 +417,7 @@ class ModelView(BaseModelView):
...
@@ -417,7 +417,7 @@ class ModelView(BaseModelView):
return
False
return
False
else
:
else
:
self
.
after_model_delete
(
model
)
self
.
after_model_delete
(
model
)
return
True
return
True
# Default model actions
# Default model actions
...
@@ -443,6 +443,7 @@ class ModelView(BaseModelView):
...
@@ -443,6 +443,7 @@ class ModelView(BaseModelView):
query
=
self
.
model
.
select
()
.
filter
(
model_pk
<<
ids
)
query
=
self
.
model
.
select
()
.
filter
(
model_pk
<<
ids
)
for
m
in
query
:
for
m
in
query
:
self
.
on_model_delete
(
m
)
m
.
delete_instance
(
recursive
=
True
)
m
.
delete_instance
(
recursive
=
True
)
count
+=
1
count
+=
1
...
...
flask_admin/contrib/pymongo/view.py
View file @
207d23fd
...
@@ -222,7 +222,7 @@ class ModelView(BaseModelView):
...
@@ -222,7 +222,7 @@ class ModelView(BaseModelView):
query
=
self
.
_search
(
query
,
search
)
query
=
self
.
_search
(
query
,
search
)
# Get count
# Get count
count
=
self
.
coll
.
find
(
query
)
.
count
()
count
=
self
.
coll
.
find
(
query
)
.
count
()
if
not
self
.
simple_list_pager
else
None
# Sorting
# Sorting
sort_by
=
None
sort_by
=
None
...
@@ -337,7 +337,7 @@ class ModelView(BaseModelView):
...
@@ -337,7 +337,7 @@ class ModelView(BaseModelView):
return
False
return
False
else
:
else
:
self
.
after_model_delete
(
model
)
self
.
after_model_delete
(
model
)
return
True
return
True
# Default model actions
# Default model actions
...
...
flask_admin/contrib/sqla/filters.py
View file @
207d23fd
import
warnings
import
time
import
datetime
from
flask_admin.babel
import
lazy_gettext
from
flask_admin.babel
import
lazy_gettext
from
flask_admin.model
import
filters
from
flask_admin.model
import
filters
from
flask_admin.contrib.sqla
import
tools
from
flask_admin.contrib.sqla
import
tools
from
sqlalchemy.sql
import
not_
,
or_
from
sqlalchemy.sql
import
not_
,
or_
class
BaseSQLAFilter
(
filters
.
BaseFilter
):
class
BaseSQLAFilter
(
filters
.
BaseFilter
):
"""
"""
Base SQLAlchemy filter.
Base SQLAlchemy filter.
...
@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
...
@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
self
.
column
=
column
self
.
column
=
column
def
get_column
(
self
,
alias
):
return
self
.
column
if
alias
is
None
else
getattr
(
alias
,
self
.
column
.
key
)
def
apply
(
self
,
query
,
value
,
alias
=
None
):
return
super
(
self
,
BaseSQLAFilter
)
.
apply
(
query
,
value
)
# Common filters
# Common filters
class
FilterEqual
(
BaseSQLAFilter
):
class
FilterEqual
(
BaseSQLAFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
return
query
.
filter
(
self
.
column
==
value
)
return
query
.
filter
(
self
.
get_column
(
alias
)
==
value
)
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'equals'
)
return
lazy_gettext
(
'equals'
)
class
FilterNotEqual
(
BaseSQLAFilter
):
class
FilterNotEqual
(
BaseSQLAFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
return
query
.
filter
(
self
.
column
!=
value
)
return
query
.
filter
(
self
.
get_column
(
alias
)
!=
value
)
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'not equal'
)
return
lazy_gettext
(
'not equal'
)
class
FilterLike
(
BaseSQLAFilter
):
class
FilterLike
(
BaseSQLAFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
stmt
=
tools
.
parse_like_term
(
value
)
stmt
=
tools
.
parse_like_term
(
value
)
return
query
.
filter
(
self
.
column
.
ilike
(
stmt
))
return
query
.
filter
(
self
.
get_column
(
alias
)
.
ilike
(
stmt
))
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'contains'
)
return
lazy_gettext
(
'contains'
)
class
FilterNotLike
(
BaseSQLAFilter
):
class
FilterNotLike
(
BaseSQLAFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
stmt
=
tools
.
parse_like_term
(
value
)
stmt
=
tools
.
parse_like_term
(
value
)
return
query
.
filter
(
~
self
.
column
.
ilike
(
stmt
))
return
query
.
filter
(
~
self
.
get_column
(
alias
)
.
ilike
(
stmt
))
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'not contains'
)
return
lazy_gettext
(
'not contains'
)
class
FilterGreater
(
BaseSQLAFilter
):
class
FilterGreater
(
BaseSQLAFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
return
query
.
filter
(
self
.
column
>
value
)
return
query
.
filter
(
self
.
get_column
(
alias
)
>
value
)
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'greater than'
)
return
lazy_gettext
(
'greater than'
)
class
FilterSmaller
(
BaseSQLAFilter
):
class
FilterSmaller
(
BaseSQLAFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
return
query
.
filter
(
self
.
column
<
value
)
return
query
.
filter
(
self
.
get_column
(
alias
)
<
value
)
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'smaller than'
)
return
lazy_gettext
(
'smaller than'
)
class
FilterEmpty
(
BaseSQLAFilter
,
filters
.
BaseBooleanFilter
):
class
FilterEmpty
(
BaseSQLAFilter
,
filters
.
BaseBooleanFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
if
value
==
'1'
:
if
value
==
'1'
:
return
query
.
filter
(
self
.
column
==
None
)
return
query
.
filter
(
self
.
get_column
(
alias
)
==
None
)
else
:
else
:
return
query
.
filter
(
self
.
column
!=
None
)
return
query
.
filter
(
self
.
get_column
(
alias
)
!=
None
)
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'empty'
)
return
lazy_gettext
(
'empty'
)
...
@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
...
@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
def
clean
(
self
,
value
):
def
clean
(
self
,
value
):
return
[
v
.
strip
()
for
v
in
value
.
split
(
','
)
if
v
.
strip
()]
return
[
v
.
strip
()
for
v
in
value
.
split
(
','
)
if
v
.
strip
()]
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
return
query
.
filter
(
self
.
column
.
in_
(
value
))
return
query
.
filter
(
self
.
get_column
(
alias
)
.
in_
(
value
))
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'in list'
)
return
lazy_gettext
(
'in list'
)
class
FilterNotInList
(
FilterInList
):
class
FilterNotInList
(
FilterInList
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added
# NOT IN can exclude NULL values, so "or_ == None" needed to be added
return
query
.
filter
(
or_
(
~
self
.
column
.
in_
(
value
),
self
.
column
==
None
))
column
=
self
.
get_column
(
alias
)
return
query
.
filter
(
or_
(
~
column
.
in_
(
value
),
column
==
None
))
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'not in list'
)
return
lazy_gettext
(
'not in list'
)
...
@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
...
@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
options
,
options
,
data_type
=
'daterangepicker'
)
data_type
=
'daterangepicker'
)
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
start
,
end
=
value
start
,
end
=
value
return
query
.
filter
(
self
.
column
.
between
(
start
,
end
))
return
query
.
filter
(
self
.
get_column
(
alias
)
.
between
(
start
,
end
))
class
DateNotBetweenFilter
(
DateBetweenFilter
):
class
DateNotBetweenFilter
(
DateBetweenFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
start
,
end
=
value
start
,
end
=
value
# ~between() isn't possible until sqlalchemy 1.0.0
# ~between() isn't possible until sqlalchemy 1.0.0
return
query
.
filter
(
not_
(
self
.
column
.
between
(
start
,
end
)))
return
query
.
filter
(
not_
(
self
.
get_column
(
alias
)
.
between
(
start
,
end
)))
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'not between'
)
return
lazy_gettext
(
'not between'
)
...
@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
...
@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
options
,
options
,
data_type
=
'datetimerangepicker'
)
data_type
=
'datetimerangepicker'
)
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
start
,
end
=
value
start
,
end
=
value
return
query
.
filter
(
self
.
column
.
between
(
start
,
end
))
return
query
.
filter
(
self
.
get_column
(
alias
)
.
between
(
start
,
end
))
class
DateTimeNotBetweenFilter
(
DateTimeBetweenFilter
):
class
DateTimeNotBetweenFilter
(
DateTimeBetweenFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
start
,
end
=
value
start
,
end
=
value
return
query
.
filter
(
not_
(
self
.
column
.
between
(
start
,
end
)))
return
query
.
filter
(
not_
(
self
.
get_column
(
alias
)
.
between
(
start
,
end
)))
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'not between'
)
return
lazy_gettext
(
'not between'
)
...
@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
...
@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
options
,
options
,
data_type
=
'timerangepicker'
)
data_type
=
'timerangepicker'
)
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
start
,
end
=
value
start
,
end
=
value
return
query
.
filter
(
self
.
column
.
between
(
start
,
end
))
return
query
.
filter
(
self
.
get_column
(
alias
)
.
between
(
start
,
end
))
class
TimeNotBetweenFilter
(
TimeBetweenFilter
):
class
TimeNotBetweenFilter
(
TimeBetweenFilter
):
def
apply
(
self
,
query
,
value
):
def
apply
(
self
,
query
,
value
,
alias
=
None
):
start
,
end
=
value
start
,
end
=
value
return
query
.
filter
(
not_
(
self
.
column
.
between
(
start
,
end
)))
return
query
.
filter
(
not_
(
self
.
get_column
(
alias
)
.
between
(
start
,
end
)))
def
operation
(
self
):
def
operation
(
self
):
return
lazy_gettext
(
'not between'
)
return
lazy_gettext
(
'not between'
)
...
...
flask_admin/contrib/sqla/view.py
View file @
207d23fd
import
logging
import
logging
import
warnings
import
warnings
import
inspect
from
sqlalchemy.orm.attributes
import
InstrumentedAttribute
from
sqlalchemy.orm.attributes
import
InstrumentedAttribute
from
sqlalchemy.orm
import
joinedload
from
sqlalchemy.orm
import
joinedload
,
aliased
from
sqlalchemy.sql.expression
import
desc
from
sqlalchemy.sql.expression
import
desc
from
sqlalchemy
import
Column
,
Boolean
,
func
,
or_
from
sqlalchemy
import
Boolean
,
Table
,
func
,
or_
from
sqlalchemy.exc
import
IntegrityError
from
sqlalchemy.exc
import
IntegrityError
from
flask
import
flash
from
flask
import
flash
...
@@ -276,7 +277,6 @@ class ModelView(BaseModelView):
...
@@ -276,7 +277,6 @@ class ModelView(BaseModelView):
self
.
session
=
session
self
.
session
=
session
self
.
_search_fields
=
None
self
.
_search_fields
=
None
self
.
_search_joins
=
[]
self
.
_filter_joins
=
dict
()
self
.
_filter_joins
=
dict
()
...
@@ -322,43 +322,92 @@ class ModelView(BaseModelView):
...
@@ -322,43 +322,92 @@ class ModelView(BaseModelView):
return
field
.
property
.
columns
return
field
.
property
.
columns
def
_get_field_with_path
(
self
,
name
):
def
_get_field_with_path
(
self
,
name
):
join_tables
=
[]
"""
Resolve property by name and figure out its join path.
if
isinstance
(
name
,
string_types
):
Join path might contain both properties and tables.
model
=
self
.
model
"""
path
=
[]
model
=
self
.
model
# For strings, resolve path
if
isinstance
(
name
,
string_types
):
for
attribute
in
name
.
split
(
'.'
):
for
attribute
in
name
.
split
(
'.'
):
value
=
getattr
(
model
,
attribute
)
value
=
getattr
(
model
,
attribute
)
if
(
hasattr
(
value
,
'property'
)
and
if
(
hasattr
(
value
,
'property'
)
and
hasattr
(
value
.
property
,
'direction'
)):
hasattr
(
value
.
property
,
'direction'
)):
model
=
value
.
property
.
mapper
.
class_
model
=
value
.
property
.
mapper
.
class_
table
=
model
.
__table__
table
=
model
.
__table__
if
self
.
_need_join
(
table
):
if
self
.
_need_join
(
table
):
join_tables
.
append
(
tabl
e
)
path
.
append
(
valu
e
)
attr
=
value
attr
=
value
else
:
else
:
attr
=
name
attr
=
name
#
determine joins if Table.column (relation object) is given
#
Determine joins if table.column (relation object) is provided
if
isinstance
(
name
,
InstrumentedAttribute
):
if
isinstance
(
attr
,
InstrumentedAttribute
):
columns
=
self
.
_get_columns_for_field
(
name
)
columns
=
self
.
_get_columns_for_field
(
attr
)
if
len
(
columns
)
>
1
:
if
len
(
columns
)
>
1
:
raise
Exception
(
'Can only handle one column for
%
s'
%
name
)
raise
Exception
(
'Can only handle one column for
%
s'
%
name
)
column
=
columns
[
0
]
column
=
columns
[
0
]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if
self
.
_need_join
(
column
.
table
):
if
self
.
_need_join
(
column
.
table
):
join_tables
.
append
(
column
.
table
)
path
.
append
(
column
.
table
)
return
join_tables
,
attr
return
attr
,
path
def
_need_join
(
self
,
table
):
def
_need_join
(
self
,
table
):
"""
Check if join to a table is necessary.
"""
return
table
not
in
self
.
model
.
_sa_class_manager
.
mapper
.
tables
return
table
not
in
self
.
model
.
_sa_class_manager
.
mapper
.
tables
def
_apply_path_joins
(
self
,
query
,
joins
,
path
,
inner_join
=
True
):
"""
Apply join path to the query.
:param query:
Query to add joins to
:param joins:
List of current joins. Used to avoid joining on same relationship more than once
:param path:
Path to be joined
:param fn:
Join function
"""
last
=
None
if
path
:
for
item
in
path
:
key
=
(
inner_join
,
item
)
alias
=
joins
.
get
(
key
)
if
key
not
in
joins
:
if
not
isinstance
(
item
,
Table
):
alias
=
aliased
(
item
.
property
.
mapper
.
class_
)
fn
=
query
.
join
if
inner_join
else
query
.
outerjoin
if
last
is
None
:
query
=
fn
(
item
)
if
alias
is
None
else
fn
(
alias
,
item
)
else
:
prop
=
getattr
(
last
,
item
.
key
)
query
=
fn
(
prop
)
if
alias
is
None
else
fn
(
alias
,
prop
)
joins
[
key
]
=
alias
last
=
alias
return
query
,
joins
,
last
# Scaffolding
# Scaffolding
def
scaffold_pk
(
self
):
def
scaffold_pk
(
self
):
"""
"""
...
@@ -453,19 +502,19 @@ class ModelView(BaseModelView):
...
@@ -453,19 +502,19 @@ class ModelView(BaseModelView):
for
c
in
self
.
column_sortable_list
:
for
c
in
self
.
column_sortable_list
:
if
isinstance
(
c
,
tuple
):
if
isinstance
(
c
,
tuple
):
join_tables
,
column
=
self
.
_get_field_with_path
(
c
[
1
])
column
,
path
=
self
.
_get_field_with_path
(
c
[
1
])
column_name
=
c
[
0
]
column_name
=
c
[
0
]
elif
isinstance
(
c
,
InstrumentedAttribute
):
elif
isinstance
(
c
,
InstrumentedAttribute
):
join_tables
,
column
=
self
.
_get_field_with_path
(
c
)
column
,
path
=
self
.
_get_field_with_path
(
c
)
column_name
=
str
(
c
)
column_name
=
str
(
c
)
else
:
else
:
join_tables
,
column
=
self
.
_get_field_with_path
(
c
)
column
,
path
=
self
.
_get_field_with_path
(
c
)
column_name
=
c
column_name
=
c
result
[
column_name
]
=
column
result
[
column_name
]
=
column
if
join_tables
:
if
path
:
self
.
_sortable_joins
[
column_name
]
=
join_tables
self
.
_sortable_joins
[
column_name
]
=
path
return
result
return
result
...
@@ -479,26 +528,15 @@ class ModelView(BaseModelView):
...
@@ -479,26 +528,15 @@ class ModelView(BaseModelView):
"""
"""
if
self
.
column_searchable_list
:
if
self
.
column_searchable_list
:
self
.
_search_fields
=
[]
self
.
_search_fields
=
[]
self
.
_search_joins
=
[]
joins
=
set
()
for
p
in
self
.
column_searchable_list
:
for
p
in
self
.
column_searchable_list
:
join_tables
,
attr
=
self
.
_get_field_with_path
(
p
)
attr
,
joins
=
self
.
_get_field_with_path
(
p
)
if
not
attr
:
if
not
attr
:
raise
Exception
(
'Failed to find field for search field:
%
s'
%
p
)
raise
Exception
(
'Failed to find field for search field:
%
s'
%
p
)
for
column
in
self
.
_get_columns_for_field
(
attr
):
for
column
in
self
.
_get_columns_for_field
(
attr
):
column_type
=
type
(
column
.
type
)
.
__name__
self
.
_search_fields
.
append
((
column
,
joins
))
self
.
_search_fields
.
append
(
column
)
# Store joins, avoid duplicates
for
table
in
join_tables
:
if
table
.
name
not
in
joins
:
self
.
_search_joins
.
append
(
table
)
joins
.
add
(
table
.
name
)
return
bool
(
self
.
column_searchable_list
)
return
bool
(
self
.
column_searchable_list
)
...
@@ -507,7 +545,7 @@ class ModelView(BaseModelView):
...
@@ -507,7 +545,7 @@ class ModelView(BaseModelView):
Return list of enabled filters
Return list of enabled filters
"""
"""
join_tables
,
attr
=
self
.
_get_field_with_path
(
name
)
attr
,
joins
=
self
.
_get_field_with_path
(
name
)
if
attr
is
None
:
if
attr
is
None
:
raise
Exception
(
'Failed to find field for filter:
%
s'
%
name
)
raise
Exception
(
'Failed to find field for filter:
%
s'
%
name
)
...
@@ -535,10 +573,11 @@ class ModelView(BaseModelView):
...
@@ -535,10 +573,11 @@ class ModelView(BaseModelView):
if
flt
:
if
flt
:
table
=
column
.
table
table
=
column
.
table
if
join
_table
s
:
if
joins
:
self
.
_filter_joins
[
table
.
name
]
=
join_table
s
self
.
_filter_joins
[
column
]
=
join
s
elif
self
.
_need_join
(
table
):
elif
self
.
_need_join
(
table
):
self
.
_filter_joins
[
table
.
name
]
=
[
table
]
self
.
_filter_joins
[
column
]
=
[
table
]
filters
.
extend
(
flt
)
filters
.
extend
(
flt
)
return
filters
return
filters
...
@@ -563,9 +602,6 @@ class ModelView(BaseModelView):
...
@@ -563,9 +602,6 @@ class ModelView(BaseModelView):
type_name
=
type
(
column
.
type
)
.
__name__
type_name
=
type
(
column
.
type
)
.
__name__
if
join_tables
:
self
.
_filter_joins
[
column
.
table
.
name
]
=
join_tables
flt
=
self
.
filter_converter
.
convert
(
flt
=
self
.
filter_converter
.
convert
(
type_name
,
type_name
,
column
,
column
,
...
@@ -573,8 +609,10 @@ class ModelView(BaseModelView):
...
@@ -573,8 +609,10 @@ class ModelView(BaseModelView):
options
=
self
.
column_choices
.
get
(
name
),
options
=
self
.
column_choices
.
get
(
name
),
)
)
if
flt
and
not
join_tables
and
self
.
_need_join
(
column
.
table
):
if
joins
:
self
.
_filter_joins
[
column
.
table
.
name
]
=
[
column
.
table
]
self
.
_filter_joins
[
column
]
=
joins
elif
self
.
_need_join
(
column
.
table
):
self
.
_filter_joins
[
column
]
=
[
column
.
table
]
return
flt
return
flt
...
@@ -583,7 +621,7 @@ class ModelView(BaseModelView):
...
@@ -583,7 +621,7 @@ class ModelView(BaseModelView):
column
=
filter
.
column
column
=
filter
.
column
if
self
.
_need_join
(
column
.
table
):
if
self
.
_need_join
(
column
.
table
):
self
.
_filter_joins
[
column
.
table
.
name
]
=
[
column
.
table
]
self
.
_filter_joins
[
column
]
=
[
column
.
table
]
return
filter
return
filter
...
@@ -707,27 +745,25 @@ class ModelView(BaseModelView):
...
@@ -707,27 +745,25 @@ class ModelView(BaseModelView):
:param query:
:param query:
Query
Query
:param joins:
:pram joins:
Joins set
Current joins
:param sort_joins:
Sort joins (properties or tables)
:param sort_field:
:param sort_field:
Sort field
Sort field
:param sort_desc:
:param sort_desc:
Ascending or descending
Ascending or descending
"""
"""
# TODO: Preprocessing for joins
if
sort_field
is
not
None
:
# Handle joins
# Handle joins
if
sort_joins
:
query
,
joins
,
alias
=
self
.
_apply_path_joins
(
query
,
joins
,
sort_joins
,
inner_join
=
False
)
for
table
in
sort_joins
:
if
table
.
name
not
in
joins
:
query
=
query
.
outerjoin
(
table
)
joins
.
add
(
table
.
name
)
column
=
sort_field
if
alias
is
None
else
getattr
(
alias
,
sort_field
.
key
)
if
sort_field
is
not
None
:
if
sort_desc
:
if
sort_desc
:
query
=
query
.
order_by
(
desc
(
sort_field
))
query
=
query
.
order_by
(
desc
(
column
))
else
:
else
:
query
=
query
.
order_by
(
sort_field
)
query
=
query
.
order_by
(
column
)
return
query
,
joins
return
query
,
joins
...
@@ -737,12 +773,112 @@ class ModelView(BaseModelView):
...
@@ -737,12 +773,112 @@ class ModelView(BaseModelView):
if
order
is
not
None
:
if
order
is
not
None
:
field
,
direction
=
order
field
,
direction
=
order
join_tables
,
attr
=
self
.
_get_field_with_path
(
field
)
attr
,
joins
=
self
.
_get_field_with_path
(
field
)
return
join_tables
,
attr
,
direction
return
attr
,
joins
,
direction
return
None
return
None
def
_apply_sorting
(
self
,
query
,
joins
,
sort_column
,
sort_desc
):
if
sort_column
is
not
None
:
if
sort_column
in
self
.
_sortable_columns
:
sort_field
=
self
.
_sortable_columns
[
sort_column
]
sort_joins
=
self
.
_sortable_joins
.
get
(
sort_column
)
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
sort_joins
,
sort_field
,
sort_desc
)
else
:
order
=
self
.
_get_default_order
()
if
order
:
sort_field
,
sort_joins
,
sort_desc
=
order
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
sort_joins
,
sort_field
,
sort_desc
)
return
query
,
joins
def
_apply_search
(
self
,
query
,
count_query
,
joins
,
count_joins
,
search
):
"""
Apply search to a query.
"""
terms
=
search
.
split
(
' '
)
for
term
in
terms
:
if
not
term
:
continue
stmt
=
tools
.
parse_like_term
(
term
)
filter_stmt
=
[]
count_filter_stmt
=
[]
for
field
,
path
in
self
.
_search_fields
:
query
,
joins
,
alias
=
self
.
_apply_path_joins
(
query
,
joins
,
path
,
inner_join
=
False
)
count_alias
=
None
if
count_query
is
not
None
:
count_query
,
count_joins
,
count_alias
=
self
.
_apply_path_joins
(
count_query
,
count_joins
,
path
,
inner_join
=
False
)
column
=
field
if
alias
is
None
else
getattr
(
alias
,
field
.
key
)
filter_stmt
.
append
(
column
.
ilike
(
stmt
))
if
count_filter_stmt
is
not
None
:
column
=
field
if
count_alias
is
None
else
getattr
(
count_alias
,
field
.
key
)
count_filter_stmt
.
append
(
column
.
ilike
(
stmt
))
query
=
query
.
filter
(
or_
(
*
filter_stmt
))
if
count_query
is
not
None
:
count_query
=
count_query
.
filter
(
or_
(
*
count_filter_stmt
))
return
query
,
count_query
,
joins
,
count_joins
def
_apply_filters
(
self
,
query
,
count_query
,
joins
,
count_joins
,
filters
):
for
idx
,
flt_name
,
value
in
filters
:
flt
=
self
.
_filters
[
idx
]
alias
=
None
count_alias
=
None
# Figure out joins
if
isinstance
(
flt
,
sqla_filters
.
BaseSQLAFilter
):
path
=
self
.
_filter_joins
.
get
(
flt
.
column
,
[])
query
,
joins
,
alias
=
self
.
_apply_path_joins
(
query
,
joins
,
path
,
inner_join
=
False
)
if
count_query
is
not
None
:
count_query
,
count_joins
,
count_alias
=
self
.
_apply_path_joins
(
count_query
,
count_joins
,
path
,
inner_join
=
False
)
# Clean value .clean() and apply the filter
clean_value
=
flt
.
clean
(
value
)
try
:
query
=
flt
.
apply
(
query
,
clean_value
,
alias
)
except
TypeError
:
spec
=
inspect
.
getargspec
(
flt
.
apply
)
if
len
(
spec
.
args
)
==
2
:
warnings
.
warn
(
'Please update your custom filter
%
s to include additional `alias` parameter.'
%
repr
(
flt
))
else
:
raise
query
=
flt
.
apply
(
query
,
clean_value
)
if
count_query
is
not
None
:
try
:
count_query
=
flt
.
apply
(
count_query
,
clean_value
,
count_alias
)
except
TypeError
:
count_query
=
flt
.
apply
(
count_query
,
clean_value
)
return
query
,
count_query
,
joins
,
count_joins
def
get_list
(
self
,
page
,
sort_column
,
sort_desc
,
search
,
filters
,
execute
=
True
):
def
get_list
(
self
,
page
,
sort_column
,
sort_desc
,
search
,
filters
,
execute
=
True
):
"""
"""
Return models from the database.
Return models from the database.
...
@@ -761,84 +897,45 @@ class ModelView(BaseModelView):
...
@@ -761,84 +897,45 @@ class ModelView(BaseModelView):
List of filter tuples
List of filter tuples
"""
"""
# Will contain names of joined tables to avoid duplicate joins
# Will contain join paths with optional aliased object
joins
=
set
()
joins
=
{}
count_joins
=
{}
query
=
self
.
get_query
()
query
=
self
.
get_query
()
count_query
=
self
.
get_count_query
()
count_query
=
self
.
get_count_query
()
if
not
self
.
simple_list_pager
else
None
# Ignore eager-loaded relations (prevent unnecessary joins)
# Ignore eager-loaded relations (prevent unnecessary joins)
# TODO: Separate join detection for query and count query?
# TODO: Separate join detection for query and count query?
if
hasattr
(
query
,
'_join_entities'
):
if
hasattr
(
query
,
'_join_entities'
):
for
entity
in
query
.
_join_entities
:
for
entity
in
query
.
_join_entities
:
for
table
in
entity
.
tables
:
for
table
in
entity
.
tables
:
joins
.
add
(
table
.
name
)
joins
[
table
]
=
None
# Apply search criteria
# Apply search criteria
if
self
.
_search_supported
and
search
:
if
self
.
_search_supported
and
search
:
# Apply search-related joins
query
,
count_query
,
joins
,
count_joins
=
self
.
_apply_search
(
query
,
if
self
.
_search_joins
:
count_query
,
for
table
in
self
.
_search_joins
:
joins
,
if
table
.
name
not
in
joins
:
count_joins
,
query
=
query
.
outerjoin
(
table
)
search
)
count_query
=
count_query
.
outerjoin
(
table
)
joins
.
add
(
table
.
name
)
# Apply terms
terms
=
search
.
split
(
' '
)
for
term
in
terms
:
if
not
term
:
continue
stmt
=
tools
.
parse_like_term
(
term
)
filter_stmt
=
[
c
.
ilike
(
stmt
)
for
c
in
self
.
_search_fields
]
query
=
query
.
filter
(
or_
(
*
filter_stmt
))
count_query
=
count_query
.
filter
(
or_
(
*
filter_stmt
))
# Apply filters
# Apply filters
if
filters
and
self
.
_filters
:
if
filters
and
self
.
_filters
:
for
idx
,
flt_name
,
value
in
filters
:
query
,
count_query
,
joins
,
count_joins
=
self
.
_apply_filters
(
query
,
flt
=
self
.
_filters
[
idx
]
count_query
,
joins
,
count_joins
,
filters
)
# Figure out joins
# Calculate number of rows if necessary
if
isinstance
(
flt
,
sqla_filters
.
BaseSQLAFilter
):
count
=
count_query
.
scalar
()
if
count_query
else
None
tbl
=
flt
.
column
.
table
.
name
join_tables
=
self
.
_filter_joins
.
get
(
tbl
,
[])
for
table
in
join_tables
:
if
table
.
name
not
in
joins
:
query
=
query
.
join
(
table
)
count_query
=
count_query
.
join
(
table
)
joins
.
add
(
table
.
name
)
# turn into python format with .clean() and apply filter
query
=
flt
.
apply
(
query
,
flt
.
clean
(
value
))
count_query
=
flt
.
apply
(
count_query
,
flt
.
clean
(
value
))
# Calculate number of rows
count
=
count_query
.
scalar
()
# Auto join
# Auto join
for
j
in
self
.
_auto_joins
:
for
j
in
self
.
_auto_joins
:
query
=
query
.
options
(
joinedload
(
j
))
query
=
query
.
options
(
joinedload
(
j
))
# Sorting
# Sorting
if
sort_column
is
not
None
:
query
,
joins
=
self
.
_apply_sorting
(
query
,
joins
,
sort_column
,
sort_desc
)
if
sort_column
in
self
.
_sortable_columns
:
sort_field
=
self
.
_sortable_columns
[
sort_column
]
sort_joins
=
self
.
_sortable_joins
.
get
(
sort_column
)
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
sort_joins
,
sort_field
,
sort_desc
)
else
:
order
=
self
.
_get_default_order
()
if
order
:
sort_joins
,
sort_field
,
sort_desc
=
order
query
,
joins
=
self
.
_order_by
(
query
,
joins
,
sort_joins
,
sort_field
,
sort_desc
)
# Pagination
# Pagination
if
page
is
not
None
:
if
page
is
not
None
:
...
@@ -944,7 +1041,7 @@ class ModelView(BaseModelView):
...
@@ -944,7 +1041,7 @@ class ModelView(BaseModelView):
return
False
return
False
else
:
else
:
self
.
after_model_delete
(
model
)
self
.
after_model_delete
(
model
)
return
True
return
True
# Default model actions
# Default model actions
...
...
flask_admin/form/upload.py
View file @
207d23fd
...
@@ -51,12 +51,21 @@ class FileUploadInput(object):
...
@@ -51,12 +51,21 @@ class FileUploadInput(object):
template
=
self
.
data_template
if
field
.
data
else
self
.
empty_template
template
=
self
.
data_template
if
field
.
data
else
self
.
empty_template
if
field
.
errors
:
template
=
self
.
empty_template
if
field
.
data
and
isinstance
(
field
.
data
,
FileStorage
):
value
=
field
.
data
.
filename
else
:
value
=
field
.
data
or
''
return
HTMLString
(
template
%
{
return
HTMLString
(
template
%
{
'text'
:
html_params
(
type
=
'text'
,
'text'
:
html_params
(
type
=
'text'
,
readonly
=
'readonly'
,
readonly
=
'readonly'
,
value
=
field
.
data
,
value
=
value
,
name
=
field
.
name
),
name
=
field
.
name
),
'file'
:
html_params
(
type
=
'file'
,
'file'
:
html_params
(
type
=
'file'
,
value
=
value
,
**
kwargs
),
**
kwargs
),
'marker'
:
'_
%
s-delete'
%
field
.
name
'marker'
:
'_
%
s-delete'
%
field
.
name
})
})
...
@@ -122,7 +131,7 @@ class FileUploadField(fields.StringField):
...
@@ -122,7 +131,7 @@ class FileUploadField(fields.StringField):
def
__init__
(
self
,
label
=
None
,
validators
=
None
,
def
__init__
(
self
,
label
=
None
,
validators
=
None
,
base_path
=
None
,
relative_path
=
None
,
base_path
=
None
,
relative_path
=
None
,
namegen
=
None
,
allowed_extensions
=
None
,
namegen
=
None
,
allowed_extensions
=
None
,
permission
=
0o666
,
permission
=
0o666
,
allow_overwrite
=
True
,
**
kwargs
):
**
kwargs
):
"""
"""
Constructor.
Constructor.
...
@@ -154,6 +163,11 @@ class FileUploadField(fields.StringField):
...
@@ -154,6 +163,11 @@ class FileUploadField(fields.StringField):
:param allowed_extensions:
:param allowed_extensions:
List of allowed extensions. If not provided, will allow any file.
List of allowed extensions. If not provided, will allow any file.
:param allow_overwrite:
Whether to overwrite existing files in upload directory. Defaults to `True`.
.. versionadded:: 1.1.1
The `allow_overwrite` parameter was added.
"""
"""
self
.
base_path
=
base_path
self
.
base_path
=
base_path
self
.
relative_path
=
relative_path
self
.
relative_path
=
relative_path
...
@@ -161,6 +175,7 @@ class FileUploadField(fields.StringField):
...
@@ -161,6 +175,7 @@ class FileUploadField(fields.StringField):
self
.
namegen
=
namegen
or
namegen_filename
self
.
namegen
=
namegen
or
namegen_filename
self
.
allowed_extensions
=
allowed_extensions
self
.
allowed_extensions
=
allowed_extensions
self
.
permission
=
permission
self
.
permission
=
permission
self
.
_allow_overwrite
=
allow_overwrite
self
.
_should_delete
=
False
self
.
_should_delete
=
False
...
@@ -188,6 +203,11 @@ class FileUploadField(fields.StringField):
...
@@ -188,6 +203,11 @@ class FileUploadField(fields.StringField):
def
pre_validate
(
self
,
form
):
def
pre_validate
(
self
,
form
):
if
self
.
_is_uploaded_file
(
self
.
data
)
and
not
self
.
is_file_allowed
(
self
.
data
.
filename
):
if
self
.
_is_uploaded_file
(
self
.
data
)
and
not
self
.
is_file_allowed
(
self
.
data
.
filename
):
raise
ValidationError
(
gettext
(
'Invalid file extension'
))
raise
ValidationError
(
gettext
(
'Invalid file extension'
))
# Handle overwriting existing content
if
not
self
.
_is_uploaded_file
(
self
.
data
):
return
if
self
.
_allow_overwrite
==
False
and
os
.
path
.
exists
(
self
.
_get_path
(
self
.
data
.
filename
)):
raise
ValidationError
(
gettext
(
'File "
%
s" already exists.'
%
self
.
data
.
filename
))
def
process
(
self
,
formdata
,
data
=
unset_value
):
def
process
(
self
,
formdata
,
data
=
unset_value
):
if
formdata
:
if
formdata
:
...
@@ -253,6 +273,9 @@ class FileUploadField(fields.StringField):
...
@@ -253,6 +273,9 @@ class FileUploadField(fields.StringField):
if
not
op
.
exists
(
op
.
dirname
(
path
)):
if
not
op
.
exists
(
op
.
dirname
(
path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
self
.
permission
|
0o111
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
self
.
permission
|
0o111
)
if
self
.
_allow_overwrite
==
False
and
os
.
path
.
exists
(
path
):
raise
ValueError
(
gettext
(
'File "
%
s" already exists.'
%
path
))
data
.
save
(
path
)
data
.
save
(
path
)
return
filename
return
filename
...
...
flask_admin/model/base.py
View file @
207d23fd
...
@@ -15,7 +15,7 @@ from flask_admin.form import BaseForm, FormOpts, rules
...
@@ -15,7 +15,7 @@ from flask_admin.form import BaseForm, FormOpts, rules
from
flask_admin.model
import
filters
,
typefmt
from
flask_admin.model
import
filters
,
typefmt
from
flask_admin.actions
import
ActionsMixin
from
flask_admin.actions
import
ActionsMixin
from
flask_admin.helpers
import
(
get_form_data
,
validate_form_on_submit
,
from
flask_admin.helpers
import
(
get_form_data
,
validate_form_on_submit
,
get_redirect_target
,
flash_errors
)
get_redirect_target
,
flash_errors
)
from
flask_admin.tools
import
rec_getattr
from
flask_admin.tools
import
rec_getattr
from
flask_admin._backwards
import
ObsoleteAttr
from
flask_admin._backwards
import
ObsoleteAttr
from
flask_admin._compat
import
iteritems
,
OrderedDict
,
as_unicode
from
flask_admin._compat
import
iteritems
,
OrderedDict
,
as_unicode
...
@@ -321,6 +321,12 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -321,6 +321,12 @@ class BaseModelView(BaseView, ActionsMixin):
Controls if the primary key should be displayed in the list view.
Controls if the primary key should be displayed in the list view.
"""
"""
simple_list_pager
=
False
"""
Enable or disable simple list pager.
If enabled, model interface would not run count query and will only show prev/next pager buttons.
"""
form
=
None
form
=
None
"""
"""
Form class. Override if you want to use custom form for your model.
Form class. Override if you want to use custom form for your model.
...
@@ -563,28 +569,30 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -563,28 +569,30 @@ class BaseModelView(BaseView, ActionsMixin):
:param menu_icon_value:
:param menu_icon_value:
Icon glyph name or URL, depending on `menu_icon_type` setting
Icon glyph name or URL, depending on `menu_icon_type` setting
"""
"""
self
.
model
=
model
# If name not provided, it is model name
# If name not provided, it is model name
if
name
is
None
:
if
name
is
None
:
name
=
'
%
s'
%
self
.
_prettify_class_name
(
model
.
__name__
)
name
=
'
%
s'
%
self
.
_prettify_class_name
(
model
.
__name__
)
# If endpoint not provided, it is model name
if
endpoint
is
None
:
endpoint
=
model
.
__name__
.
lower
()
super
(
BaseModelView
,
self
)
.
__init__
(
name
,
category
,
endpoint
,
url
,
static_folder
,
super
(
BaseModelView
,
self
)
.
__init__
(
name
,
category
,
endpoint
,
url
,
static_folder
,
menu_class_name
=
menu_class_name
,
menu_class_name
=
menu_class_name
,
menu_icon_type
=
menu_icon_type
,
menu_icon_type
=
menu_icon_type
,
menu_icon_value
=
menu_icon_value
)
menu_icon_value
=
menu_icon_value
)
self
.
model
=
model
# Actions
# Actions
self
.
init_actions
()
self
.
init_actions
()
# Scaffolding
# Scaffolding
self
.
_refresh_cache
()
self
.
_refresh_cache
()
# Endpoint
def
_get_endpoint
(
self
,
endpoint
):
if
endpoint
:
return
super
(
BaseModelView
,
self
)
.
_get_endpoint
(
endpoint
)
return
self
.
model
.
__name__
.
lower
()
# Caching
# Caching
def
_refresh_forms_cache
(
self
):
def
_refresh_forms_cache
(
self
):
# Forms
# Forms
...
@@ -617,7 +625,7 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -617,7 +625,7 @@ class BaseModelView(BaseView, ActionsMixin):
self
.
_filter_groups
[
flt
.
name
]
.
append
({
self
.
_filter_groups
[
flt
.
name
]
.
append
({
'index'
:
i
,
'index'
:
i
,
'arg'
:
self
.
get_filter_arg
(
i
,
flt
),
'arg'
:
self
.
get_filter_arg
(
i
,
flt
),
'operation'
:
as_unicode
(
flt
.
operation
()
),
'operation'
:
flt
.
operation
(
),
'options'
:
flt
.
get_options
(
self
)
or
None
,
'options'
:
flt
.
get_options
(
self
)
or
None
,
'type'
:
flt
.
data_type
'type'
:
flt
.
data_type
})
})
...
@@ -852,6 +860,27 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -852,6 +860,27 @@ class BaseModelView(BaseView, ActionsMixin):
else
:
else
:
return
str
(
index
)
return
str
(
index
)
def
_get_filter_groups
(
self
):
"""
Returns non-lazy version of filter strings
"""
if
self
.
_filter_groups
:
results
=
OrderedDict
()
for
key
,
value
in
iteritems
(
self
.
_filter_groups
):
items
=
[]
for
item
in
value
:
copy
=
dict
(
item
)
copy
[
'operation'
]
=
as_unicode
(
copy
[
'operation'
])
items
.
append
(
copy
)
results
[
key
]
=
items
return
results
return
None
# Form helpers
# Form helpers
def
scaffold_form
(
self
):
def
scaffold_form
(
self
):
"""
"""
...
@@ -1018,7 +1047,7 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -1018,7 +1047,7 @@ class BaseModelView(BaseView, ActionsMixin):
missing_fields
.
append
(
field
.
name
)
missing_fields
.
append
(
field
.
name
)
return
missing_fields
return
missing_fields
def
_show_missing_fields_warning
(
self
,
text
):
def
_show_missing_fields_warning
(
self
,
text
):
warnings
.
warn
(
text
)
warnings
.
warn
(
text
)
...
@@ -1200,7 +1229,7 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -1200,7 +1229,7 @@ class BaseModelView(BaseView, ActionsMixin):
By default do nothing.
By default do nothing.
"""
"""
pass
pass
def
after_model_delete
(
self
,
model
):
def
after_model_delete
(
self
,
model
):
"""
"""
Perform some actions after a model was deleted and
Perform some actions after a model was deleted and
...
@@ -1214,7 +1243,7 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -1214,7 +1243,7 @@ class BaseModelView(BaseView, ActionsMixin):
:param model:
:param model:
Model that was deleted
Model that was deleted
"""
"""
pass
pass
def
on_form_prefill
(
self
,
form
,
id
):
def
on_form_prefill
(
self
,
form
,
id
):
"""
"""
...
@@ -1463,9 +1492,12 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -1463,9 +1492,12 @@ class BaseModelView(BaseView, ActionsMixin):
view_args
.
search
,
view_args
.
filters
)
view_args
.
search
,
view_args
.
filters
)
# Calculate number of pages
# Calculate number of pages
num_pages
=
count
//
self
.
page_size
if
count
is
not
None
:
if
count
%
self
.
page_size
!=
0
:
num_pages
=
count
//
self
.
page_size
num_pages
+=
1
if
count
%
self
.
page_size
!=
0
:
num_pages
+=
1
else
:
num_pages
=
None
# Various URL generation helpers
# Various URL generation helpers
def
pager_url
(
p
):
def
pager_url
(
p
):
...
@@ -1508,6 +1540,7 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -1508,6 +1540,7 @@ class BaseModelView(BaseView, ActionsMixin):
pager_url
=
pager_url
,
pager_url
=
pager_url
,
num_pages
=
num_pages
,
num_pages
=
num_pages
,
page
=
view_args
.
page
,
page
=
view_args
.
page
,
page_size
=
self
.
page_size
,
# Sorting
# Sorting
sort_column
=
view_args
.
sort
,
sort_column
=
view_args
.
sort
,
...
@@ -1521,7 +1554,7 @@ class BaseModelView(BaseView, ActionsMixin):
...
@@ -1521,7 +1554,7 @@ class BaseModelView(BaseView, ActionsMixin):
# Filters
# Filters
filters
=
self
.
_filters
,
filters
=
self
.
_filters
,
filter_groups
=
self
.
_
filter_groups
,
filter_groups
=
self
.
_
get_filter_groups
()
,
active_filters
=
view_args
.
filters
,
active_filters
=
view_args
.
filters
,
# Actions
# Actions
...
...
flask_admin/templates/bootstrap2/admin/layout.html
View file @
207d23fd
...
@@ -7,9 +7,9 @@
...
@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %}
{% elif icon_type == 'fa' %}
<i
class=
"fa fa-{{ icon_value }}"
></i>
<i
class=
"fa fa-{{ icon_value }}"
></i>
{% elif icon_type == 'image' %}
{% elif icon_type == 'image' %}
<img
src=
"{{ url_for('static', filename=icon_value) }}"
alt=
"menu image"
>
</img>
<img
src=
"{{ url_for('static', filename=icon_value) }}"
alt=
"menu image"
>
{% elif icon_type == 'image-url' %}
{% elif icon_type == 'image-url' %}
<img
src=
"item.icon_value"
alt=
"menu image"
>
</img>
<img
src=
"item.icon_value"
alt=
"menu image"
>
{% endif %}
{% endif %}
{% endif %}
{% endif %}
{%- endmacro %}
{%- endmacro %}
...
...
flask_admin/templates/bootstrap2/admin/lib.html
View file @
207d23fd
...
@@ -76,6 +76,31 @@
...
@@ -76,6 +76,31 @@
{% endif %}
{% endif %}
{%- endmacro %}
{%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<div
class=
"pagination"
>
<ul>
{% if page > 0 %}
<li>
<a
href=
"{{ generator(page - 1) }}"
>
<
</a>
</li>
{% else %}
<li
class=
"disabled"
>
<a
href=
"{{ generator(0) }}"
>
<
</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a
href=
"{{ generator(page + 1) }}"
>
>
</a>
</li>
{% else %}
<li
class=
"disabled"
>
<a
href=
"{{ generator(page) }}"
>
>
</a>
</li>
{% endif %}
</ul>
</div>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #}
{# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %}
{% set direct_error = h.is_field_error(field.errors) %}
...
...
flask_admin/templates/bootstrap2/admin/model/list.html
View file @
207d23fd
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
{% block model_menu_bar %}
{% block model_menu_bar %}
<ul
class=
"nav nav-tabs actions-nav"
>
<ul
class=
"nav nav-tabs actions-nav"
>
<li
class=
"active"
>
<li
class=
"active"
>
<a
href=
"javascript:void(0)"
>
{{ _gettext('List') }}
({{ count }})
</a>
<a
href=
"javascript:void(0)"
>
{{ _gettext('List') }}
{% if count %} ({{ count }}){% endif %}
</a>
</li>
</li>
{% if admin_view.can_create %}
{% if admin_view.can_create %}
<li>
<li>
...
@@ -110,7 +110,11 @@
...
@@ -110,7 +110,11 @@
<form
class=
"icon"
method=
"POST"
action=
"{{ get_url('.delete_view') }}"
>
<form
class=
"icon"
method=
"POST"
action=
"{{ get_url('.delete_view') }}"
>
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }}
{{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }}
{{ delete_form.csrf_token }}
{% elif csrf_token %}
<input
type=
"hidden"
name=
"csrf_token"
value=
"{{ csrf_token() }}"
/>
{% endif %}
<button
onclick=
"return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');"
title=
"{{ _gettext('Delete record') }}"
>
<button
onclick=
"return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');"
title=
"{{ _gettext('Delete record') }}"
>
<i
class=
"fa fa-trash icon-trash"
></i>
<i
class=
"fa fa-trash icon-trash"
></i>
</button>
</button>
...
@@ -147,7 +151,13 @@
...
@@ -147,7 +151,13 @@
</tr>
</tr>
{% endfor %}
{% endfor %}
</table>
</table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }}
{{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %}
{% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }}
{{ actionlib.form(actions, get_url('.action_view')) }}
...
...
flask_admin/templates/bootstrap3/admin/layout.html
View file @
207d23fd
...
@@ -7,9 +7,9 @@
...
@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %}
{% elif icon_type == 'fa' %}
<i
class=
"fa {{ icon_value }}"
></i>
<i
class=
"fa {{ icon_value }}"
></i>
{% elif icon_type == 'image' %}
{% elif icon_type == 'image' %}
<img
src=
"{{ url_for('static', filename=icon_value) }}"
alt=
"menu image"
>
</img>
<img
src=
"{{ url_for('static', filename=icon_value) }}"
alt=
"menu image"
>
{% elif icon_type == 'image-url' %}
{% elif icon_type == 'image-url' %}
<img
src=
"item.icon_value"
alt=
"menu image"
>
</img>
<img
src=
"item.icon_value"
alt=
"menu image"
>
{% endif %}
{% endif %}
{% endif %}
{% endif %}
{%- endmacro %}
{%- endmacro %}
...
...
flask_admin/templates/bootstrap3/admin/lib.html
View file @
207d23fd
...
@@ -74,6 +74,29 @@
...
@@ -74,6 +74,29 @@
{% endif %}
{% endif %}
{%- endmacro %}
{%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<ul
class=
"pagination"
>
{% if page > 0 %}
<li>
<a
href=
"{{ generator(page - 1) }}"
>
<
</a>
</li>
{% else %}
<li
class=
"disabled"
>
<a
href=
"{{ generator(0) }}"
>
<
</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a
href=
"{{ generator(page + 1) }}"
>
>
</a>
</li>
{% else %}
<li
class=
"disabled"
>
<a
href=
"{{ generator(page) }}"
>
>
</a>
</li>
{% endif %}
</ul>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #}
{# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %}
{% set direct_error = h.is_field_error(field.errors) %}
...
...
flask_admin/templates/bootstrap3/admin/model/list.html
View file @
207d23fd
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
{% block model_menu_bar %}
{% block model_menu_bar %}
<ul
class=
"nav nav-tabs actions-nav"
>
<ul
class=
"nav nav-tabs actions-nav"
>
<li
class=
"active"
>
<li
class=
"active"
>
<a
href=
"javascript:void(0)"
>
{{ _gettext('List') }}
({{ count }})
</a>
<a
href=
"javascript:void(0)"
>
{{ _gettext('List') }}
{% if count %} ({{ count }}){% endif %}
</a>
</li>
</li>
{% if admin_view.can_create %}
{% if admin_view.can_create %}
<li>
<li>
...
@@ -110,7 +110,11 @@
...
@@ -110,7 +110,11 @@
<form
class=
"icon"
method=
"POST"
action=
"{{ get_url('.delete_view') }}"
>
<form
class=
"icon"
method=
"POST"
action=
"{{ get_url('.delete_view') }}"
>
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }}
{{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }}
{{ delete_form.csrf_token }}
{% elif csrf_token %}
<input
type=
"hidden"
name=
"csrf_token"
value=
"{{ csrf_token() }}"
/>
{% endif %}
<button
onclick=
"return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');"
title=
"Delete record"
>
<button
onclick=
"return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');"
title=
"Delete record"
>
<span
class=
"fa fa-trash glyphicon glyphicon-trash"
></span>
<span
class=
"fa fa-trash glyphicon glyphicon-trash"
></span>
</button>
</button>
...
@@ -146,7 +150,13 @@
...
@@ -146,7 +150,13 @@
</tr>
</tr>
{% endfor %}
{% endfor %}
</table>
</table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }}
{{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %}
{% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }}
{{ actionlib.form(actions, get_url('.action_view')) }}
...
...
flask_admin/tests/mongoengine/test_basic.py
View file @
207d23fd
...
@@ -948,3 +948,20 @@ def test_form_args_embeddeddoc():
...
@@ -948,3 +948,20 @@ def test_form_args_embeddeddoc():
eq_
(
form
.
timestamp
.
label
.
text
,
'Last Updated Time'
)
eq_
(
form
.
timestamp
.
label
.
text
,
'Last Updated Time'
)
# This is the failure
# This is the failure
eq_
(
form
.
info
.
label
.
text
,
'Information'
)
eq_
(
form
.
info
.
label
.
text
,
'Information'
)
def
test_simple_list_pager
():
app
,
db
,
admin
=
setup
()
Model1
,
_
=
create_models
(
db
)
class
TestModelView
(
CustomModelView
):
simple_list_pager
=
True
def
get_count_query
(
self
):
assert
False
view
=
TestModelView
(
Model1
)
admin
.
add_view
(
view
)
count
,
data
=
view
.
get_list
(
0
,
None
,
None
,
None
,
None
)
ok_
(
count
is
None
)
flask_admin/tests/sqla/test_basic.py
View file @
207d23fd
...
@@ -12,6 +12,7 @@ from . import setup
...
@@ -12,6 +12,7 @@ from . import setup
from
datetime
import
datetime
,
time
,
date
from
datetime
import
datetime
,
time
,
date
class
CustomModelView
(
ModelView
):
class
CustomModelView
(
ModelView
):
def
__init__
(
self
,
model
,
session
,
def
__init__
(
self
,
model
,
session
,
name
=
None
,
category
=
None
,
endpoint
=
None
,
url
=
None
,
name
=
None
,
category
=
None
,
endpoint
=
None
,
url
=
None
,
...
@@ -259,10 +260,11 @@ def test_column_searchable_list():
...
@@ -259,10 +260,11 @@ def test_column_searchable_list():
eq_
(
view
.
_search_supported
,
True
)
eq_
(
view
.
_search_supported
,
True
)
eq_
(
len
(
view
.
_search_fields
),
2
)
eq_
(
len
(
view
.
_search_fields
),
2
)
ok_
(
isinstance
(
view
.
_search_fields
[
0
],
db
.
Column
))
ok_
(
isinstance
(
view
.
_search_fields
[
1
],
db
.
Column
))
ok_
(
isinstance
(
view
.
_search_fields
[
0
][
0
],
db
.
Column
))
eq_
(
view
.
_search_fields
[
0
]
.
name
,
'string_field'
)
ok_
(
isinstance
(
view
.
_search_fields
[
1
][
0
],
db
.
Column
))
eq_
(
view
.
_search_fields
[
1
]
.
name
,
'int_field'
)
eq_
(
view
.
_search_fields
[
0
][
0
]
.
name
,
'string_field'
)
eq_
(
view
.
_search_fields
[
1
][
0
]
.
name
,
'int_field'
)
db
.
session
.
add
(
Model2
(
'model1-test'
,
5000
))
db
.
session
.
add
(
Model2
(
'model1-test'
,
5000
))
db
.
session
.
add
(
Model2
(
'model2-test'
,
9000
))
db
.
session
.
add
(
Model2
(
'model2-test'
,
9000
))
...
@@ -417,6 +419,8 @@ def test_column_filters():
...
@@ -417,6 +419,8 @@ def test_column_filters():
)
)
admin
.
add_view
(
view
)
admin
.
add_view
(
view
)
client
=
app
.
test_client
()
eq_
(
len
(
view
.
_filters
),
7
)
eq_
(
len
(
view
.
_filters
),
7
)
eq_
([(
f
[
'index'
],
f
[
'operation'
])
for
f
in
view
.
_filter_groups
[
u'Test1'
]],
eq_
([(
f
[
'index'
],
f
[
'operation'
])
for
f
in
view
.
_filter_groups
[
u'Test1'
]],
...
@@ -515,21 +519,23 @@ def test_column_filters():
...
@@ -515,21 +519,23 @@ def test_column_filters():
fill_db
(
db
,
Model1
,
Model2
)
fill_db
(
db
,
Model1
,
Model2
)
client
=
app
.
test_client
()
# Test equals
rv
=
client
.
get
(
'/admin/model1/?flt0_0=test1_val_1'
)
rv
=
client
.
get
(
'/admin/model1/?flt0_0=test1_val_1'
)
eq_
(
rv
.
status_code
,
200
)
eq_
(
rv
.
status_code
,
200
)
data
=
rv
.
data
.
decode
(
'utf-8'
)
data
=
rv
.
data
.
decode
(
'utf-8'
)
# the filter value is always in "data"
# the filter value is always in "data"
# need to check a different column than test1 for the expected row
# need to check a different column than test1 for the expected row
ok_
(
'test2_val_1'
in
data
)
ok_
(
'test2_val_1'
in
data
)
ok_
(
'test1_val_2'
not
in
data
)
ok_
(
'test1_val_2'
not
in
data
)
# Test NOT IN filter
rv
=
client
.
get
(
'/admin/model1/?flt0_6=test1_val_1'
)
rv
=
client
.
get
(
'/admin/model1/?flt0_6=test1_val_1'
)
eq_
(
rv
.
status_code
,
200
)
eq_
(
rv
.
status_code
,
200
)
data
=
rv
.
data
.
decode
(
'utf-8'
)
data
=
rv
.
data
.
decode
(
'utf-8'
)
ok_
(
'test2_val_1'
not
in
data
)
ok_
(
'test1_val_2'
in
data
)
ok_
(
'test1_val_2'
in
data
)
ok_
(
'test2_val_1'
not
in
data
)
# Test string filter
# Test string filter
view
=
CustomModelView
(
Model1
,
db
.
session
,
view
=
CustomModelView
(
Model1
,
db
.
session
,
...
@@ -1103,9 +1109,11 @@ def test_column_filters():
...
@@ -1103,9 +1109,11 @@ def test_column_filters():
rv
=
client
.
get
(
'/admin/_relation_test/?flt1_0=test1_val_1'
)
rv
=
client
.
get
(
'/admin/_relation_test/?flt1_0=test1_val_1'
)
data
=
rv
.
data
.
decode
(
'utf-8'
)
data
=
rv
.
data
.
decode
(
'utf-8'
)
ok_
(
'test1_val_1'
in
data
)
ok_
(
'test1_val_1'
in
data
)
ok_
(
'test1_val_2'
not
in
data
)
ok_
(
'test1_val_2'
not
in
data
)
def
test_url_args
():
def
test_url_args
():
app
,
db
,
admin
=
setup
()
app
,
db
,
admin
=
setup
()
...
@@ -1680,3 +1688,123 @@ def test_safe_redirect():
...
@@ -1680,3 +1688,123 @@ def test_safe_redirect():
assert_true
(
rv
.
location
.
startswith
(
'http://localhost/admin/model1/edit/'
))
assert_true
(
rv
.
location
.
startswith
(
'http://localhost/admin/model1/edit/'
))
assert_true
(
'url=
%2
Fadmin
%2
Fmodel1
%2
F'
in
rv
.
location
)
assert_true
(
'url=
%2
Fadmin
%2
Fmodel1
%2
F'
in
rv
.
location
)
assert_true
(
'id=2'
in
rv
.
location
)
assert_true
(
'id=2'
in
rv
.
location
)
def
test_simple_list_pager
():
app
,
db
,
admin
=
setup
()
Model1
,
_
=
create_models
(
db
)
db
.
create_all
()
class
TestModelView
(
CustomModelView
):
simple_list_pager
=
True
def
get_count_query
(
self
):
assert
False
view
=
TestModelView
(
Model1
,
db
.
session
)
admin
.
add_view
(
view
)
count
,
data
=
view
.
get_list
(
0
,
None
,
None
,
None
,
None
)
assert_true
(
count
is
None
)
def
test_advanced_joins
():
app
,
db
,
admin
=
setup
()
class
Model1
(
db
.
Model
):
id
=
db
.
Column
(
db
.
Integer
,
primary_key
=
True
)
val1
=
db
.
Column
(
db
.
String
(
20
))
test
=
db
.
Column
(
db
.
String
(
20
))
class
Model2
(
db
.
Model
):
id
=
db
.
Column
(
db
.
Integer
,
primary_key
=
True
)
val2
=
db
.
Column
(
db
.
String
(
20
))
model1_id
=
db
.
Column
(
db
.
Integer
,
db
.
ForeignKey
(
Model1
.
id
))
model1
=
db
.
relationship
(
Model1
,
backref
=
'model2'
)
class
Model3
(
db
.
Model
):
id
=
db
.
Column
(
db
.
Integer
,
primary_key
=
True
)
val2
=
db
.
Column
(
db
.
String
(
20
))
model2_id
=
db
.
Column
(
db
.
Integer
,
db
.
ForeignKey
(
Model2
.
id
))
model2
=
db
.
relationship
(
Model2
,
backref
=
'model3'
)
view1
=
CustomModelView
(
Model1
,
db
.
session
)
admin
.
add_view
(
view1
)
view2
=
CustomModelView
(
Model2
,
db
.
session
)
admin
.
add_view
(
view2
)
view3
=
CustomModelView
(
Model3
,
db
.
session
)
admin
.
add_view
(
view3
)
# Test joins
attr
,
path
=
view2
.
_get_field_with_path
(
'model1.val1'
)
eq_
(
attr
,
Model1
.
val1
)
eq_
(
path
,
[
Model2
.
model1
])
attr
,
path
=
view1
.
_get_field_with_path
(
'model2.val2'
)
eq_
(
attr
,
Model2
.
val2
)
eq_
(
id
(
path
[
0
]),
id
(
Model1
.
model2
))
attr
,
path
=
view3
.
_get_field_with_path
(
'model2.model1.val1'
)
eq_
(
attr
,
Model1
.
val1
)
eq_
(
path
,
[
Model3
.
model2
,
Model2
.
model1
])
# Test how joins are applied
query
=
view3
.
get_query
()
joins
=
{}
q1
,
joins
,
alias
=
view3
.
_apply_path_joins
(
query
,
joins
,
path
)
ok_
((
True
,
Model3
.
model2
)
in
joins
)
ok_
((
True
,
Model2
.
model1
)
in
joins
)
ok_
(
alias
is
not
None
)
# Check if another join would use same path
attr
,
path
=
view2
.
_get_field_with_path
(
'model1.test'
)
q2
,
joins
,
alias
=
view2
.
_apply_path_joins
(
query
,
joins
,
path
)
eq_
(
len
(
joins
),
2
)
for
p
in
q2
.
_join_entities
:
ok_
(
p
in
q1
.
_join_entities
)
ok_
(
alias
is
not
None
)
# Check if normal properties are supported by _get_field_with_path
attr
,
path
=
view2
.
_get_field_with_path
(
Model1
.
test
)
eq_
(
attr
,
Model1
.
test
)
eq_
(
path
,
[
Model1
.
__table__
])
q3
,
joins
,
alias
=
view2
.
_apply_path_joins
(
view2
.
get_query
(),
joins
,
path
)
eq_
(
len
(
joins
),
3
)
ok_
(
alias
is
None
)
def
test_multipath_joins
():
app
,
db
,
admin
=
setup
()
class
Model1
(
db
.
Model
):
id
=
db
.
Column
(
db
.
Integer
,
primary_key
=
True
)
val1
=
db
.
Column
(
db
.
String
(
20
))
test
=
db
.
Column
(
db
.
String
(
20
))
class
Model2
(
db
.
Model
):
id
=
db
.
Column
(
db
.
Integer
,
primary_key
=
True
)
val2
=
db
.
Column
(
db
.
String
(
20
))
first_id
=
db
.
Column
(
db
.
Integer
,
db
.
ForeignKey
(
Model1
.
id
))
first
=
db
.
relationship
(
Model1
,
backref
=
'first'
,
foreign_keys
=
[
first_id
])
second_id
=
db
.
Column
(
db
.
Integer
,
db
.
ForeignKey
(
Model1
.
id
))
second
=
db
.
relationship
(
Model1
,
backref
=
'second'
,
foreign_keys
=
[
second_id
])
db
.
create_all
()
view
=
CustomModelView
(
Model2
,
db
.
session
,
filters
=
[
'first.test'
])
admin
.
add_view
(
view
)
client
=
app
.
test_client
()
rv
=
client
.
get
(
'/admin/model2/'
)
eq_
(
rv
.
status_code
,
200
)
flask_admin/tests/test_base.py
View file @
207d23fd
...
@@ -76,7 +76,7 @@ def test_baseview_defaults():
...
@@ -76,7 +76,7 @@ def test_baseview_defaults():
view
=
MockView
()
view
=
MockView
()
eq_
(
view
.
name
,
None
)
eq_
(
view
.
name
,
None
)
eq_
(
view
.
category
,
None
)
eq_
(
view
.
category
,
None
)
eq_
(
view
.
endpoint
,
None
)
eq_
(
view
.
endpoint
,
'mockview'
)
eq_
(
view
.
url
,
None
)
eq_
(
view
.
url
,
None
)
eq_
(
view
.
static_folder
,
None
)
eq_
(
view
.
static_folder
,
None
)
eq_
(
view
.
admin
,
None
)
eq_
(
view
.
admin
,
None
)
...
@@ -388,3 +388,12 @@ def test_menu_links():
...
@@ -388,3 +388,12 @@ def test_menu_links():
def
check_class_name
():
def
check_class_name
():
view
=
MockView
()
view
=
MockView
()
eq_
(
view
.
name
,
'Mock View'
)
eq_
(
view
.
name
,
'Mock View'
)
def
check_endpoint
():
class
CustomView
(
MockView
):
def
_get_endpoint
(
self
,
endpoint
):
return
'admin.'
+
super
(
CustomView
,
self
)
.
_get_endpoint
(
endpoint
)
view
=
CustomView
()
eq_
(
view
.
endpoint
,
'admin.customview'
)
flask_admin/tests/test_form_upload.py
View file @
207d23fd
...
@@ -40,6 +40,9 @@ def test_upload_field():
...
@@ -40,6 +40,9 @@ def test_upload_field():
class
TestForm
(
form
.
BaseForm
):
class
TestForm
(
form
.
BaseForm
):
upload
=
form
.
FileUploadField
(
'Upload'
,
base_path
=
path
)
upload
=
form
.
FileUploadField
(
'Upload'
,
base_path
=
path
)
class
TestNoOverWriteForm
(
form
.
BaseForm
):
upload
=
form
.
FileUploadField
(
'Upload'
,
base_path
=
path
,
allow_overwrite
=
False
)
class
Dummy
(
object
):
class
Dummy
(
object
):
pass
pass
...
@@ -74,6 +77,7 @@ def test_upload_field():
...
@@ -74,6 +77,7 @@ def test_upload_field():
# Check delete
# Check delete
with
app
.
test_request_context
(
method
=
'POST'
,
data
=
{
'_upload-delete'
:
'checked'
}):
with
app
.
test_request_context
(
method
=
'POST'
,
data
=
{
'_upload-delete'
:
'checked'
}):
my_form
=
TestForm
(
helpers
.
get_form_data
())
my_form
=
TestForm
(
helpers
.
get_form_data
())
ok_
(
my_form
.
validate
())
ok_
(
my_form
.
validate
())
...
@@ -83,6 +87,24 @@ def test_upload_field():
...
@@ -83,6 +87,24 @@ def test_upload_field():
ok_
(
not
op
.
exists
(
op
.
join
(
path
,
'test2.txt'
)))
ok_
(
not
op
.
exists
(
op
.
join
(
path
,
'test2.txt'
)))
# Check overwrite
_remove_testfiles
()
my_form_ow
=
TestNoOverWriteForm
()
with
app
.
test_request_context
(
method
=
'POST'
,
data
=
{
'upload'
:
(
BytesIO
(
b
'Hullo'
),
'test1.txt'
)}):
my_form_ow
=
TestNoOverWriteForm
(
helpers
.
get_form_data
())
ok_
(
my_form_ow
.
validate
())
my_form_ow
.
populate_obj
(
dummy
)
eq_
(
dummy
.
upload
,
'test1.txt'
)
ok_
(
op
.
exists
(
op
.
join
(
path
,
'test1.txt'
)))
with
app
.
test_request_context
(
method
=
'POST'
,
data
=
{
'upload'
:
(
BytesIO
(
b
'Hullo'
),
'test1.txt'
)}):
my_form_ow
=
TestNoOverWriteForm
(
helpers
.
get_form_data
())
ok_
(
not
my_form_ow
.
validate
())
_remove_testfiles
()
def
test_image_upload_field
():
def
test_image_upload_field
():
app
=
Flask
(
__name__
)
app
=
Flask
(
__name__
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment