Commit 207d23fd authored by Petrus J.v.Rensburg's avatar Petrus J.v.Rensburg

Merge branch 'master' into examples

parents 2d3d1d63 6b6fe519
Changelog
=========
1.2.0
-----
* Codebase was migrated to Flask-Admin GitHub organization
* Automatically inject Flask-WTF CSRF token to internal Flask-Admin forms
* MapBox v4 support for GeoAlchemy
* Updated translations with help of CrowdIn
* Show warning if field was ignored in form rendering rules
* Simple AppEngine backend
* Optional support for Font Awesome in templates and menus
* Bug fixes
1.1.0
-----
......@@ -43,21 +55,3 @@ Highlights:
* Support for newer wtforms versions
* `form_rules` property that affects both create and edit forms
* Lots of bugfixes
1.0.7
-----
Full change log and feature walkthrough can be found `here <http://mrjoes.github.io/2013/10/21/flask-admin-107.html>`_.
Highlights:
* Python 3 support
* AJAX-based foreign-key data loading for all backends
* New, optional, rule-based form rendering engine
* MongoEngine fixes and features: GridFS support, nested subdocument configuration and much more
* Greatly improved and more configurable inline models
* New WTForms fields and widgets
* `form_extra_columns` allows adding custom columns to the form declaratively
* Redis cli
* SQLAlchemy backend can handle inherited models with multiple PKs
* Lots of bug fixes
......@@ -101,7 +101,8 @@ class FileView(sqla.ModelView):
form_args = {
'path': {
'label': 'File',
'base_path': file_path
'base_path': file_path,
'allow_overwrite': False
}
}
......
......@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
# List of columns that can be sorted. For 'user' column, use User.username as
# a column.
column_sortable_list = ('title', ('user', User.username), 'date')
column_sortable_list = ('title', ('user', 'user.username'), 'date')
# Rename 'title' columns to 'Post Title' in list view
column_labels = dict(title='Post Title')
......
__version__ = '1.1.1-dev'
__version__ = '1.2.0'
__author__ = 'Serge S. Koval'
__email__ = 'serge.koval+github@gmail.com'
......
......@@ -188,7 +188,7 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
"""
self.name = name
self.category = category
self.endpoint = endpoint
self.endpoint = self._get_endpoint(endpoint)
self.url = url
self.static_folder = static_folder
self.static_url_path = static_url_path
......@@ -206,6 +206,16 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if self._default_view is None:
raise Exception(u'Attempted to instantiate admin view %s without default view' % self.__class__.__name__)
def _get_endpoint(self, endpoint):
"""
Generate Flask endpoint name. By default converts class name to lower case if endpoint is
not explicitly provided.
"""
if endpoint:
return endpoint
return self.__class__.__name__.lower()
def create_blueprint(self, admin):
"""
Create Flask blueprint.
......@@ -213,10 +223,6 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
# Store admin instance
self.admin = admin
# If endpoint name is not provided, get it from the class name
if self.endpoint is None:
self.endpoint = self.__class__.__name__.lower()
# If the static_url_path is not provided, use the admin's
if not self.static_url_path:
self.static_url_path = admin.static_url_path
......@@ -234,15 +240,13 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if not self.url.startswith('/'):
self.url = '%s/%s' % (self.admin.url, self.url)
# If we're working from the root of the site, set prefix to None
if self.url == '/':
self.url = None
# prevent admin static files from conflicting with flask static files
if not self.static_url_path:
self.static_folder='static'
self.static_url_path='/static/admin'
self.static_folder = 'static'
self.static_url_path = '/static/admin'
# If name is not povided, use capitalized endpoint name
if self.name is None:
......
......@@ -484,7 +484,7 @@ class ModelView(BaseModelView):
query = self._search(query, search)
# Get count
count = query.count()
count = query.count() if not self.simple_list_pager else None
# Sorting
if sort_column:
......@@ -592,7 +592,7 @@ class ModelView(BaseModelView):
return False
else:
self.after_model_delete(model)
return True
......
......@@ -339,7 +339,7 @@ class ModelView(BaseModelView):
query = f.apply(query, f.clean(value))
# Get count
count = query.count()
count = query.count() if not self.simple_list_pager else None
# Apply sorting
if sort_column is not None:
......@@ -417,7 +417,7 @@ class ModelView(BaseModelView):
return False
else:
self.after_model_delete(model)
return True
# Default model actions
......@@ -443,6 +443,7 @@ class ModelView(BaseModelView):
query = self.model.select().filter(model_pk << ids)
for m in query:
self.on_model_delete(m)
m.delete_instance(recursive=True)
count += 1
......
......@@ -222,7 +222,7 @@ class ModelView(BaseModelView):
query = self._search(query, search)
# Get count
count = self.coll.find(query).count()
count = self.coll.find(query).count() if not self.simple_list_pager else None
# Sorting
sort_by = None
......@@ -337,7 +337,7 @@ class ModelView(BaseModelView):
return False
else:
self.after_model_delete(model)
return True
# Default model actions
......
import warnings
import time
import datetime
from flask_admin.babel import lazy_gettext
from flask_admin.model import filters
from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_
class BaseSQLAFilter(filters.BaseFilter):
"""
Base SQLAlchemy filter.
......@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
self.column = column
def get_column(self, alias):
return self.column if alias is None else getattr(alias, self.column.key)
def apply(self, query, value, alias=None):
return super(self, BaseSQLAFilter).apply(query, value)
# Common filters
class FilterEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column == value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) == value)
def operation(self):
return lazy_gettext('equals')
class FilterNotEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column != value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) != value)
def operation(self):
return lazy_gettext('not equal')
class FilterLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(self.column.ilike(stmt))
return query.filter(self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('contains')
class FilterNotLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(~self.column.ilike(stmt))
return query.filter(~self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('not contains')
class FilterGreater(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column > value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) > value)
def operation(self):
return lazy_gettext('greater than')
class FilterSmaller(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column < value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) < value)
def operation(self):
return lazy_gettext('smaller than')
class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
if value == '1':
return query.filter(self.column == None)
return query.filter(self.get_column(alias) == None)
else:
return query.filter(self.column != None)
return query.filter(self.get_column(alias) != None)
def operation(self):
return lazy_gettext('empty')
......@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value):
return query.filter(self.column.in_(value))
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias).in_(value))
def operation(self):
return lazy_gettext('in list')
class FilterNotInList(FilterInList):
def apply(self, query, value):
def apply(self, query, value, alias=None):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(~self.column.in_(value), self.column == None))
column = self.get_column(alias)
return query.filter(or_(~column.in_(value), column == None))
def operation(self):
return lazy_gettext('not in list')
......@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
options,
data_type='daterangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateNotBetweenFilter(DateBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
# ~between() isn't possible until sqlalchemy 1.0.0
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
options,
data_type='datetimerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
options,
data_type='timerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class TimeNotBetweenFilter(TimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......
import logging
import warnings
import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, aliased
from sqlalchemy.sql.expression import desc
from sqlalchemy import Column, Boolean, func, or_
from sqlalchemy import Boolean, Table, func, or_
from sqlalchemy.exc import IntegrityError
from flask import flash
......@@ -276,7 +277,6 @@ class ModelView(BaseModelView):
self.session = session
self._search_fields = None
self._search_joins = []
self._filter_joins = dict()
......@@ -322,43 +322,92 @@ class ModelView(BaseModelView):
return field.property.columns
def _get_field_with_path(self, name):
join_tables = []
"""
Resolve property by name and figure out its join path.
if isinstance(name, string_types):
model = self.model
Join path might contain both properties and tables.
"""
path = []
model = self.model
# For strings, resolve path
if isinstance(name, string_types):
for attribute in name.split('.'):
value = getattr(model, attribute)
if (hasattr(value, 'property') and
hasattr(value.property, 'direction')):
hasattr(value.property, 'direction')):
model = value.property.mapper.class_
table = model.__table__
if self._need_join(table):
join_tables.append(table)
path.append(value)
attr = value
else:
attr = name
# determine joins if Table.column (relation object) is given
if isinstance(name, InstrumentedAttribute):
columns = self._get_columns_for_field(name)
# Determine joins if table.column (relation object) is provided
if isinstance(attr, InstrumentedAttribute):
columns = self._get_columns_for_field(attr)
if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name)
column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if self._need_join(column.table):
join_tables.append(column.table)
path.append(column.table)
return join_tables, attr
return attr, path
def _need_join(self, table):
"""
Check if join to a table is necessary.
"""
return table not in self.model._sa_class_manager.mapper.tables
def _apply_path_joins(self, query, joins, path, inner_join=True):
"""
Apply join path to the query.
:param query:
Query to add joins to
:param joins:
List of current joins. Used to avoid joining on same relationship more than once
:param path:
Path to be joined
:param fn:
Join function
"""
last = None
if path:
for item in path:
key = (inner_join, item)
alias = joins.get(key)
if key not in joins:
if not isinstance(item, Table):
alias = aliased(item.property.mapper.class_)
fn = query.join if inner_join else query.outerjoin
if last is None:
query = fn(item) if alias is None else fn(alias, item)
else:
prop = getattr(last, item.key)
query = fn(prop) if alias is None else fn(alias, prop)
joins[key] = alias
last = alias
return query, joins, last
# Scaffolding
def scaffold_pk(self):
"""
......@@ -453,19 +502,19 @@ class ModelView(BaseModelView):
for c in self.column_sortable_list:
if isinstance(c, tuple):
join_tables, column = self._get_field_with_path(c[1])
column, path = self._get_field_with_path(c[1])
column_name = c[0]
elif isinstance(c, InstrumentedAttribute):
join_tables, column = self._get_field_with_path(c)
column, path = self._get_field_with_path(c)
column_name = str(c)
else:
join_tables, column = self._get_field_with_path(c)
column, path = self._get_field_with_path(c)
column_name = c
result[column_name] = column
if join_tables:
self._sortable_joins[column_name] = join_tables
if path:
self._sortable_joins[column_name] = path
return result
......@@ -479,26 +528,15 @@ class ModelView(BaseModelView):
"""
if self.column_searchable_list:
self._search_fields = []
self._search_joins = []
joins = set()
for p in self.column_searchable_list:
join_tables, attr = self._get_field_with_path(p)
attr, joins = self._get_field_with_path(p)
if not attr:
raise Exception('Failed to find field for search field: %s' % p)
for column in self._get_columns_for_field(attr):
column_type = type(column.type).__name__
self._search_fields.append(column)
# Store joins, avoid duplicates
for table in join_tables:
if table.name not in joins:
self._search_joins.append(table)
joins.add(table.name)
self._search_fields.append((column, joins))
return bool(self.column_searchable_list)
......@@ -507,7 +545,7 @@ class ModelView(BaseModelView):
Return list of enabled filters
"""
join_tables, attr = self._get_field_with_path(name)
attr, joins = self._get_field_with_path(name)
if attr is None:
raise Exception('Failed to find field for filter: %s' % name)
......@@ -535,10 +573,11 @@ class ModelView(BaseModelView):
if flt:
table = column.table
if join_tables:
self._filter_joins[table.name] = join_tables
if joins:
self._filter_joins[column] = joins
elif self._need_join(table):
self._filter_joins[table.name] = [table]
self._filter_joins[column] = [table]
filters.extend(flt)
return filters
......@@ -563,9 +602,6 @@ class ModelView(BaseModelView):
type_name = type(column.type).__name__
if join_tables:
self._filter_joins[column.table.name] = join_tables
flt = self.filter_converter.convert(
type_name,
column,
......@@ -573,8 +609,10 @@ class ModelView(BaseModelView):
options=self.column_choices.get(name),
)
if flt and not join_tables and self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
if joins:
self._filter_joins[column] = joins
elif self._need_join(column.table):
self._filter_joins[column] = [column.table]
return flt
......@@ -583,7 +621,7 @@ class ModelView(BaseModelView):
column = filter.column
if self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table]
self._filter_joins[column] = [column.table]
return filter
......@@ -707,27 +745,25 @@ class ModelView(BaseModelView):
:param query:
Query
:param joins:
Joins set
:pram joins:
Current joins
:param sort_joins:
Sort joins (properties or tables)
:param sort_field:
Sort field
:param sort_desc:
Ascending or descending
"""
# TODO: Preprocessing for joins
# Handle joins
if sort_joins:
for table in sort_joins:
if table.name not in joins:
query = query.outerjoin(table)
if sort_field is not None:
# Handle joins
query, joins, alias = self._apply_path_joins(query, joins, sort_joins, inner_join=False)
joins.add(table.name)
column = sort_field if alias is None else getattr(alias, sort_field.key)
if sort_field is not None:
if sort_desc:
query = query.order_by(desc(sort_field))
query = query.order_by(desc(column))
else:
query = query.order_by(sort_field)
query = query.order_by(column)
return query, joins
......@@ -737,12 +773,112 @@ class ModelView(BaseModelView):
if order is not None:
field, direction = order
join_tables, attr = self._get_field_with_path(field)
attr, joins = self._get_field_with_path(field)
return join_tables, attr, direction
return attr, joins, direction
return None
def _apply_sorting(self, query, joins, sort_column, sort_desc):
if sort_column is not None:
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_field, sort_joins, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins
def _apply_search(self, query, count_query, joins, count_joins, search):
"""
Apply search to a query.
"""
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = []
count_filter_stmt = []
for field, path in self._search_fields:
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
count_alias = None
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(count_query,
count_joins,
path,
inner_join=False)
column = field if alias is None else getattr(alias, field.key)
filter_stmt.append(column.ilike(stmt))
if count_filter_stmt is not None:
column = field if count_alias is None else getattr(count_alias, field.key)
count_filter_stmt.append(column.ilike(stmt))
query = query.filter(or_(*filter_stmt))
if count_query is not None:
count_query = count_query.filter(or_(*count_filter_stmt))
return query, count_query, joins, count_joins
def _apply_filters(self, query, count_query, joins, count_joins, filters):
for idx, flt_name, value in filters:
flt = self._filters[idx]
alias = None
count_alias = None
# Figure out joins
if isinstance(flt, sqla_filters.BaseSQLAFilter):
path = self._filter_joins.get(flt.column, [])
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(
count_query,
count_joins,
path,
inner_join=False)
# Clean value .clean() and apply the filter
clean_value = flt.clean(value)
try:
query = flt.apply(query, clean_value, alias)
except TypeError:
spec = inspect.getargspec(flt.apply)
if len(spec.args) == 2:
warnings.warn('Please update your custom filter %s to include additional `alias` parameter.' % repr(flt))
else:
raise
query = flt.apply(query, clean_value)
if count_query is not None:
try:
count_query = flt.apply(count_query, clean_value, count_alias)
except TypeError:
count_query = flt.apply(count_query, clean_value)
return query, count_query, joins, count_joins
def get_list(self, page, sort_column, sort_desc, search, filters, execute=True):
"""
Return models from the database.
......@@ -761,84 +897,45 @@ class ModelView(BaseModelView):
List of filter tuples
"""
# Will contain names of joined tables to avoid duplicate joins
joins = set()
# Will contain join paths with optional aliased object
joins = {}
count_joins = {}
query = self.get_query()
count_query = self.get_count_query()
count_query = self.get_count_query() if not self.simple_list_pager else None
# Ignore eager-loaded relations (prevent unnecessary joins)
# TODO: Separate join detection for query and count query?
if hasattr(query, '_join_entities'):
for entity in query._join_entities:
for table in entity.tables:
joins.add(table.name)
joins[table] = None
# Apply search criteria
if self._search_supported and search:
# Apply search-related joins
if self._search_joins:
for table in self._search_joins:
if table.name not in joins:
query = query.outerjoin(table)
count_query = count_query.outerjoin(table)
joins.add(table.name)
# Apply terms
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = [c.ilike(stmt) for c in self._search_fields]
query = query.filter(or_(*filter_stmt))
count_query = count_query.filter(or_(*filter_stmt))
query, count_query, joins, count_joins = self._apply_search(query,
count_query,
joins,
count_joins,
search)
# Apply filters
if filters and self._filters:
for idx, flt_name, value in filters:
flt = self._filters[idx]
query, count_query, joins, count_joins = self._apply_filters(query,
count_query,
joins,
count_joins,
filters)
# Figure out joins
if isinstance(flt, sqla_filters.BaseSQLAFilter):
tbl = flt.column.table.name
join_tables = self._filter_joins.get(tbl, [])
for table in join_tables:
if table.name not in joins:
query = query.join(table)
count_query = count_query.join(table)
joins.add(table.name)
# turn into python format with .clean() and apply filter
query = flt.apply(query, flt.clean(value))
count_query = flt.apply(count_query, flt.clean(value))
# Calculate number of rows
count = count_query.scalar()
# Calculate number of rows if necessary
count = count_query.scalar() if count_query else None
# Auto join
for j in self._auto_joins:
query = query.options(joinedload(j))
# Sorting
if sort_column is not None:
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_joins, sort_field, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
query, joins = self._apply_sorting(query, joins, sort_column, sort_desc)
# Pagination
if page is not None:
......@@ -944,7 +1041,7 @@ class ModelView(BaseModelView):
return False
else:
self.after_model_delete(model)
return True
# Default model actions
......
......@@ -51,12 +51,21 @@ class FileUploadInput(object):
template = self.data_template if field.data else self.empty_template
if field.errors:
template = self.empty_template
if field.data and isinstance(field.data, FileStorage):
value = field.data.filename
else:
value = field.data or ''
return HTMLString(template % {
'text': html_params(type='text',
readonly='readonly',
value=field.data,
value=value,
name=field.name),
'file': html_params(type='file',
value=value,
**kwargs),
'marker': '_%s-delete' % field.name
})
......@@ -122,7 +131,7 @@ class FileUploadField(fields.StringField):
def __init__(self, label=None, validators=None,
base_path=None, relative_path=None,
namegen=None, allowed_extensions=None,
permission=0o666,
permission=0o666, allow_overwrite=True,
**kwargs):
"""
Constructor.
......@@ -154,6 +163,11 @@ class FileUploadField(fields.StringField):
:param allowed_extensions:
List of allowed extensions. If not provided, will allow any file.
:param allow_overwrite:
Whether to overwrite existing files in upload directory. Defaults to `True`.
.. versionadded:: 1.1.1
The `allow_overwrite` parameter was added.
"""
self.base_path = base_path
self.relative_path = relative_path
......@@ -161,6 +175,7 @@ class FileUploadField(fields.StringField):
self.namegen = namegen or namegen_filename
self.allowed_extensions = allowed_extensions
self.permission = permission
self._allow_overwrite = allow_overwrite
self._should_delete = False
......@@ -188,6 +203,11 @@ class FileUploadField(fields.StringField):
def pre_validate(self, form):
if self._is_uploaded_file(self.data) and not self.is_file_allowed(self.data.filename):
raise ValidationError(gettext('Invalid file extension'))
# Handle overwriting existing content
if not self._is_uploaded_file(self.data):
return
if self._allow_overwrite == False and os.path.exists(self._get_path(self.data.filename)):
raise ValidationError(gettext('File "%s" already exists.' % self.data.filename))
def process(self, formdata, data=unset_value):
if formdata:
......@@ -253,6 +273,9 @@ class FileUploadField(fields.StringField):
if not op.exists(op.dirname(path)):
os.makedirs(os.path.dirname(path), self.permission | 0o111)
if self._allow_overwrite == False and os.path.exists(path):
raise ValueError(gettext('File "%s" already exists.' % path))
data.save(path)
return filename
......
......@@ -15,7 +15,7 @@ from flask_admin.form import BaseForm, FormOpts, rules
from flask_admin.model import filters, typefmt
from flask_admin.actions import ActionsMixin
from flask_admin.helpers import (get_form_data, validate_form_on_submit,
get_redirect_target, flash_errors)
get_redirect_target, flash_errors)
from flask_admin.tools import rec_getattr
from flask_admin._backwards import ObsoleteAttr
from flask_admin._compat import iteritems, OrderedDict, as_unicode
......@@ -321,6 +321,12 @@ class BaseModelView(BaseView, ActionsMixin):
Controls if the primary key should be displayed in the list view.
"""
simple_list_pager = False
"""
Enable or disable simple list pager.
If enabled, model interface would not run count query and will only show prev/next pager buttons.
"""
form = None
"""
Form class. Override if you want to use custom form for your model.
......@@ -563,28 +569,30 @@ class BaseModelView(BaseView, ActionsMixin):
:param menu_icon_value:
Icon glyph name or URL, depending on `menu_icon_type` setting
"""
self.model = model
# If name not provided, it is model name
if name is None:
name = '%s' % self._prettify_class_name(model.__name__)
# If endpoint not provided, it is model name
if endpoint is None:
endpoint = model.__name__.lower()
super(BaseModelView, self).__init__(name, category, endpoint, url, static_folder,
menu_class_name=menu_class_name,
menu_icon_type=menu_icon_type,
menu_icon_value=menu_icon_value)
self.model = model
# Actions
self.init_actions()
# Scaffolding
self._refresh_cache()
# Endpoint
def _get_endpoint(self, endpoint):
if endpoint:
return super(BaseModelView, self)._get_endpoint(endpoint)
return self.model.__name__.lower()
# Caching
def _refresh_forms_cache(self):
# Forms
......@@ -617,7 +625,7 @@ class BaseModelView(BaseView, ActionsMixin):
self._filter_groups[flt.name].append({
'index': i,
'arg': self.get_filter_arg(i, flt),
'operation': as_unicode(flt.operation()),
'operation': flt.operation(),
'options': flt.get_options(self) or None,
'type': flt.data_type
})
......@@ -852,6 +860,27 @@ class BaseModelView(BaseView, ActionsMixin):
else:
return str(index)
def _get_filter_groups(self):
"""
Returns non-lazy version of filter strings
"""
if self._filter_groups:
results = OrderedDict()
for key, value in iteritems(self._filter_groups):
items = []
for item in value:
copy = dict(item)
copy['operation'] = as_unicode(copy['operation'])
items.append(copy)
results[key] = items
return results
return None
# Form helpers
def scaffold_form(self):
"""
......@@ -1018,7 +1047,7 @@ class BaseModelView(BaseView, ActionsMixin):
missing_fields.append(field.name)
return missing_fields
def _show_missing_fields_warning(self, text):
warnings.warn(text)
......@@ -1200,7 +1229,7 @@ class BaseModelView(BaseView, ActionsMixin):
By default do nothing.
"""
pass
def after_model_delete(self, model):
"""
Perform some actions after a model was deleted and
......@@ -1214,7 +1243,7 @@ class BaseModelView(BaseView, ActionsMixin):
:param model:
Model that was deleted
"""
pass
pass
def on_form_prefill (self, form, id):
"""
......@@ -1463,9 +1492,12 @@ class BaseModelView(BaseView, ActionsMixin):
view_args.search, view_args.filters)
# Calculate number of pages
num_pages = count // self.page_size
if count % self.page_size != 0:
num_pages += 1
if count is not None:
num_pages = count // self.page_size
if count % self.page_size != 0:
num_pages += 1
else:
num_pages = None
# Various URL generation helpers
def pager_url(p):
......@@ -1508,6 +1540,7 @@ class BaseModelView(BaseView, ActionsMixin):
pager_url=pager_url,
num_pages=num_pages,
page=view_args.page,
page_size=self.page_size,
# Sorting
sort_column=view_args.sort,
......@@ -1521,7 +1554,7 @@ class BaseModelView(BaseView, ActionsMixin):
# Filters
filters=self._filters,
filter_groups=self._filter_groups,
filter_groups=self._get_filter_groups(),
active_filters=view_args.filters,
# Actions
......
......@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %}
<i class="fa fa-{{ icon_value }}"></i>
{% elif icon_type == 'image' %}
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image"></img>
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image">
{% elif icon_type == 'image-url' %}
<img src="item.icon_value" alt="menu image"></img>
<img src="item.icon_value" alt="menu image">
{% endif %}
{% endif %}
{%- endmacro %}
......
......@@ -76,6 +76,31 @@
{% endif %}
{%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<div class="pagination">
<ul>
{% if page > 0 %}
<li>
<a href="{{ generator(page - 1) }}">&lt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(0) }}">&lt;</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a href="{{ generator(page + 1) }}">&gt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(page) }}">&gt;</a>
</li>
{% endif %}
</ul>
</div>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %}
......
......@@ -13,7 +13,7 @@
{% block model_menu_bar %}
<ul class="nav nav-tabs actions-nav">
<li class="active">
<a href="javascript:void(0)">{{ _gettext('List') }} ({{ count }})</a>
<a href="javascript:void(0)">{{ _gettext('List') }}{% if count %} ({{ count }}){% endif %}</a>
</li>
{% if admin_view.can_create %}
<li>
......@@ -110,7 +110,11 @@
<form class="icon" method="POST" action="{{ get_url('.delete_view') }}">
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="{{ _gettext('Delete record') }}">
<i class="fa fa-trash icon-trash"></i>
</button>
......@@ -147,7 +151,13 @@
</tr>
{% endfor %}
</table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }}
......
......@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %}
<i class="fa {{ icon_value }}"></i>
{% elif icon_type == 'image' %}
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image"></img>
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image">
{% elif icon_type == 'image-url' %}
<img src="item.icon_value" alt="menu image"></img>
<img src="item.icon_value" alt="menu image">
{% endif %}
{% endif %}
{%- endmacro %}
......
......@@ -74,6 +74,29 @@
{% endif %}
{%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<ul class="pagination">
{% if page > 0 %}
<li>
<a href="{{ generator(page - 1) }}">&lt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(0) }}">&lt;</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a href="{{ generator(page + 1) }}">&gt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(page) }}">&gt;</a>
</li>
{% endif %}
</ul>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %}
......
......@@ -13,7 +13,7 @@
{% block model_menu_bar %}
<ul class="nav nav-tabs actions-nav">
<li class="active">
<a href="javascript:void(0)">{{ _gettext('List') }} ({{ count }})</a>
<a href="javascript:void(0)">{{ _gettext('List') }}{% if count %} ({{ count }}){% endif %}</a>
</li>
{% if admin_view.can_create %}
<li>
......@@ -110,7 +110,11 @@
<form class="icon" method="POST" action="{{ get_url('.delete_view') }}">
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="Delete record">
<span class="fa fa-trash glyphicon glyphicon-trash"></span>
</button>
......@@ -146,7 +150,13 @@
</tr>
{% endfor %}
</table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }}
......
......@@ -948,3 +948,20 @@ def test_form_args_embeddeddoc():
eq_(form.timestamp.label.text, 'Last Updated Time')
# This is the failure
eq_(form.info.label.text, 'Information')
def test_simple_list_pager():
app, db, admin = setup()
Model1, _ = create_models(db)
class TestModelView(CustomModelView):
simple_list_pager = True
def get_count_query(self):
assert False
view = TestModelView(Model1)
admin.add_view(view)
count, data = view.get_list(0, None, None, None, None)
ok_(count is None)
......@@ -12,6 +12,7 @@ from . import setup
from datetime import datetime, time, date
class CustomModelView(ModelView):
def __init__(self, model, session,
name=None, category=None, endpoint=None, url=None,
......@@ -259,10 +260,11 @@ def test_column_searchable_list():
eq_(view._search_supported, True)
eq_(len(view._search_fields), 2)
ok_(isinstance(view._search_fields[0], db.Column))
ok_(isinstance(view._search_fields[1], db.Column))
eq_(view._search_fields[0].name, 'string_field')
eq_(view._search_fields[1].name, 'int_field')
ok_(isinstance(view._search_fields[0][0], db.Column))
ok_(isinstance(view._search_fields[1][0], db.Column))
eq_(view._search_fields[0][0].name, 'string_field')
eq_(view._search_fields[1][0].name, 'int_field')
db.session.add(Model2('model1-test', 5000))
db.session.add(Model2('model2-test', 9000))
......@@ -417,6 +419,8 @@ def test_column_filters():
)
admin.add_view(view)
client = app.test_client()
eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
......@@ -515,21 +519,23 @@ def test_column_filters():
fill_db(db, Model1, Model2)
client = app.test_client()
# Test equals
rv = client.get('/admin/model1/?flt0_0=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
# the filter value is always in "data"
# need to check a different column than test1 for the expected row
ok_('test2_val_1' in data)
ok_('test1_val_2' not in data)
# Test NOT IN filter
rv = client.get('/admin/model1/?flt0_6=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test1_val_2' in data)
ok_('test2_val_1' not in data)
# Test string filter
view = CustomModelView(Model1, db.session,
......@@ -1103,9 +1109,11 @@ def test_column_filters():
rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1')
data = rv.data.decode('utf-8')
ok_('test1_val_1' in data)
ok_('test1_val_2' not in data)
def test_url_args():
app, db, admin = setup()
......@@ -1680,3 +1688,123 @@ def test_safe_redirect():
assert_true(rv.location.startswith('http://localhost/admin/model1/edit/'))
assert_true('url=%2Fadmin%2Fmodel1%2F' in rv.location)
assert_true('id=2' in rv.location)
def test_simple_list_pager():
app, db, admin = setup()
Model1, _ = create_models(db)
db.create_all()
class TestModelView(CustomModelView):
simple_list_pager = True
def get_count_query(self):
assert False
view = TestModelView(Model1, db.session)
admin.add_view(view)
count, data = view.get_list(0, None, None, None, None)
assert_true(count is None)
def test_advanced_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1, backref='model2')
class Model3(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model2_id = db.Column(db.Integer, db.ForeignKey(Model2.id))
model2 = db.relationship(Model2, backref='model3')
view1 = CustomModelView(Model1, db.session)
admin.add_view(view1)
view2 = CustomModelView(Model2, db.session)
admin.add_view(view2)
view3 = CustomModelView(Model3, db.session)
admin.add_view(view3)
# Test joins
attr, path = view2._get_field_with_path('model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2')
eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1])
# Test how joins are applied
query = view3.get_query()
joins = {}
q1, joins, alias = view3._apply_path_joins(query, joins, path)
ok_((True, Model3.model2) in joins)
ok_((True, Model2.model1) in joins)
ok_(alias is not None)
# Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2)
for p in q2._join_entities:
ok_(p in q1._join_entities)
ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path
attr, path = view2._get_field_with_path(Model1.test)
eq_(attr, Model1.test)
eq_(path, [Model1.__table__])
q3, joins, alias = view2._apply_path_joins(view2.get_query(), joins, path)
eq_(len(joins), 3)
ok_(alias is None)
def test_multipath_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
first_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
first = db.relationship(Model1, backref='first', foreign_keys=[first_id])
second_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
second = db.relationship(Model1, backref='second', foreign_keys=[second_id])
db.create_all()
view = CustomModelView(Model2, db.session, filters=['first.test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model2/')
eq_(rv.status_code, 200)
......@@ -76,7 +76,7 @@ def test_baseview_defaults():
view = MockView()
eq_(view.name, None)
eq_(view.category, None)
eq_(view.endpoint, None)
eq_(view.endpoint, 'mockview')
eq_(view.url, None)
eq_(view.static_folder, None)
eq_(view.admin, None)
......@@ -388,3 +388,12 @@ def test_menu_links():
def check_class_name():
view = MockView()
eq_(view.name, 'Mock View')
def check_endpoint():
class CustomView(MockView):
def _get_endpoint(self, endpoint):
return 'admin.' + super(CustomView, self)._get_endpoint(endpoint)
view = CustomView()
eq_(view.endpoint, 'admin.customview')
......@@ -40,6 +40,9 @@ def test_upload_field():
class TestForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path)
class TestNoOverWriteForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path, allow_overwrite=False)
class Dummy(object):
pass
......@@ -74,6 +77,7 @@ def test_upload_field():
# Check delete
with app.test_request_context(method='POST', data={'_upload-delete': 'checked'}):
my_form = TestForm(helpers.get_form_data())
ok_(my_form.validate())
......@@ -83,6 +87,24 @@ def test_upload_field():
ok_(not op.exists(op.join(path, 'test2.txt')))
# Check overwrite
_remove_testfiles()
my_form_ow = TestNoOverWriteForm()
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(my_form_ow.validate())
my_form_ow.populate_obj(dummy)
eq_(dummy.upload, 'test1.txt')
ok_(op.exists(op.join(path, 'test1.txt')))
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(not my_form_ow.validate())
_remove_testfiles()
def test_image_upload_field():
app = Flask(__name__)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment