Commit 207d23fd authored by Petrus J.v.Rensburg's avatar Petrus J.v.Rensburg

Merge branch 'master' into examples

parents 2d3d1d63 6b6fe519
Changelog Changelog
========= =========
1.2.0
-----
* Codebase was migrated to Flask-Admin GitHub organization
* Automatically inject Flask-WTF CSRF token to internal Flask-Admin forms
* MapBox v4 support for GeoAlchemy
* Updated translations with help of CrowdIn
* Show warning if field was ignored in form rendering rules
* Simple AppEngine backend
* Optional support for Font Awesome in templates and menus
* Bug fixes
1.1.0 1.1.0
----- -----
...@@ -43,21 +55,3 @@ Highlights: ...@@ -43,21 +55,3 @@ Highlights:
* Support for newer wtforms versions * Support for newer wtforms versions
* `form_rules` property that affects both create and edit forms * `form_rules` property that affects both create and edit forms
* Lots of bugfixes * Lots of bugfixes
1.0.7
-----
Full change log and feature walkthrough can be found `here <http://mrjoes.github.io/2013/10/21/flask-admin-107.html>`_.
Highlights:
* Python 3 support
* AJAX-based foreign-key data loading for all backends
* New, optional, rule-based form rendering engine
* MongoEngine fixes and features: GridFS support, nested subdocument configuration and much more
* Greatly improved and more configurable inline models
* New WTForms fields and widgets
* `form_extra_columns` allows adding custom columns to the form declaratively
* Redis cli
* SQLAlchemy backend can handle inherited models with multiple PKs
* Lots of bug fixes
...@@ -101,7 +101,8 @@ class FileView(sqla.ModelView): ...@@ -101,7 +101,8 @@ class FileView(sqla.ModelView):
form_args = { form_args = {
'path': { 'path': {
'label': 'File', 'label': 'File',
'base_path': file_path 'base_path': file_path,
'allow_overwrite': False
} }
} }
......
...@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView): ...@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
# List of columns that can be sorted. For 'user' column, use User.username as # List of columns that can be sorted. For 'user' column, use User.username as
# a column. # a column.
column_sortable_list = ('title', ('user', User.username), 'date') column_sortable_list = ('title', ('user', 'user.username'), 'date')
# Rename 'title' columns to 'Post Title' in list view # Rename 'title' columns to 'Post Title' in list view
column_labels = dict(title='Post Title') column_labels = dict(title='Post Title')
......
__version__ = '1.1.1-dev' __version__ = '1.2.0'
__author__ = 'Serge S. Koval' __author__ = 'Serge S. Koval'
__email__ = 'serge.koval+github@gmail.com' __email__ = 'serge.koval+github@gmail.com'
......
...@@ -188,7 +188,7 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)): ...@@ -188,7 +188,7 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
""" """
self.name = name self.name = name
self.category = category self.category = category
self.endpoint = endpoint self.endpoint = self._get_endpoint(endpoint)
self.url = url self.url = url
self.static_folder = static_folder self.static_folder = static_folder
self.static_url_path = static_url_path self.static_url_path = static_url_path
...@@ -206,6 +206,16 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)): ...@@ -206,6 +206,16 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if self._default_view is None: if self._default_view is None:
raise Exception(u'Attempted to instantiate admin view %s without default view' % self.__class__.__name__) raise Exception(u'Attempted to instantiate admin view %s without default view' % self.__class__.__name__)
def _get_endpoint(self, endpoint):
"""
Generate Flask endpoint name. By default converts class name to lower case if endpoint is
not explicitly provided.
"""
if endpoint:
return endpoint
return self.__class__.__name__.lower()
def create_blueprint(self, admin): def create_blueprint(self, admin):
""" """
Create Flask blueprint. Create Flask blueprint.
...@@ -213,10 +223,6 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)): ...@@ -213,10 +223,6 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
# Store admin instance # Store admin instance
self.admin = admin self.admin = admin
# If endpoint name is not provided, get it from the class name
if self.endpoint is None:
self.endpoint = self.__class__.__name__.lower()
# If the static_url_path is not provided, use the admin's # If the static_url_path is not provided, use the admin's
if not self.static_url_path: if not self.static_url_path:
self.static_url_path = admin.static_url_path self.static_url_path = admin.static_url_path
...@@ -234,15 +240,13 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)): ...@@ -234,15 +240,13 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if not self.url.startswith('/'): if not self.url.startswith('/'):
self.url = '%s/%s' % (self.admin.url, self.url) self.url = '%s/%s' % (self.admin.url, self.url)
# If we're working from the root of the site, set prefix to None # If we're working from the root of the site, set prefix to None
if self.url == '/': if self.url == '/':
self.url = None self.url = None
# prevent admin static files from conflicting with flask static files # prevent admin static files from conflicting with flask static files
if not self.static_url_path: if not self.static_url_path:
self.static_folder='static' self.static_folder = 'static'
self.static_url_path='/static/admin' self.static_url_path = '/static/admin'
# If name is not povided, use capitalized endpoint name # If name is not povided, use capitalized endpoint name
if self.name is None: if self.name is None:
......
...@@ -484,7 +484,7 @@ class ModelView(BaseModelView): ...@@ -484,7 +484,7 @@ class ModelView(BaseModelView):
query = self._search(query, search) query = self._search(query, search)
# Get count # Get count
count = query.count() count = query.count() if not self.simple_list_pager else None
# Sorting # Sorting
if sort_column: if sort_column:
...@@ -592,7 +592,7 @@ class ModelView(BaseModelView): ...@@ -592,7 +592,7 @@ class ModelView(BaseModelView):
return False return False
else: else:
self.after_model_delete(model) self.after_model_delete(model)
return True return True
......
...@@ -339,7 +339,7 @@ class ModelView(BaseModelView): ...@@ -339,7 +339,7 @@ class ModelView(BaseModelView):
query = f.apply(query, f.clean(value)) query = f.apply(query, f.clean(value))
# Get count # Get count
count = query.count() count = query.count() if not self.simple_list_pager else None
# Apply sorting # Apply sorting
if sort_column is not None: if sort_column is not None:
...@@ -417,7 +417,7 @@ class ModelView(BaseModelView): ...@@ -417,7 +417,7 @@ class ModelView(BaseModelView):
return False return False
else: else:
self.after_model_delete(model) self.after_model_delete(model)
return True return True
# Default model actions # Default model actions
...@@ -443,6 +443,7 @@ class ModelView(BaseModelView): ...@@ -443,6 +443,7 @@ class ModelView(BaseModelView):
query = self.model.select().filter(model_pk << ids) query = self.model.select().filter(model_pk << ids)
for m in query: for m in query:
self.on_model_delete(m)
m.delete_instance(recursive=True) m.delete_instance(recursive=True)
count += 1 count += 1
......
...@@ -222,7 +222,7 @@ class ModelView(BaseModelView): ...@@ -222,7 +222,7 @@ class ModelView(BaseModelView):
query = self._search(query, search) query = self._search(query, search)
# Get count # Get count
count = self.coll.find(query).count() count = self.coll.find(query).count() if not self.simple_list_pager else None
# Sorting # Sorting
sort_by = None sort_by = None
...@@ -337,7 +337,7 @@ class ModelView(BaseModelView): ...@@ -337,7 +337,7 @@ class ModelView(BaseModelView):
return False return False
else: else:
self.after_model_delete(model) self.after_model_delete(model)
return True return True
# Default model actions # Default model actions
......
import warnings
import time
import datetime
from flask_admin.babel import lazy_gettext from flask_admin.babel import lazy_gettext
from flask_admin.model import filters from flask_admin.model import filters
from flask_admin.contrib.sqla import tools from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_ from sqlalchemy.sql import not_, or_
class BaseSQLAFilter(filters.BaseFilter): class BaseSQLAFilter(filters.BaseFilter):
""" """
Base SQLAlchemy filter. Base SQLAlchemy filter.
...@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter): ...@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
self.column = column self.column = column
def get_column(self, alias):
return self.column if alias is None else getattr(alias, self.column.key)
def apply(self, query, value, alias=None):
return super(self, BaseSQLAFilter).apply(query, value)
# Common filters # Common filters
class FilterEqual(BaseSQLAFilter): class FilterEqual(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column == value) return query.filter(self.get_column(alias) == value)
def operation(self): def operation(self):
return lazy_gettext('equals') return lazy_gettext('equals')
class FilterNotEqual(BaseSQLAFilter): class FilterNotEqual(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column != value) return query.filter(self.get_column(alias) != value)
def operation(self): def operation(self):
return lazy_gettext('not equal') return lazy_gettext('not equal')
class FilterLike(BaseSQLAFilter): class FilterLike(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value) stmt = tools.parse_like_term(value)
return query.filter(self.column.ilike(stmt)) return query.filter(self.get_column(alias).ilike(stmt))
def operation(self): def operation(self):
return lazy_gettext('contains') return lazy_gettext('contains')
class FilterNotLike(BaseSQLAFilter): class FilterNotLike(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value) stmt = tools.parse_like_term(value)
return query.filter(~self.column.ilike(stmt)) return query.filter(~self.get_column(alias).ilike(stmt))
def operation(self): def operation(self):
return lazy_gettext('not contains') return lazy_gettext('not contains')
class FilterGreater(BaseSQLAFilter): class FilterGreater(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column > value) return query.filter(self.get_column(alias) > value)
def operation(self): def operation(self):
return lazy_gettext('greater than') return lazy_gettext('greater than')
class FilterSmaller(BaseSQLAFilter): class FilterSmaller(BaseSQLAFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column < value) return query.filter(self.get_column(alias) < value)
def operation(self): def operation(self):
return lazy_gettext('smaller than') return lazy_gettext('smaller than')
class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter): class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
if value == '1': if value == '1':
return query.filter(self.column == None) return query.filter(self.get_column(alias) == None)
else: else:
return query.filter(self.column != None) return query.filter(self.get_column(alias) != None)
def operation(self): def operation(self):
return lazy_gettext('empty') return lazy_gettext('empty')
...@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter): ...@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
def clean(self, value): def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()] return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value): def apply(self, query, value, alias=None):
return query.filter(self.column.in_(value)) return query.filter(self.get_column(alias).in_(value))
def operation(self): def operation(self):
return lazy_gettext('in list') return lazy_gettext('in list')
class FilterNotInList(FilterInList): class FilterNotInList(FilterInList):
def apply(self, query, value): def apply(self, query, value, alias=None):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added # NOT IN can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(~self.column.in_(value), self.column == None)) column = self.get_column(alias)
return query.filter(or_(~column.in_(value), column == None))
def operation(self): def operation(self):
return lazy_gettext('not in list') return lazy_gettext('not in list')
...@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter): ...@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
options, options,
data_type='daterangepicker') data_type='daterangepicker')
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(self.column.between(start, end)) return query.filter(self.get_column(alias).between(start, end))
class DateNotBetweenFilter(DateBetweenFilter): class DateNotBetweenFilter(DateBetweenFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
# ~between() isn't possible until sqlalchemy 1.0.0 # ~between() isn't possible until sqlalchemy 1.0.0
return query.filter(not_(self.column.between(start, end))) return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self): def operation(self):
return lazy_gettext('not between') return lazy_gettext('not between')
...@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter): ...@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
options, options,
data_type='datetimerangepicker') data_type='datetimerangepicker')
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(self.column.between(start, end)) return query.filter(self.get_column(alias).between(start, end))
class DateTimeNotBetweenFilter(DateTimeBetweenFilter): class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(not_(self.column.between(start, end))) return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self): def operation(self):
return lazy_gettext('not between') return lazy_gettext('not between')
...@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter): ...@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
options, options,
data_type='timerangepicker') data_type='timerangepicker')
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(self.column.between(start, end)) return query.filter(self.get_column(alias).between(start, end))
class TimeNotBetweenFilter(TimeBetweenFilter): class TimeNotBetweenFilter(TimeBetweenFilter):
def apply(self, query, value): def apply(self, query, value, alias=None):
start, end = value start, end = value
return query.filter(not_(self.column.between(start, end))) return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self): def operation(self):
return lazy_gettext('not between') return lazy_gettext('not between')
......
import logging import logging
import warnings import warnings
import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload, aliased
from sqlalchemy.sql.expression import desc from sqlalchemy.sql.expression import desc
from sqlalchemy import Column, Boolean, func, or_ from sqlalchemy import Boolean, Table, func, or_
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from flask import flash from flask import flash
...@@ -276,7 +277,6 @@ class ModelView(BaseModelView): ...@@ -276,7 +277,6 @@ class ModelView(BaseModelView):
self.session = session self.session = session
self._search_fields = None self._search_fields = None
self._search_joins = []
self._filter_joins = dict() self._filter_joins = dict()
...@@ -322,43 +322,92 @@ class ModelView(BaseModelView): ...@@ -322,43 +322,92 @@ class ModelView(BaseModelView):
return field.property.columns return field.property.columns
def _get_field_with_path(self, name): def _get_field_with_path(self, name):
join_tables = [] """
Resolve property by name and figure out its join path.
if isinstance(name, string_types): Join path might contain both properties and tables.
model = self.model """
path = []
model = self.model
# For strings, resolve path
if isinstance(name, string_types):
for attribute in name.split('.'): for attribute in name.split('.'):
value = getattr(model, attribute) value = getattr(model, attribute)
if (hasattr(value, 'property') and if (hasattr(value, 'property') and
hasattr(value.property, 'direction')): hasattr(value.property, 'direction')):
model = value.property.mapper.class_ model = value.property.mapper.class_
table = model.__table__ table = model.__table__
if self._need_join(table): if self._need_join(table):
join_tables.append(table) path.append(value)
attr = value attr = value
else: else:
attr = name attr = name
# determine joins if Table.column (relation object) is given # Determine joins if table.column (relation object) is provided
if isinstance(name, InstrumentedAttribute): if isinstance(attr, InstrumentedAttribute):
columns = self._get_columns_for_field(name) columns = self._get_columns_for_field(attr)
if len(columns) > 1: if len(columns) > 1:
raise Exception('Can only handle one column for %s' % name) raise Exception('Can only handle one column for %s' % name)
column = columns[0] column = columns[0]
# TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
if self._need_join(column.table): if self._need_join(column.table):
join_tables.append(column.table) path.append(column.table)
return join_tables, attr return attr, path
def _need_join(self, table): def _need_join(self, table):
"""
Check if join to a table is necessary.
"""
return table not in self.model._sa_class_manager.mapper.tables return table not in self.model._sa_class_manager.mapper.tables
def _apply_path_joins(self, query, joins, path, inner_join=True):
"""
Apply join path to the query.
:param query:
Query to add joins to
:param joins:
List of current joins. Used to avoid joining on same relationship more than once
:param path:
Path to be joined
:param fn:
Join function
"""
last = None
if path:
for item in path:
key = (inner_join, item)
alias = joins.get(key)
if key not in joins:
if not isinstance(item, Table):
alias = aliased(item.property.mapper.class_)
fn = query.join if inner_join else query.outerjoin
if last is None:
query = fn(item) if alias is None else fn(alias, item)
else:
prop = getattr(last, item.key)
query = fn(prop) if alias is None else fn(alias, prop)
joins[key] = alias
last = alias
return query, joins, last
# Scaffolding # Scaffolding
def scaffold_pk(self): def scaffold_pk(self):
""" """
...@@ -453,19 +502,19 @@ class ModelView(BaseModelView): ...@@ -453,19 +502,19 @@ class ModelView(BaseModelView):
for c in self.column_sortable_list: for c in self.column_sortable_list:
if isinstance(c, tuple): if isinstance(c, tuple):
join_tables, column = self._get_field_with_path(c[1]) column, path = self._get_field_with_path(c[1])
column_name = c[0] column_name = c[0]
elif isinstance(c, InstrumentedAttribute): elif isinstance(c, InstrumentedAttribute):
join_tables, column = self._get_field_with_path(c) column, path = self._get_field_with_path(c)
column_name = str(c) column_name = str(c)
else: else:
join_tables, column = self._get_field_with_path(c) column, path = self._get_field_with_path(c)
column_name = c column_name = c
result[column_name] = column result[column_name] = column
if join_tables: if path:
self._sortable_joins[column_name] = join_tables self._sortable_joins[column_name] = path
return result return result
...@@ -479,26 +528,15 @@ class ModelView(BaseModelView): ...@@ -479,26 +528,15 @@ class ModelView(BaseModelView):
""" """
if self.column_searchable_list: if self.column_searchable_list:
self._search_fields = [] self._search_fields = []
self._search_joins = []
joins = set()
for p in self.column_searchable_list: for p in self.column_searchable_list:
join_tables, attr = self._get_field_with_path(p) attr, joins = self._get_field_with_path(p)
if not attr: if not attr:
raise Exception('Failed to find field for search field: %s' % p) raise Exception('Failed to find field for search field: %s' % p)
for column in self._get_columns_for_field(attr): for column in self._get_columns_for_field(attr):
column_type = type(column.type).__name__ self._search_fields.append((column, joins))
self._search_fields.append(column)
# Store joins, avoid duplicates
for table in join_tables:
if table.name not in joins:
self._search_joins.append(table)
joins.add(table.name)
return bool(self.column_searchable_list) return bool(self.column_searchable_list)
...@@ -507,7 +545,7 @@ class ModelView(BaseModelView): ...@@ -507,7 +545,7 @@ class ModelView(BaseModelView):
Return list of enabled filters Return list of enabled filters
""" """
join_tables, attr = self._get_field_with_path(name) attr, joins = self._get_field_with_path(name)
if attr is None: if attr is None:
raise Exception('Failed to find field for filter: %s' % name) raise Exception('Failed to find field for filter: %s' % name)
...@@ -535,10 +573,11 @@ class ModelView(BaseModelView): ...@@ -535,10 +573,11 @@ class ModelView(BaseModelView):
if flt: if flt:
table = column.table table = column.table
if join_tables: if joins:
self._filter_joins[table.name] = join_tables self._filter_joins[column] = joins
elif self._need_join(table): elif self._need_join(table):
self._filter_joins[table.name] = [table] self._filter_joins[column] = [table]
filters.extend(flt) filters.extend(flt)
return filters return filters
...@@ -563,9 +602,6 @@ class ModelView(BaseModelView): ...@@ -563,9 +602,6 @@ class ModelView(BaseModelView):
type_name = type(column.type).__name__ type_name = type(column.type).__name__
if join_tables:
self._filter_joins[column.table.name] = join_tables
flt = self.filter_converter.convert( flt = self.filter_converter.convert(
type_name, type_name,
column, column,
...@@ -573,8 +609,10 @@ class ModelView(BaseModelView): ...@@ -573,8 +609,10 @@ class ModelView(BaseModelView):
options=self.column_choices.get(name), options=self.column_choices.get(name),
) )
if flt and not join_tables and self._need_join(column.table): if joins:
self._filter_joins[column.table.name] = [column.table] self._filter_joins[column] = joins
elif self._need_join(column.table):
self._filter_joins[column] = [column.table]
return flt return flt
...@@ -583,7 +621,7 @@ class ModelView(BaseModelView): ...@@ -583,7 +621,7 @@ class ModelView(BaseModelView):
column = filter.column column = filter.column
if self._need_join(column.table): if self._need_join(column.table):
self._filter_joins[column.table.name] = [column.table] self._filter_joins[column] = [column.table]
return filter return filter
...@@ -707,27 +745,25 @@ class ModelView(BaseModelView): ...@@ -707,27 +745,25 @@ class ModelView(BaseModelView):
:param query: :param query:
Query Query
:param joins: :pram joins:
Joins set Current joins
:param sort_joins:
Sort joins (properties or tables)
:param sort_field: :param sort_field:
Sort field Sort field
:param sort_desc: :param sort_desc:
Ascending or descending Ascending or descending
""" """
# TODO: Preprocessing for joins if sort_field is not None:
# Handle joins # Handle joins
if sort_joins: query, joins, alias = self._apply_path_joins(query, joins, sort_joins, inner_join=False)
for table in sort_joins:
if table.name not in joins:
query = query.outerjoin(table)
joins.add(table.name) column = sort_field if alias is None else getattr(alias, sort_field.key)
if sort_field is not None:
if sort_desc: if sort_desc:
query = query.order_by(desc(sort_field)) query = query.order_by(desc(column))
else: else:
query = query.order_by(sort_field) query = query.order_by(column)
return query, joins return query, joins
...@@ -737,12 +773,112 @@ class ModelView(BaseModelView): ...@@ -737,12 +773,112 @@ class ModelView(BaseModelView):
if order is not None: if order is not None:
field, direction = order field, direction = order
join_tables, attr = self._get_field_with_path(field) attr, joins = self._get_field_with_path(field)
return join_tables, attr, direction return attr, joins, direction
return None return None
def _apply_sorting(self, query, joins, sort_column, sort_desc):
if sort_column is not None:
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_field, sort_joins, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins
def _apply_search(self, query, count_query, joins, count_joins, search):
"""
Apply search to a query.
"""
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = []
count_filter_stmt = []
for field, path in self._search_fields:
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
count_alias = None
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(count_query,
count_joins,
path,
inner_join=False)
column = field if alias is None else getattr(alias, field.key)
filter_stmt.append(column.ilike(stmt))
if count_filter_stmt is not None:
column = field if count_alias is None else getattr(count_alias, field.key)
count_filter_stmt.append(column.ilike(stmt))
query = query.filter(or_(*filter_stmt))
if count_query is not None:
count_query = count_query.filter(or_(*count_filter_stmt))
return query, count_query, joins, count_joins
def _apply_filters(self, query, count_query, joins, count_joins, filters):
for idx, flt_name, value in filters:
flt = self._filters[idx]
alias = None
count_alias = None
# Figure out joins
if isinstance(flt, sqla_filters.BaseSQLAFilter):
path = self._filter_joins.get(flt.column, [])
query, joins, alias = self._apply_path_joins(query, joins, path, inner_join=False)
if count_query is not None:
count_query, count_joins, count_alias = self._apply_path_joins(
count_query,
count_joins,
path,
inner_join=False)
# Clean value .clean() and apply the filter
clean_value = flt.clean(value)
try:
query = flt.apply(query, clean_value, alias)
except TypeError:
spec = inspect.getargspec(flt.apply)
if len(spec.args) == 2:
warnings.warn('Please update your custom filter %s to include additional `alias` parameter.' % repr(flt))
else:
raise
query = flt.apply(query, clean_value)
if count_query is not None:
try:
count_query = flt.apply(count_query, clean_value, count_alias)
except TypeError:
count_query = flt.apply(count_query, clean_value)
return query, count_query, joins, count_joins
def get_list(self, page, sort_column, sort_desc, search, filters, execute=True): def get_list(self, page, sort_column, sort_desc, search, filters, execute=True):
""" """
Return models from the database. Return models from the database.
...@@ -761,84 +897,45 @@ class ModelView(BaseModelView): ...@@ -761,84 +897,45 @@ class ModelView(BaseModelView):
List of filter tuples List of filter tuples
""" """
# Will contain names of joined tables to avoid duplicate joins # Will contain join paths with optional aliased object
joins = set() joins = {}
count_joins = {}
query = self.get_query() query = self.get_query()
count_query = self.get_count_query() count_query = self.get_count_query() if not self.simple_list_pager else None
# Ignore eager-loaded relations (prevent unnecessary joins) # Ignore eager-loaded relations (prevent unnecessary joins)
# TODO: Separate join detection for query and count query? # TODO: Separate join detection for query and count query?
if hasattr(query, '_join_entities'): if hasattr(query, '_join_entities'):
for entity in query._join_entities: for entity in query._join_entities:
for table in entity.tables: for table in entity.tables:
joins.add(table.name) joins[table] = None
# Apply search criteria # Apply search criteria
if self._search_supported and search: if self._search_supported and search:
# Apply search-related joins query, count_query, joins, count_joins = self._apply_search(query,
if self._search_joins: count_query,
for table in self._search_joins: joins,
if table.name not in joins: count_joins,
query = query.outerjoin(table) search)
count_query = count_query.outerjoin(table)
joins.add(table.name)
# Apply terms
terms = search.split(' ')
for term in terms:
if not term:
continue
stmt = tools.parse_like_term(term)
filter_stmt = [c.ilike(stmt) for c in self._search_fields]
query = query.filter(or_(*filter_stmt))
count_query = count_query.filter(or_(*filter_stmt))
# Apply filters # Apply filters
if filters and self._filters: if filters and self._filters:
for idx, flt_name, value in filters: query, count_query, joins, count_joins = self._apply_filters(query,
flt = self._filters[idx] count_query,
joins,
count_joins,
filters)
# Figure out joins # Calculate number of rows if necessary
if isinstance(flt, sqla_filters.BaseSQLAFilter): count = count_query.scalar() if count_query else None
tbl = flt.column.table.name
join_tables = self._filter_joins.get(tbl, [])
for table in join_tables:
if table.name not in joins:
query = query.join(table)
count_query = count_query.join(table)
joins.add(table.name)
# turn into python format with .clean() and apply filter
query = flt.apply(query, flt.clean(value))
count_query = flt.apply(count_query, flt.clean(value))
# Calculate number of rows
count = count_query.scalar()
# Auto join # Auto join
for j in self._auto_joins: for j in self._auto_joins:
query = query.options(joinedload(j)) query = query.options(joinedload(j))
# Sorting # Sorting
if sort_column is not None: query, joins = self._apply_sorting(query, joins, sort_column, sort_desc)
if sort_column in self._sortable_columns:
sort_field = self._sortable_columns[sort_column]
sort_joins = self._sortable_joins.get(sort_column)
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
else:
order = self._get_default_order()
if order:
sort_joins, sort_field, sort_desc = order
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
# Pagination # Pagination
if page is not None: if page is not None:
...@@ -944,7 +1041,7 @@ class ModelView(BaseModelView): ...@@ -944,7 +1041,7 @@ class ModelView(BaseModelView):
return False return False
else: else:
self.after_model_delete(model) self.after_model_delete(model)
return True return True
# Default model actions # Default model actions
......
...@@ -51,12 +51,21 @@ class FileUploadInput(object): ...@@ -51,12 +51,21 @@ class FileUploadInput(object):
template = self.data_template if field.data else self.empty_template template = self.data_template if field.data else self.empty_template
if field.errors:
template = self.empty_template
if field.data and isinstance(field.data, FileStorage):
value = field.data.filename
else:
value = field.data or ''
return HTMLString(template % { return HTMLString(template % {
'text': html_params(type='text', 'text': html_params(type='text',
readonly='readonly', readonly='readonly',
value=field.data, value=value,
name=field.name), name=field.name),
'file': html_params(type='file', 'file': html_params(type='file',
value=value,
**kwargs), **kwargs),
'marker': '_%s-delete' % field.name 'marker': '_%s-delete' % field.name
}) })
...@@ -122,7 +131,7 @@ class FileUploadField(fields.StringField): ...@@ -122,7 +131,7 @@ class FileUploadField(fields.StringField):
def __init__(self, label=None, validators=None, def __init__(self, label=None, validators=None,
base_path=None, relative_path=None, base_path=None, relative_path=None,
namegen=None, allowed_extensions=None, namegen=None, allowed_extensions=None,
permission=0o666, permission=0o666, allow_overwrite=True,
**kwargs): **kwargs):
""" """
Constructor. Constructor.
...@@ -154,6 +163,11 @@ class FileUploadField(fields.StringField): ...@@ -154,6 +163,11 @@ class FileUploadField(fields.StringField):
:param allowed_extensions: :param allowed_extensions:
List of allowed extensions. If not provided, will allow any file. List of allowed extensions. If not provided, will allow any file.
:param allow_overwrite:
Whether to overwrite existing files in upload directory. Defaults to `True`.
.. versionadded:: 1.1.1
The `allow_overwrite` parameter was added.
""" """
self.base_path = base_path self.base_path = base_path
self.relative_path = relative_path self.relative_path = relative_path
...@@ -161,6 +175,7 @@ class FileUploadField(fields.StringField): ...@@ -161,6 +175,7 @@ class FileUploadField(fields.StringField):
self.namegen = namegen or namegen_filename self.namegen = namegen or namegen_filename
self.allowed_extensions = allowed_extensions self.allowed_extensions = allowed_extensions
self.permission = permission self.permission = permission
self._allow_overwrite = allow_overwrite
self._should_delete = False self._should_delete = False
...@@ -188,6 +203,11 @@ class FileUploadField(fields.StringField): ...@@ -188,6 +203,11 @@ class FileUploadField(fields.StringField):
def pre_validate(self, form): def pre_validate(self, form):
if self._is_uploaded_file(self.data) and not self.is_file_allowed(self.data.filename): if self._is_uploaded_file(self.data) and not self.is_file_allowed(self.data.filename):
raise ValidationError(gettext('Invalid file extension')) raise ValidationError(gettext('Invalid file extension'))
# Handle overwriting existing content
if not self._is_uploaded_file(self.data):
return
if self._allow_overwrite == False and os.path.exists(self._get_path(self.data.filename)):
raise ValidationError(gettext('File "%s" already exists.' % self.data.filename))
def process(self, formdata, data=unset_value): def process(self, formdata, data=unset_value):
if formdata: if formdata:
...@@ -253,6 +273,9 @@ class FileUploadField(fields.StringField): ...@@ -253,6 +273,9 @@ class FileUploadField(fields.StringField):
if not op.exists(op.dirname(path)): if not op.exists(op.dirname(path)):
os.makedirs(os.path.dirname(path), self.permission | 0o111) os.makedirs(os.path.dirname(path), self.permission | 0o111)
if self._allow_overwrite == False and os.path.exists(path):
raise ValueError(gettext('File "%s" already exists.' % path))
data.save(path) data.save(path)
return filename return filename
......
...@@ -15,7 +15,7 @@ from flask_admin.form import BaseForm, FormOpts, rules ...@@ -15,7 +15,7 @@ from flask_admin.form import BaseForm, FormOpts, rules
from flask_admin.model import filters, typefmt from flask_admin.model import filters, typefmt
from flask_admin.actions import ActionsMixin from flask_admin.actions import ActionsMixin
from flask_admin.helpers import (get_form_data, validate_form_on_submit, from flask_admin.helpers import (get_form_data, validate_form_on_submit,
get_redirect_target, flash_errors) get_redirect_target, flash_errors)
from flask_admin.tools import rec_getattr from flask_admin.tools import rec_getattr
from flask_admin._backwards import ObsoleteAttr from flask_admin._backwards import ObsoleteAttr
from flask_admin._compat import iteritems, OrderedDict, as_unicode from flask_admin._compat import iteritems, OrderedDict, as_unicode
...@@ -321,6 +321,12 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -321,6 +321,12 @@ class BaseModelView(BaseView, ActionsMixin):
Controls if the primary key should be displayed in the list view. Controls if the primary key should be displayed in the list view.
""" """
simple_list_pager = False
"""
Enable or disable simple list pager.
If enabled, model interface would not run count query and will only show prev/next pager buttons.
"""
form = None form = None
""" """
Form class. Override if you want to use custom form for your model. Form class. Override if you want to use custom form for your model.
...@@ -563,28 +569,30 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -563,28 +569,30 @@ class BaseModelView(BaseView, ActionsMixin):
:param menu_icon_value: :param menu_icon_value:
Icon glyph name or URL, depending on `menu_icon_type` setting Icon glyph name or URL, depending on `menu_icon_type` setting
""" """
self.model = model
# If name not provided, it is model name # If name not provided, it is model name
if name is None: if name is None:
name = '%s' % self._prettify_class_name(model.__name__) name = '%s' % self._prettify_class_name(model.__name__)
# If endpoint not provided, it is model name
if endpoint is None:
endpoint = model.__name__.lower()
super(BaseModelView, self).__init__(name, category, endpoint, url, static_folder, super(BaseModelView, self).__init__(name, category, endpoint, url, static_folder,
menu_class_name=menu_class_name, menu_class_name=menu_class_name,
menu_icon_type=menu_icon_type, menu_icon_type=menu_icon_type,
menu_icon_value=menu_icon_value) menu_icon_value=menu_icon_value)
self.model = model
# Actions # Actions
self.init_actions() self.init_actions()
# Scaffolding # Scaffolding
self._refresh_cache() self._refresh_cache()
# Endpoint
def _get_endpoint(self, endpoint):
if endpoint:
return super(BaseModelView, self)._get_endpoint(endpoint)
return self.model.__name__.lower()
# Caching # Caching
def _refresh_forms_cache(self): def _refresh_forms_cache(self):
# Forms # Forms
...@@ -617,7 +625,7 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -617,7 +625,7 @@ class BaseModelView(BaseView, ActionsMixin):
self._filter_groups[flt.name].append({ self._filter_groups[flt.name].append({
'index': i, 'index': i,
'arg': self.get_filter_arg(i, flt), 'arg': self.get_filter_arg(i, flt),
'operation': as_unicode(flt.operation()), 'operation': flt.operation(),
'options': flt.get_options(self) or None, 'options': flt.get_options(self) or None,
'type': flt.data_type 'type': flt.data_type
}) })
...@@ -852,6 +860,27 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -852,6 +860,27 @@ class BaseModelView(BaseView, ActionsMixin):
else: else:
return str(index) return str(index)
def _get_filter_groups(self):
"""
Returns non-lazy version of filter strings
"""
if self._filter_groups:
results = OrderedDict()
for key, value in iteritems(self._filter_groups):
items = []
for item in value:
copy = dict(item)
copy['operation'] = as_unicode(copy['operation'])
items.append(copy)
results[key] = items
return results
return None
# Form helpers # Form helpers
def scaffold_form(self): def scaffold_form(self):
""" """
...@@ -1018,7 +1047,7 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1018,7 +1047,7 @@ class BaseModelView(BaseView, ActionsMixin):
missing_fields.append(field.name) missing_fields.append(field.name)
return missing_fields return missing_fields
def _show_missing_fields_warning(self, text): def _show_missing_fields_warning(self, text):
warnings.warn(text) warnings.warn(text)
...@@ -1200,7 +1229,7 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1200,7 +1229,7 @@ class BaseModelView(BaseView, ActionsMixin):
By default do nothing. By default do nothing.
""" """
pass pass
def after_model_delete(self, model): def after_model_delete(self, model):
""" """
Perform some actions after a model was deleted and Perform some actions after a model was deleted and
...@@ -1214,7 +1243,7 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1214,7 +1243,7 @@ class BaseModelView(BaseView, ActionsMixin):
:param model: :param model:
Model that was deleted Model that was deleted
""" """
pass pass
def on_form_prefill (self, form, id): def on_form_prefill (self, form, id):
""" """
...@@ -1463,9 +1492,12 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1463,9 +1492,12 @@ class BaseModelView(BaseView, ActionsMixin):
view_args.search, view_args.filters) view_args.search, view_args.filters)
# Calculate number of pages # Calculate number of pages
num_pages = count // self.page_size if count is not None:
if count % self.page_size != 0: num_pages = count // self.page_size
num_pages += 1 if count % self.page_size != 0:
num_pages += 1
else:
num_pages = None
# Various URL generation helpers # Various URL generation helpers
def pager_url(p): def pager_url(p):
...@@ -1508,6 +1540,7 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1508,6 +1540,7 @@ class BaseModelView(BaseView, ActionsMixin):
pager_url=pager_url, pager_url=pager_url,
num_pages=num_pages, num_pages=num_pages,
page=view_args.page, page=view_args.page,
page_size=self.page_size,
# Sorting # Sorting
sort_column=view_args.sort, sort_column=view_args.sort,
...@@ -1521,7 +1554,7 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1521,7 +1554,7 @@ class BaseModelView(BaseView, ActionsMixin):
# Filters # Filters
filters=self._filters, filters=self._filters,
filter_groups=self._filter_groups, filter_groups=self._get_filter_groups(),
active_filters=view_args.filters, active_filters=view_args.filters,
# Actions # Actions
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %} {% elif icon_type == 'fa' %}
<i class="fa fa-{{ icon_value }}"></i> <i class="fa fa-{{ icon_value }}"></i>
{% elif icon_type == 'image' %} {% elif icon_type == 'image' %}
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image"></img> <img src="{{ url_for('static', filename=icon_value) }}" alt="menu image">
{% elif icon_type == 'image-url' %} {% elif icon_type == 'image-url' %}
<img src="item.icon_value" alt="menu image"></img> <img src="item.icon_value" alt="menu image">
{% endif %} {% endif %}
{% endif %} {% endif %}
{%- endmacro %} {%- endmacro %}
......
...@@ -76,6 +76,31 @@ ...@@ -76,6 +76,31 @@
{% endif %} {% endif %}
{%- endmacro %} {%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<div class="pagination">
<ul>
{% if page > 0 %}
<li>
<a href="{{ generator(page - 1) }}">&lt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(0) }}">&lt;</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a href="{{ generator(page + 1) }}">&gt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(page) }}">&gt;</a>
</li>
{% endif %}
</ul>
</div>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #} {# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %} {% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %} {% set direct_error = h.is_field_error(field.errors) %}
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
{% block model_menu_bar %} {% block model_menu_bar %}
<ul class="nav nav-tabs actions-nav"> <ul class="nav nav-tabs actions-nav">
<li class="active"> <li class="active">
<a href="javascript:void(0)">{{ _gettext('List') }} ({{ count }})</a> <a href="javascript:void(0)">{{ _gettext('List') }}{% if count %} ({{ count }}){% endif %}</a>
</li> </li>
{% if admin_view.can_create %} {% if admin_view.can_create %}
<li> <li>
...@@ -110,7 +110,11 @@ ...@@ -110,7 +110,11 @@
<form class="icon" method="POST" action="{{ get_url('.delete_view') }}"> <form class="icon" method="POST" action="{{ get_url('.delete_view') }}">
{{ delete_form.id(value=get_pk_value(row)) }} {{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }} {{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }} {{ delete_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="{{ _gettext('Delete record') }}"> <button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="{{ _gettext('Delete record') }}">
<i class="fa fa-trash icon-trash"></i> <i class="fa fa-trash icon-trash"></i>
</button> </button>
...@@ -147,7 +151,13 @@ ...@@ -147,7 +151,13 @@
</tr> </tr>
{% endfor %} {% endfor %}
</table> </table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }} {{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %} {% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }} {{ actionlib.form(actions, get_url('.action_view')) }}
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %} {% elif icon_type == 'fa' %}
<i class="fa {{ icon_value }}"></i> <i class="fa {{ icon_value }}"></i>
{% elif icon_type == 'image' %} {% elif icon_type == 'image' %}
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image"></img> <img src="{{ url_for('static', filename=icon_value) }}" alt="menu image">
{% elif icon_type == 'image-url' %} {% elif icon_type == 'image-url' %}
<img src="item.icon_value" alt="menu image"></img> <img src="item.icon_value" alt="menu image">
{% endif %} {% endif %}
{% endif %} {% endif %}
{%- endmacro %} {%- endmacro %}
......
...@@ -74,6 +74,29 @@ ...@@ -74,6 +74,29 @@
{% endif %} {% endif %}
{%- endmacro %} {%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<ul class="pagination">
{% if page > 0 %}
<li>
<a href="{{ generator(page - 1) }}">&lt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(0) }}">&lt;</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a href="{{ generator(page + 1) }}">&gt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(page) }}">&gt;</a>
</li>
{% endif %}
</ul>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #} {# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %} {% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %} {% set direct_error = h.is_field_error(field.errors) %}
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
{% block model_menu_bar %} {% block model_menu_bar %}
<ul class="nav nav-tabs actions-nav"> <ul class="nav nav-tabs actions-nav">
<li class="active"> <li class="active">
<a href="javascript:void(0)">{{ _gettext('List') }} ({{ count }})</a> <a href="javascript:void(0)">{{ _gettext('List') }}{% if count %} ({{ count }}){% endif %}</a>
</li> </li>
{% if admin_view.can_create %} {% if admin_view.can_create %}
<li> <li>
...@@ -110,7 +110,11 @@ ...@@ -110,7 +110,11 @@
<form class="icon" method="POST" action="{{ get_url('.delete_view') }}"> <form class="icon" method="POST" action="{{ get_url('.delete_view') }}">
{{ delete_form.id(value=get_pk_value(row)) }} {{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }} {{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }} {{ delete_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="Delete record"> <button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="Delete record">
<span class="fa fa-trash glyphicon glyphicon-trash"></span> <span class="fa fa-trash glyphicon glyphicon-trash"></span>
</button> </button>
...@@ -146,7 +150,13 @@ ...@@ -146,7 +150,13 @@
</tr> </tr>
{% endfor %} {% endfor %}
</table> </table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }} {{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %} {% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }} {{ actionlib.form(actions, get_url('.action_view')) }}
......
...@@ -948,3 +948,20 @@ def test_form_args_embeddeddoc(): ...@@ -948,3 +948,20 @@ def test_form_args_embeddeddoc():
eq_(form.timestamp.label.text, 'Last Updated Time') eq_(form.timestamp.label.text, 'Last Updated Time')
# This is the failure # This is the failure
eq_(form.info.label.text, 'Information') eq_(form.info.label.text, 'Information')
def test_simple_list_pager():
app, db, admin = setup()
Model1, _ = create_models(db)
class TestModelView(CustomModelView):
simple_list_pager = True
def get_count_query(self):
assert False
view = TestModelView(Model1)
admin.add_view(view)
count, data = view.get_list(0, None, None, None, None)
ok_(count is None)
...@@ -12,6 +12,7 @@ from . import setup ...@@ -12,6 +12,7 @@ from . import setup
from datetime import datetime, time, date from datetime import datetime, time, date
class CustomModelView(ModelView): class CustomModelView(ModelView):
def __init__(self, model, session, def __init__(self, model, session,
name=None, category=None, endpoint=None, url=None, name=None, category=None, endpoint=None, url=None,
...@@ -259,10 +260,11 @@ def test_column_searchable_list(): ...@@ -259,10 +260,11 @@ def test_column_searchable_list():
eq_(view._search_supported, True) eq_(view._search_supported, True)
eq_(len(view._search_fields), 2) eq_(len(view._search_fields), 2)
ok_(isinstance(view._search_fields[0], db.Column))
ok_(isinstance(view._search_fields[1], db.Column)) ok_(isinstance(view._search_fields[0][0], db.Column))
eq_(view._search_fields[0].name, 'string_field') ok_(isinstance(view._search_fields[1][0], db.Column))
eq_(view._search_fields[1].name, 'int_field') eq_(view._search_fields[0][0].name, 'string_field')
eq_(view._search_fields[1][0].name, 'int_field')
db.session.add(Model2('model1-test', 5000)) db.session.add(Model2('model1-test', 5000))
db.session.add(Model2('model2-test', 9000)) db.session.add(Model2('model2-test', 9000))
...@@ -417,6 +419,8 @@ def test_column_filters(): ...@@ -417,6 +419,8 @@ def test_column_filters():
) )
admin.add_view(view) admin.add_view(view)
client = app.test_client()
eq_(len(view._filters), 7) eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']], eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
...@@ -515,21 +519,23 @@ def test_column_filters(): ...@@ -515,21 +519,23 @@ def test_column_filters():
fill_db(db, Model1, Model2) fill_db(db, Model1, Model2)
client = app.test_client() # Test equals
rv = client.get('/admin/model1/?flt0_0=test1_val_1') rv = client.get('/admin/model1/?flt0_0=test1_val_1')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
# the filter value is always in "data" # the filter value is always in "data"
# need to check a different column than test1 for the expected row # need to check a different column than test1 for the expected row
ok_('test2_val_1' in data) ok_('test2_val_1' in data)
ok_('test1_val_2' not in data) ok_('test1_val_2' not in data)
# Test NOT IN filter
rv = client.get('/admin/model1/?flt0_6=test1_val_1') rv = client.get('/admin/model1/?flt0_6=test1_val_1')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test1_val_2' in data) ok_('test1_val_2' in data)
ok_('test2_val_1' not in data)
# Test string filter # Test string filter
view = CustomModelView(Model1, db.session, view = CustomModelView(Model1, db.session,
...@@ -1103,9 +1109,11 @@ def test_column_filters(): ...@@ -1103,9 +1109,11 @@ def test_column_filters():
rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1') rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1')
data = rv.data.decode('utf-8') data = rv.data.decode('utf-8')
ok_('test1_val_1' in data) ok_('test1_val_1' in data)
ok_('test1_val_2' not in data) ok_('test1_val_2' not in data)
def test_url_args(): def test_url_args():
app, db, admin = setup() app, db, admin = setup()
...@@ -1680,3 +1688,123 @@ def test_safe_redirect(): ...@@ -1680,3 +1688,123 @@ def test_safe_redirect():
assert_true(rv.location.startswith('http://localhost/admin/model1/edit/')) assert_true(rv.location.startswith('http://localhost/admin/model1/edit/'))
assert_true('url=%2Fadmin%2Fmodel1%2F' in rv.location) assert_true('url=%2Fadmin%2Fmodel1%2F' in rv.location)
assert_true('id=2' in rv.location) assert_true('id=2' in rv.location)
def test_simple_list_pager():
app, db, admin = setup()
Model1, _ = create_models(db)
db.create_all()
class TestModelView(CustomModelView):
simple_list_pager = True
def get_count_query(self):
assert False
view = TestModelView(Model1, db.session)
admin.add_view(view)
count, data = view.get_list(0, None, None, None, None)
assert_true(count is None)
def test_advanced_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1, backref='model2')
class Model3(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model2_id = db.Column(db.Integer, db.ForeignKey(Model2.id))
model2 = db.relationship(Model2, backref='model3')
view1 = CustomModelView(Model1, db.session)
admin.add_view(view1)
view2 = CustomModelView(Model2, db.session)
admin.add_view(view2)
view3 = CustomModelView(Model3, db.session)
admin.add_view(view3)
# Test joins
attr, path = view2._get_field_with_path('model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2')
eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1])
# Test how joins are applied
query = view3.get_query()
joins = {}
q1, joins, alias = view3._apply_path_joins(query, joins, path)
ok_((True, Model3.model2) in joins)
ok_((True, Model2.model1) in joins)
ok_(alias is not None)
# Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2)
for p in q2._join_entities:
ok_(p in q1._join_entities)
ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path
attr, path = view2._get_field_with_path(Model1.test)
eq_(attr, Model1.test)
eq_(path, [Model1.__table__])
q3, joins, alias = view2._apply_path_joins(view2.get_query(), joins, path)
eq_(len(joins), 3)
ok_(alias is None)
def test_multipath_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
first_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
first = db.relationship(Model1, backref='first', foreign_keys=[first_id])
second_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
second = db.relationship(Model1, backref='second', foreign_keys=[second_id])
db.create_all()
view = CustomModelView(Model2, db.session, filters=['first.test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model2/')
eq_(rv.status_code, 200)
...@@ -76,7 +76,7 @@ def test_baseview_defaults(): ...@@ -76,7 +76,7 @@ def test_baseview_defaults():
view = MockView() view = MockView()
eq_(view.name, None) eq_(view.name, None)
eq_(view.category, None) eq_(view.category, None)
eq_(view.endpoint, None) eq_(view.endpoint, 'mockview')
eq_(view.url, None) eq_(view.url, None)
eq_(view.static_folder, None) eq_(view.static_folder, None)
eq_(view.admin, None) eq_(view.admin, None)
...@@ -388,3 +388,12 @@ def test_menu_links(): ...@@ -388,3 +388,12 @@ def test_menu_links():
def check_class_name(): def check_class_name():
view = MockView() view = MockView()
eq_(view.name, 'Mock View') eq_(view.name, 'Mock View')
def check_endpoint():
class CustomView(MockView):
def _get_endpoint(self, endpoint):
return 'admin.' + super(CustomView, self)._get_endpoint(endpoint)
view = CustomView()
eq_(view.endpoint, 'admin.customview')
...@@ -40,6 +40,9 @@ def test_upload_field(): ...@@ -40,6 +40,9 @@ def test_upload_field():
class TestForm(form.BaseForm): class TestForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path) upload = form.FileUploadField('Upload', base_path=path)
class TestNoOverWriteForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path, allow_overwrite=False)
class Dummy(object): class Dummy(object):
pass pass
...@@ -74,6 +77,7 @@ def test_upload_field(): ...@@ -74,6 +77,7 @@ def test_upload_field():
# Check delete # Check delete
with app.test_request_context(method='POST', data={'_upload-delete': 'checked'}): with app.test_request_context(method='POST', data={'_upload-delete': 'checked'}):
my_form = TestForm(helpers.get_form_data()) my_form = TestForm(helpers.get_form_data())
ok_(my_form.validate()) ok_(my_form.validate())
...@@ -83,6 +87,24 @@ def test_upload_field(): ...@@ -83,6 +87,24 @@ def test_upload_field():
ok_(not op.exists(op.join(path, 'test2.txt'))) ok_(not op.exists(op.join(path, 'test2.txt')))
# Check overwrite
_remove_testfiles()
my_form_ow = TestNoOverWriteForm()
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(my_form_ow.validate())
my_form_ow.populate_obj(dummy)
eq_(dummy.upload, 'test1.txt')
ok_(op.exists(op.join(path, 'test1.txt')))
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(not my_form_ow.validate())
_remove_testfiles()
def test_image_upload_field(): def test_image_upload_field():
app = Flask(__name__) app = Flask(__name__)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment