Commit 207d23fd authored by Petrus J.v.Rensburg's avatar Petrus J.v.Rensburg

Merge branch 'master' into examples

parents 2d3d1d63 6b6fe519
Changelog
=========
1.2.0
-----
* Codebase was migrated to Flask-Admin GitHub organization
* Automatically inject Flask-WTF CSRF token to internal Flask-Admin forms
* MapBox v4 support for GeoAlchemy
* Updated translations with help of CrowdIn
* Show warning if field was ignored in form rendering rules
* Simple AppEngine backend
* Optional support for Font Awesome in templates and menus
* Bug fixes
1.1.0
-----
......@@ -43,21 +55,3 @@ Highlights:
* Support for newer wtforms versions
* `form_rules` property that affects both create and edit forms
* Lots of bugfixes
1.0.7
-----
Full change log and feature walkthrough can be found `here <http://mrjoes.github.io/2013/10/21/flask-admin-107.html>`_.
Highlights:
* Python 3 support
* AJAX-based foreign-key data loading for all backends
* New, optional, rule-based form rendering engine
* MongoEngine fixes and features: GridFS support, nested subdocument configuration and much more
* Greatly improved and more configurable inline models
* New WTForms fields and widgets
* `form_extra_columns` allows adding custom columns to the form declaratively
* Redis cli
* SQLAlchemy backend can handle inherited models with multiple PKs
* Lots of bug fixes
......@@ -101,7 +101,8 @@ class FileView(sqla.ModelView):
form_args = {
'path': {
'label': 'File',
'base_path': file_path
'base_path': file_path,
'allow_overwrite': False
}
}
......
......@@ -107,7 +107,7 @@ class PostAdmin(sqla.ModelView):
# List of columns that can be sorted. For 'user' column, use User.username as
# a column.
column_sortable_list = ('title', ('user', User.username), 'date')
column_sortable_list = ('title', ('user', 'user.username'), 'date')
# Rename 'title' columns to 'Post Title' in list view
column_labels = dict(title='Post Title')
......
__version__ = '1.1.1-dev'
__version__ = '1.2.0'
__author__ = 'Serge S. Koval'
__email__ = 'serge.koval+github@gmail.com'
......
......@@ -188,7 +188,7 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
"""
self.name = name
self.category = category
self.endpoint = endpoint
self.endpoint = self._get_endpoint(endpoint)
self.url = url
self.static_folder = static_folder
self.static_url_path = static_url_path
......@@ -206,6 +206,16 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if self._default_view is None:
raise Exception(u'Attempted to instantiate admin view %s without default view' % self.__class__.__name__)
def _get_endpoint(self, endpoint):
"""
Generate Flask endpoint name. By default converts class name to lower case if endpoint is
not explicitly provided.
"""
if endpoint:
return endpoint
return self.__class__.__name__.lower()
def create_blueprint(self, admin):
"""
Create Flask blueprint.
......@@ -213,10 +223,6 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
# Store admin instance
self.admin = admin
# If endpoint name is not provided, get it from the class name
if self.endpoint is None:
self.endpoint = self.__class__.__name__.lower()
# If the static_url_path is not provided, use the admin's
if not self.static_url_path:
self.static_url_path = admin.static_url_path
......@@ -234,15 +240,13 @@ class BaseView(with_metaclass(AdminViewMeta, BaseViewClass)):
if not self.url.startswith('/'):
self.url = '%s/%s' % (self.admin.url, self.url)
# If we're working from the root of the site, set prefix to None
if self.url == '/':
self.url = None
# prevent admin static files from conflicting with flask static files
if not self.static_url_path:
self.static_folder='static'
self.static_url_path='/static/admin'
self.static_folder = 'static'
self.static_url_path = '/static/admin'
# If name is not povided, use capitalized endpoint name
if self.name is None:
......
......@@ -484,7 +484,7 @@ class ModelView(BaseModelView):
query = self._search(query, search)
# Get count
count = query.count()
count = query.count() if not self.simple_list_pager else None
# Sorting
if sort_column:
......@@ -592,7 +592,7 @@ class ModelView(BaseModelView):
return False
else:
self.after_model_delete(model)
return True
......
......@@ -339,7 +339,7 @@ class ModelView(BaseModelView):
query = f.apply(query, f.clean(value))
# Get count
count = query.count()
count = query.count() if not self.simple_list_pager else None
# Apply sorting
if sort_column is not None:
......@@ -417,7 +417,7 @@ class ModelView(BaseModelView):
return False
else:
self.after_model_delete(model)
return True
# Default model actions
......@@ -443,6 +443,7 @@ class ModelView(BaseModelView):
query = self.model.select().filter(model_pk << ids)
for m in query:
self.on_model_delete(m)
m.delete_instance(recursive=True)
count += 1
......
......@@ -222,7 +222,7 @@ class ModelView(BaseModelView):
query = self._search(query, search)
# Get count
count = self.coll.find(query).count()
count = self.coll.find(query).count() if not self.simple_list_pager else None
# Sorting
sort_by = None
......@@ -337,7 +337,7 @@ class ModelView(BaseModelView):
return False
else:
self.after_model_delete(model)
return True
# Default model actions
......
import warnings
import time
import datetime
from flask_admin.babel import lazy_gettext
from flask_admin.model import filters
from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_
class BaseSQLAFilter(filters.BaseFilter):
"""
Base SQLAlchemy filter.
......@@ -28,64 +25,70 @@ class BaseSQLAFilter(filters.BaseFilter):
self.column = column
def get_column(self, alias):
return self.column if alias is None else getattr(alias, self.column.key)
def apply(self, query, value, alias=None):
return super(self, BaseSQLAFilter).apply(query, value)
# Common filters
class FilterEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column == value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) == value)
def operation(self):
return lazy_gettext('equals')
class FilterNotEqual(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column != value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) != value)
def operation(self):
return lazy_gettext('not equal')
class FilterLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(self.column.ilike(stmt))
return query.filter(self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('contains')
class FilterNotLike(BaseSQLAFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
stmt = tools.parse_like_term(value)
return query.filter(~self.column.ilike(stmt))
return query.filter(~self.get_column(alias).ilike(stmt))
def operation(self):
return lazy_gettext('not contains')
class FilterGreater(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column > value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) > value)
def operation(self):
return lazy_gettext('greater than')
class FilterSmaller(BaseSQLAFilter):
def apply(self, query, value):
return query.filter(self.column < value)
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias) < value)
def operation(self):
return lazy_gettext('smaller than')
class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
if value == '1':
return query.filter(self.column == None)
return query.filter(self.get_column(alias) == None)
else:
return query.filter(self.column != None)
return query.filter(self.get_column(alias) != None)
def operation(self):
return lazy_gettext('empty')
......@@ -98,17 +101,18 @@ class FilterInList(BaseSQLAFilter):
def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value):
return query.filter(self.column.in_(value))
def apply(self, query, value, alias=None):
return query.filter(self.get_column(alias).in_(value))
def operation(self):
return lazy_gettext('in list')
class FilterNotInList(FilterInList):
def apply(self, query, value):
def apply(self, query, value, alias=None):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(~self.column.in_(value), self.column == None))
column = self.get_column(alias)
return query.filter(or_(~column.in_(value), column == None))
def operation(self):
return lazy_gettext('not in list')
......@@ -194,16 +198,16 @@ class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
options,
data_type='daterangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateNotBetweenFilter(DateBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
# ~between() isn't possible until sqlalchemy 1.0.0
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -232,15 +236,15 @@ class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
options,
data_type='datetimerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......@@ -269,15 +273,15 @@ class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
options,
data_type='timerangepicker')
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(self.column.between(start, end))
return query.filter(self.get_column(alias).between(start, end))
class TimeNotBetweenFilter(TimeBetweenFilter):
def apply(self, query, value):
def apply(self, query, value, alias=None):
start, end = value
return query.filter(not_(self.column.between(start, end)))
return query.filter(not_(self.get_column(alias).between(start, end)))
def operation(self):
return lazy_gettext('not between')
......
This diff is collapsed.
......@@ -51,12 +51,21 @@ class FileUploadInput(object):
template = self.data_template if field.data else self.empty_template
if field.errors:
template = self.empty_template
if field.data and isinstance(field.data, FileStorage):
value = field.data.filename
else:
value = field.data or ''
return HTMLString(template % {
'text': html_params(type='text',
readonly='readonly',
value=field.data,
value=value,
name=field.name),
'file': html_params(type='file',
value=value,
**kwargs),
'marker': '_%s-delete' % field.name
})
......@@ -122,7 +131,7 @@ class FileUploadField(fields.StringField):
def __init__(self, label=None, validators=None,
base_path=None, relative_path=None,
namegen=None, allowed_extensions=None,
permission=0o666,
permission=0o666, allow_overwrite=True,
**kwargs):
"""
Constructor.
......@@ -154,6 +163,11 @@ class FileUploadField(fields.StringField):
:param allowed_extensions:
List of allowed extensions. If not provided, will allow any file.
:param allow_overwrite:
Whether to overwrite existing files in upload directory. Defaults to `True`.
.. versionadded:: 1.1.1
The `allow_overwrite` parameter was added.
"""
self.base_path = base_path
self.relative_path = relative_path
......@@ -161,6 +175,7 @@ class FileUploadField(fields.StringField):
self.namegen = namegen or namegen_filename
self.allowed_extensions = allowed_extensions
self.permission = permission
self._allow_overwrite = allow_overwrite
self._should_delete = False
......@@ -188,6 +203,11 @@ class FileUploadField(fields.StringField):
def pre_validate(self, form):
if self._is_uploaded_file(self.data) and not self.is_file_allowed(self.data.filename):
raise ValidationError(gettext('Invalid file extension'))
# Handle overwriting existing content
if not self._is_uploaded_file(self.data):
return
if self._allow_overwrite == False and os.path.exists(self._get_path(self.data.filename)):
raise ValidationError(gettext('File "%s" already exists.' % self.data.filename))
def process(self, formdata, data=unset_value):
if formdata:
......@@ -253,6 +273,9 @@ class FileUploadField(fields.StringField):
if not op.exists(op.dirname(path)):
os.makedirs(os.path.dirname(path), self.permission | 0o111)
if self._allow_overwrite == False and os.path.exists(path):
raise ValueError(gettext('File "%s" already exists.' % path))
data.save(path)
return filename
......
......@@ -15,7 +15,7 @@ from flask_admin.form import BaseForm, FormOpts, rules
from flask_admin.model import filters, typefmt
from flask_admin.actions import ActionsMixin
from flask_admin.helpers import (get_form_data, validate_form_on_submit,
get_redirect_target, flash_errors)
get_redirect_target, flash_errors)
from flask_admin.tools import rec_getattr
from flask_admin._backwards import ObsoleteAttr
from flask_admin._compat import iteritems, OrderedDict, as_unicode
......@@ -321,6 +321,12 @@ class BaseModelView(BaseView, ActionsMixin):
Controls if the primary key should be displayed in the list view.
"""
simple_list_pager = False
"""
Enable or disable simple list pager.
If enabled, model interface would not run count query and will only show prev/next pager buttons.
"""
form = None
"""
Form class. Override if you want to use custom form for your model.
......@@ -563,28 +569,30 @@ class BaseModelView(BaseView, ActionsMixin):
:param menu_icon_value:
Icon glyph name or URL, depending on `menu_icon_type` setting
"""
self.model = model
# If name not provided, it is model name
if name is None:
name = '%s' % self._prettify_class_name(model.__name__)
# If endpoint not provided, it is model name
if endpoint is None:
endpoint = model.__name__.lower()
super(BaseModelView, self).__init__(name, category, endpoint, url, static_folder,
menu_class_name=menu_class_name,
menu_icon_type=menu_icon_type,
menu_icon_value=menu_icon_value)
self.model = model
# Actions
self.init_actions()
# Scaffolding
self._refresh_cache()
# Endpoint
def _get_endpoint(self, endpoint):
if endpoint:
return super(BaseModelView, self)._get_endpoint(endpoint)
return self.model.__name__.lower()
# Caching
def _refresh_forms_cache(self):
# Forms
......@@ -617,7 +625,7 @@ class BaseModelView(BaseView, ActionsMixin):
self._filter_groups[flt.name].append({
'index': i,
'arg': self.get_filter_arg(i, flt),
'operation': as_unicode(flt.operation()),
'operation': flt.operation(),
'options': flt.get_options(self) or None,
'type': flt.data_type
})
......@@ -852,6 +860,27 @@ class BaseModelView(BaseView, ActionsMixin):
else:
return str(index)
def _get_filter_groups(self):
"""
Returns non-lazy version of filter strings
"""
if self._filter_groups:
results = OrderedDict()
for key, value in iteritems(self._filter_groups):
items = []
for item in value:
copy = dict(item)
copy['operation'] = as_unicode(copy['operation'])
items.append(copy)
results[key] = items
return results
return None
# Form helpers
def scaffold_form(self):
"""
......@@ -1018,7 +1047,7 @@ class BaseModelView(BaseView, ActionsMixin):
missing_fields.append(field.name)
return missing_fields
def _show_missing_fields_warning(self, text):
warnings.warn(text)
......@@ -1200,7 +1229,7 @@ class BaseModelView(BaseView, ActionsMixin):
By default do nothing.
"""
pass
def after_model_delete(self, model):
"""
Perform some actions after a model was deleted and
......@@ -1214,7 +1243,7 @@ class BaseModelView(BaseView, ActionsMixin):
:param model:
Model that was deleted
"""
pass
pass
def on_form_prefill (self, form, id):
"""
......@@ -1463,9 +1492,12 @@ class BaseModelView(BaseView, ActionsMixin):
view_args.search, view_args.filters)
# Calculate number of pages
num_pages = count // self.page_size
if count % self.page_size != 0:
num_pages += 1
if count is not None:
num_pages = count // self.page_size
if count % self.page_size != 0:
num_pages += 1
else:
num_pages = None
# Various URL generation helpers
def pager_url(p):
......@@ -1508,6 +1540,7 @@ class BaseModelView(BaseView, ActionsMixin):
pager_url=pager_url,
num_pages=num_pages,
page=view_args.page,
page_size=self.page_size,
# Sorting
sort_column=view_args.sort,
......@@ -1521,7 +1554,7 @@ class BaseModelView(BaseView, ActionsMixin):
# Filters
filters=self._filters,
filter_groups=self._filter_groups,
filter_groups=self._get_filter_groups(),
active_filters=view_args.filters,
# Actions
......
......@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %}
<i class="fa fa-{{ icon_value }}"></i>
{% elif icon_type == 'image' %}
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image"></img>
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image">
{% elif icon_type == 'image-url' %}
<img src="item.icon_value" alt="menu image"></img>
<img src="item.icon_value" alt="menu image">
{% endif %}
{% endif %}
{%- endmacro %}
......
......@@ -76,6 +76,31 @@
{% endif %}
{%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<div class="pagination">
<ul>
{% if page > 0 %}
<li>
<a href="{{ generator(page - 1) }}">&lt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(0) }}">&lt;</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a href="{{ generator(page + 1) }}">&gt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(page) }}">&gt;</a>
</li>
{% endif %}
</ul>
</div>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %}
......
......@@ -13,7 +13,7 @@
{% block model_menu_bar %}
<ul class="nav nav-tabs actions-nav">
<li class="active">
<a href="javascript:void(0)">{{ _gettext('List') }} ({{ count }})</a>
<a href="javascript:void(0)">{{ _gettext('List') }}{% if count %} ({{ count }}){% endif %}</a>
</li>
{% if admin_view.can_create %}
<li>
......@@ -110,7 +110,11 @@
<form class="icon" method="POST" action="{{ get_url('.delete_view') }}">
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="{{ _gettext('Delete record') }}">
<i class="fa fa-trash icon-trash"></i>
</button>
......@@ -147,7 +151,13 @@
</tr>
{% endfor %}
</table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }}
......
......@@ -7,9 +7,9 @@
{% elif icon_type == 'fa' %}
<i class="fa {{ icon_value }}"></i>
{% elif icon_type == 'image' %}
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image"></img>
<img src="{{ url_for('static', filename=icon_value) }}" alt="menu image">
{% elif icon_type == 'image-url' %}
<img src="item.icon_value" alt="menu image"></img>
<img src="item.icon_value" alt="menu image">
{% endif %}
{% endif %}
{%- endmacro %}
......
......@@ -74,6 +74,29 @@
{% endif %}
{%- endmacro %}
{% macro simple_pager(page, have_next, generator) -%}
<ul class="pagination">
{% if page > 0 %}
<li>
<a href="{{ generator(page - 1) }}">&lt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(0) }}">&lt;</a>
</li>
{% endif %}
{% if have_next %}
<li>
<a href="{{ generator(page + 1) }}">&gt;</a>
</li>
{% else %}
<li class="disabled">
<a href="{{ generator(page) }}">&gt;</a>
</li>
{% endif %}
</ul>
{%- endmacro %}
{# ---------------------- Forms -------------------------- #}
{% macro render_field(form, field, kwargs={}, caller=None) %}
{% set direct_error = h.is_field_error(field.errors) %}
......
......@@ -13,7 +13,7 @@
{% block model_menu_bar %}
<ul class="nav nav-tabs actions-nav">
<li class="active">
<a href="javascript:void(0)">{{ _gettext('List') }} ({{ count }})</a>
<a href="javascript:void(0)">{{ _gettext('List') }}{% if count %} ({{ count }}){% endif %}</a>
</li>
{% if admin_view.can_create %}
<li>
......@@ -110,7 +110,11 @@
<form class="icon" method="POST" action="{{ get_url('.delete_view') }}">
{{ delete_form.id(value=get_pk_value(row)) }}
{{ delete_form.url(value=return_url) }}
{% if delete_form.csrf_token %}
{{ delete_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<button onclick="return confirm('{{ _gettext('Are you sure you want to delete this record?') }}');" title="Delete record">
<span class="fa fa-trash glyphicon glyphicon-trash"></span>
</button>
......@@ -146,7 +150,13 @@
</tr>
{% endfor %}
</table>
{% block list_pager %}
{% if num_pages is not none %}
{{ lib.pager(page, num_pages, pager_url) }}
{% else %}
{{ lib.simple_pager(page, data|length == page_size, pager_url) }}
{% endif %}
{% endblock %}
{% endblock %}
{{ actionlib.form(actions, get_url('.action_view')) }}
......
......@@ -948,3 +948,20 @@ def test_form_args_embeddeddoc():
eq_(form.timestamp.label.text, 'Last Updated Time')
# This is the failure
eq_(form.info.label.text, 'Information')
def test_simple_list_pager():
app, db, admin = setup()
Model1, _ = create_models(db)
class TestModelView(CustomModelView):
simple_list_pager = True
def get_count_query(self):
assert False
view = TestModelView(Model1)
admin.add_view(view)
count, data = view.get_list(0, None, None, None, None)
ok_(count is None)
......@@ -12,6 +12,7 @@ from . import setup
from datetime import datetime, time, date
class CustomModelView(ModelView):
def __init__(self, model, session,
name=None, category=None, endpoint=None, url=None,
......@@ -259,10 +260,11 @@ def test_column_searchable_list():
eq_(view._search_supported, True)
eq_(len(view._search_fields), 2)
ok_(isinstance(view._search_fields[0], db.Column))
ok_(isinstance(view._search_fields[1], db.Column))
eq_(view._search_fields[0].name, 'string_field')
eq_(view._search_fields[1].name, 'int_field')
ok_(isinstance(view._search_fields[0][0], db.Column))
ok_(isinstance(view._search_fields[1][0], db.Column))
eq_(view._search_fields[0][0].name, 'string_field')
eq_(view._search_fields[1][0].name, 'int_field')
db.session.add(Model2('model1-test', 5000))
db.session.add(Model2('model2-test', 9000))
......@@ -417,6 +419,8 @@ def test_column_filters():
)
admin.add_view(view)
client = app.test_client()
eq_(len(view._filters), 7)
eq_([(f['index'], f['operation']) for f in view._filter_groups[u'Test1']],
......@@ -515,21 +519,23 @@ def test_column_filters():
fill_db(db, Model1, Model2)
client = app.test_client()
# Test equals
rv = client.get('/admin/model1/?flt0_0=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
# the filter value is always in "data"
# need to check a different column than test1 for the expected row
ok_('test2_val_1' in data)
ok_('test1_val_2' not in data)
# Test NOT IN filter
rv = client.get('/admin/model1/?flt0_6=test1_val_1')
eq_(rv.status_code, 200)
data = rv.data.decode('utf-8')
ok_('test2_val_1' not in data)
ok_('test1_val_2' in data)
ok_('test2_val_1' not in data)
# Test string filter
view = CustomModelView(Model1, db.session,
......@@ -1103,9 +1109,11 @@ def test_column_filters():
rv = client.get('/admin/_relation_test/?flt1_0=test1_val_1')
data = rv.data.decode('utf-8')
ok_('test1_val_1' in data)
ok_('test1_val_2' not in data)
def test_url_args():
app, db, admin = setup()
......@@ -1680,3 +1688,123 @@ def test_safe_redirect():
assert_true(rv.location.startswith('http://localhost/admin/model1/edit/'))
assert_true('url=%2Fadmin%2Fmodel1%2F' in rv.location)
assert_true('id=2' in rv.location)
def test_simple_list_pager():
app, db, admin = setup()
Model1, _ = create_models(db)
db.create_all()
class TestModelView(CustomModelView):
simple_list_pager = True
def get_count_query(self):
assert False
view = TestModelView(Model1, db.session)
admin.add_view(view)
count, data = view.get_list(0, None, None, None, None)
assert_true(count is None)
def test_advanced_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
model1 = db.relationship(Model1, backref='model2')
class Model3(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
model2_id = db.Column(db.Integer, db.ForeignKey(Model2.id))
model2 = db.relationship(Model2, backref='model3')
view1 = CustomModelView(Model1, db.session)
admin.add_view(view1)
view2 = CustomModelView(Model2, db.session)
admin.add_view(view2)
view3 = CustomModelView(Model3, db.session)
admin.add_view(view3)
# Test joins
attr, path = view2._get_field_with_path('model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model2.model1])
attr, path = view1._get_field_with_path('model2.val2')
eq_(attr, Model2.val2)
eq_(id(path[0]), id(Model1.model2))
attr, path = view3._get_field_with_path('model2.model1.val1')
eq_(attr, Model1.val1)
eq_(path, [Model3.model2, Model2.model1])
# Test how joins are applied
query = view3.get_query()
joins = {}
q1, joins, alias = view3._apply_path_joins(query, joins, path)
ok_((True, Model3.model2) in joins)
ok_((True, Model2.model1) in joins)
ok_(alias is not None)
# Check if another join would use same path
attr, path = view2._get_field_with_path('model1.test')
q2, joins, alias = view2._apply_path_joins(query, joins, path)
eq_(len(joins), 2)
for p in q2._join_entities:
ok_(p in q1._join_entities)
ok_(alias is not None)
# Check if normal properties are supported by _get_field_with_path
attr, path = view2._get_field_with_path(Model1.test)
eq_(attr, Model1.test)
eq_(path, [Model1.__table__])
q3, joins, alias = view2._apply_path_joins(view2.get_query(), joins, path)
eq_(len(joins), 3)
ok_(alias is None)
def test_multipath_joins():
app, db, admin = setup()
class Model1(db.Model):
id = db.Column(db.Integer, primary_key=True)
val1 = db.Column(db.String(20))
test = db.Column(db.String(20))
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
val2 = db.Column(db.String(20))
first_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
first = db.relationship(Model1, backref='first', foreign_keys=[first_id])
second_id = db.Column(db.Integer, db.ForeignKey(Model1.id))
second = db.relationship(Model1, backref='second', foreign_keys=[second_id])
db.create_all()
view = CustomModelView(Model2, db.session, filters=['first.test'])
admin.add_view(view)
client = app.test_client()
rv = client.get('/admin/model2/')
eq_(rv.status_code, 200)
......@@ -76,7 +76,7 @@ def test_baseview_defaults():
view = MockView()
eq_(view.name, None)
eq_(view.category, None)
eq_(view.endpoint, None)
eq_(view.endpoint, 'mockview')
eq_(view.url, None)
eq_(view.static_folder, None)
eq_(view.admin, None)
......@@ -388,3 +388,12 @@ def test_menu_links():
def check_class_name():
view = MockView()
eq_(view.name, 'Mock View')
def check_endpoint():
class CustomView(MockView):
def _get_endpoint(self, endpoint):
return 'admin.' + super(CustomView, self)._get_endpoint(endpoint)
view = CustomView()
eq_(view.endpoint, 'admin.customview')
......@@ -40,6 +40,9 @@ def test_upload_field():
class TestForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path)
class TestNoOverWriteForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path, allow_overwrite=False)
class Dummy(object):
pass
......@@ -74,6 +77,7 @@ def test_upload_field():
# Check delete
with app.test_request_context(method='POST', data={'_upload-delete': 'checked'}):
my_form = TestForm(helpers.get_form_data())
ok_(my_form.validate())
......@@ -83,6 +87,24 @@ def test_upload_field():
ok_(not op.exists(op.join(path, 'test2.txt')))
# Check overwrite
_remove_testfiles()
my_form_ow = TestNoOverWriteForm()
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(my_form_ow.validate())
my_form_ow.populate_obj(dummy)
eq_(dummy.upload, 'test1.txt')
ok_(op.exists(op.join(path, 'test1.txt')))
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(not my_form_ow.validate())
_remove_testfiles()
def test_image_upload_field():
app = Flask(__name__)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment