Unverified Commit 3e558465 authored by Petrus Janse van Rensburg's avatar Petrus Janse van Rensburg Committed by GitHub

Merge pull request #1742 from flask-admin/sqlalchemy-utils-types

Sqlalchemy utils types
parents d29796b6 ce6bdc16
...@@ -5,6 +5,7 @@ Next release ...@@ -5,6 +5,7 @@ Next release
----- -----
* Fix display of inline x-editable boolean fields on list view * Fix display of inline x-editable boolean fields on list view
* Add support for several SQLAlchemy-Utils data types
1.5.3 1.5.3
----- -----
......
import os import os
import os.path as op import os.path as op
from flask import Flask from flask import Flask, Markup
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import composite
import uuid
import random
import string
from wtforms import validators from wtforms import validators
...@@ -13,6 +17,13 @@ from flask_admin.contrib.sqla import filters ...@@ -13,6 +17,13 @@ from flask_admin.contrib.sqla import filters
from flask_admin.contrib.sqla.form import InlineModelConverter from flask_admin.contrib.sqla.form import InlineModelConverter
from flask_admin.contrib.sqla.fields import InlineModelFormList from flask_admin.contrib.sqla.fields import InlineModelFormList
from flask_admin.contrib.sqla.filters import BaseSQLAFilter, FilterEqual from flask_admin.contrib.sqla.filters import BaseSQLAFilter, FilterEqual
from flask_admin.babel import gettext
from sqlalchemy_utils import ChoiceType, EmailType, UUIDType, URLType, CurrencyType, Currency
from colour import Color
from sqlalchemy_utils import ColorType, ArrowType, IPAddressType, TimezoneType
import arrow
import enum
# Create application # Create application
...@@ -31,30 +42,57 @@ app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + app.config['DATABASE_FILE ...@@ -31,30 +42,57 @@ app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + app.config['DATABASE_FILE
app.config['SQLALCHEMY_ECHO'] = True app.config['SQLALCHEMY_ECHO'] = True
db = SQLAlchemy(app) db = SQLAlchemy(app)
class EnumChoices(enum.Enum):
first = 1
second = 2
AVAILABLE_USER_TYPES = [
(u'admin', u'Admin'),
(u'content-writer', u'Content writer'),
(u'editor', u'Editor'),
(u'regular-user', u'Regular user'),
]
# Create models # Create models
class User(db.Model): class User(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(UUIDType(binary=False), default=uuid.uuid4, primary_key=True)
# use a regular string field, for which we can specify a list of available choices later on
type = db.Column(db.String(100))
# fixed choices can be handled in a number of different ways:
enum_choice_field = db.Column(db.Enum(EnumChoices), nullable=True)
sqla_utils_choice_field = db.Column(ChoiceType(AVAILABLE_USER_TYPES), nullable=True)
sqla_utils_enum_choice_field = db.Column(ChoiceType(EnumChoices, impl=db.Integer()), nullable=True)
first_name = db.Column(db.String(100)) first_name = db.Column(db.String(100))
last_name = db.Column(db.String(100)) last_name = db.Column(db.String(100))
email = db.Column(db.String(120), unique=True)
pets = db.relationship('Pet', backref='owner')
def __str__(self): # some sqlalchemy_utils data types (see https://sqlalchemy-utils.readthedocs.io/)
return "{}, {}".format(self.last_name, self.first_name) email = db.Column(EmailType, unique=True, nullable=False)
website = db.Column(URLType)
ip_address = db.Column(IPAddressType)
currency = db.Column(CurrencyType, nullable=True, default=None)
timezone = db.Column(TimezoneType(backend='pytz'))
def __repr__(self): dialling_code = db.Column(db.Integer())
return "{}: {}".format(self.id, self.__str__()) local_phone_number = db.Column(db.String(10))
featured_post_id = db.Column(db.Integer, db.ForeignKey('post.id'))
featured_post = db.relationship('Post', foreign_keys=[featured_post_id])
class Pet(db.Model): @hybrid_property
id = db.Column(db.Integer, primary_key=True) def phone_number(self):
name = db.Column(db.String(50), nullable=False) if self.dialling_code and self.local_phone_number:
person_id = db.Column(db.Integer, db.ForeignKey('user.id')) number = str(self.local_phone_number)
available = db.Column(db.Boolean) return "+{} ({}){} {} {}".format(self.dialling_code, number[0], number[1:3], number[3:6], number[6::])
return
def __str__(self): def __str__(self):
return self.name return "{}, {}".format(self.last_name, self.first_name)
def __repr__(self):
return "{}: {}".format(self.id, self.__str__())
# Create M2M table # Create M2M table
...@@ -70,9 +108,12 @@ class Post(db.Model): ...@@ -70,9 +108,12 @@ class Post(db.Model):
text = db.Column(db.Text, nullable=False) text = db.Column(db.Text, nullable=False)
date = db.Column(db.Date) date = db.Column(db.Date)
user_id = db.Column(db.Integer(), db.ForeignKey(User.id)) # some sqlalchemy_utils data types (see https://sqlalchemy-utils.readthedocs.io/)
user = db.relationship(User, backref='posts') background_color = db.Column(ColorType)
created_at = db.Column(ArrowType, default=arrow.utcnow())
user_id = db.Column(UUIDType(binary=False), db.ForeignKey(User.id))
user = db.relationship(User, foreign_keys=[user_id], backref='posts')
tags = db.relationship('Tag', secondary=post_tags_table) tags = db.relationship('Tag', secondary=post_tags_table)
def __str__(self): def __str__(self):
...@@ -81,25 +122,12 @@ class Post(db.Model): ...@@ -81,25 +122,12 @@ class Post(db.Model):
class Tag(db.Model): class Tag(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Unicode(64)) name = db.Column(db.Unicode(64), unique=True)
def __str__(self): def __str__(self):
return "{}".format(self.name) return "{}".format(self.name)
class UserInfo(db.Model):
id = db.Column(db.Integer, primary_key=True)
key = db.Column(db.String(64), nullable=False)
value = db.Column(db.String(64))
user_id = db.Column(db.Integer(), db.ForeignKey(User.id))
user = db.relationship(User, backref='info')
def __str__(self):
return "{} - {}".format(self.key, self.value)
class Tree(db.Model): class Tree(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(64)) name = db.Column(db.String(64))
...@@ -110,17 +138,6 @@ class Tree(db.Model): ...@@ -110,17 +138,6 @@ class Tree(db.Model):
return "{}".format(self.name) return "{}".format(self.name)
class Screen(db.Model):
__tablename__ = 'screen'
id = db.Column(db.Integer, primary_key=True)
width = db.Column(db.Integer, nullable=False)
height = db.Column(db.Integer, nullable=False)
@hybrid_property
def number_of_pixels(self):
return self.width * self.height
# Flask views # Flask views
@app.route('/') @app.route('/')
def index(): def index():
...@@ -140,58 +157,109 @@ class FilterLastNameBrown(BaseSQLAFilter): ...@@ -140,58 +157,109 @@ class FilterLastNameBrown(BaseSQLAFilter):
# Customized User model admin # Customized User model admin
inline_form_options = { def phone_number_formatter (view, context, model, name):
'form_label': "Info item", return Markup("<nobr>{}</nobr>".format(model.phone_number)) if model.phone_number else None
'form_columns': ['id', 'key', 'value'],
'form_args': None, def is_numberic_validator(form, field):
'form_extra_fields': None, if field.data and not field.data.isdigit():
} raise validators.ValidationError(gettext('Only numbers are allowed.'))
class UserAdmin(sqla.ModelView): class UserAdmin(sqla.ModelView):
can_view_details = True # show a modal dialog with records details
action_disallowed_list = ['delete', ] action_disallowed_list = ['delete', ]
column_display_pk = True
form_choices = {
'type': AVAILABLE_USER_TYPES,
}
form_args = {
'dialling_code': {'label': 'Dialling code'},
'local_phone_number': {
'label': 'Phone number',
'validators': [is_numberic_validator]
},
}
form_widget_args = {
'id':{
'readonly':True
}
}
column_list = [ column_list = [
'type',
'last_name',
'first_name',
'email',
'ip_address',
'currency',
'timezone',
'phone_number',
]
column_searchable_list = [
'first_name',
'last_name',
'email',
]
column_editable_list = ['type', 'currency', 'timezone']
column_details_list = [
'id', 'id',
'featured_post',
'website',
'enum_choice_field',
'sqla_utils_choice_field',
'sqla_utils_enum_choice_field',
] + column_list
form_columns = [
'id',
'type',
'featured_post',
'enum_choice_field',
'sqla_utils_choice_field',
'sqla_utils_enum_choice_field',
'last_name', 'last_name',
'first_name', 'first_name',
'email', 'email',
'pets', 'website',
'dialling_code',
'local_phone_number',
] ]
column_auto_select_related = True
column_default_sort = [('last_name', False), ('first_name', False)] # sort on multiple columns column_default_sort = [('last_name', False), ('first_name', False)] # sort on multiple columns
# custom filter: each filter in the list is a filter operation (equals, not equals, etc) # custom filter: each filter in the list is a filter operation (equals, not equals, etc)
# filters with the same name will appear as operations under the same filter # filters with the same name will appear as operations under the same filter
column_filters = [ column_filters = [
'first_name',
FilterEqual(column=User.last_name, name='Last Name'), FilterEqual(column=User.last_name, name='Last Name'),
FilterLastNameBrown(column=User.last_name, name='Last Name', FilterLastNameBrown(column=User.last_name, name='Last Name',
options=(('1', 'Yes'), ('0', 'No'))) options=(('1', 'Yes'), ('0', 'No'))),
'email',
'ip_address',
'currency',
'timezone',
] ]
inline_models = [(UserInfo, inline_form_options), ] column_formatters = {'phone_number': phone_number_formatter}
# setup create & edit forms so that only 'available' pets can be selected # setup create & edit forms so that only posts created by this user can be selected as 'featured'
def create_form(self): def create_form(self):
return self._use_filtered_parent( return self._filtered_posts(
super(UserAdmin, self).create_form() super(UserAdmin, self).create_form()
) )
def edit_form(self, obj): def edit_form(self, obj):
return self._use_filtered_parent( return self._filtered_posts(
super(UserAdmin, self).edit_form(obj) super(UserAdmin, self).edit_form(obj)
) )
def _use_filtered_parent(self, form): def _filtered_posts(self, form):
form.pets.query_factory = self._get_parent_list form.featured_post.query_factory = lambda: Post.query.filter(Post.user_id == form._obj.id).all()
return form return form
def _get_parent_list(self):
# only show available pets in the form
return Pet.query.filter_by(available=True).all()
# Customized Post model admin # Customized Post model admin
class PostAdmin(sqla.ModelView): class PostAdmin(sqla.ModelView):
column_list = ['id', 'user', 'title', 'date', 'tags'] column_display_pk = True
column_list = ['id', 'user', 'title', 'date', 'tags', 'background_color', 'created_at',]
column_editable_list = ['background_color', ]
column_default_sort = ('date', True) column_default_sort = ('date', True)
column_sortable_list = [ column_sortable_list = [
'id', 'id',
...@@ -213,6 +281,8 @@ class PostAdmin(sqla.ModelView): ...@@ -213,6 +281,8 @@ class PostAdmin(sqla.ModelView):
'user.last_name': 'last name', 'user.last_name': 'last name',
} }
column_filters = [ column_filters = [
'background_color',
'created_at',
'user', 'user',
'title', 'title',
'date', 'date',
...@@ -224,10 +294,10 @@ class PostAdmin(sqla.ModelView): ...@@ -224,10 +294,10 @@ class PostAdmin(sqla.ModelView):
export_types = ['csv', 'xls'] export_types = ['csv', 'xls']
# Pass arguments to WTForms. In this case, change label for text field to # Pass arguments to WTForms. In this case, change label for text field to
# be 'Big Text' and add required() validator. # be 'Big Text' and add DataRequired() validator.
form_args = dict( form_args = {
text=dict(label='Big Text', validators=[validators.required()]) 'text': dict(label='Big Text', validators=[validators.DataRequired()])
) }
form_ajax_refs = { form_ajax_refs = {
'user': { 'user': {
...@@ -250,14 +320,6 @@ class TreeView(sqla.ModelView): ...@@ -250,14 +320,6 @@ class TreeView(sqla.ModelView):
form_excluded_columns = ['children', ] form_excluded_columns = ['children', ]
class ScreenView(sqla.ModelView):
column_list = ['id', 'width', 'height', 'number_of_pixels'] # not that 'number_of_pixels' is a hybrid property, not a field
column_sortable_list = ['id', 'width', 'height', 'number_of_pixels']
# Flask-admin can automatically detect the relevant filters for hybrid properties.
column_filters = ('number_of_pixels', )
# Create admin # Create admin
admin = admin.Admin(app, name='Example: SQLAlchemy', template_mode='bootstrap3') admin = admin.Admin(app, name='Example: SQLAlchemy', template_mode='bootstrap3')
...@@ -265,14 +327,10 @@ admin = admin.Admin(app, name='Example: SQLAlchemy', template_mode='bootstrap3') ...@@ -265,14 +327,10 @@ admin = admin.Admin(app, name='Example: SQLAlchemy', template_mode='bootstrap3')
admin.add_view(UserAdmin(User, db.session)) admin.add_view(UserAdmin(User, db.session))
admin.add_view(sqla.ModelView(Tag, db.session)) admin.add_view(sqla.ModelView(Tag, db.session))
admin.add_view(PostAdmin(db.session)) admin.add_view(PostAdmin(db.session))
admin.add_view(sqla.ModelView(Pet, db.session, category="Other"))
admin.add_view(sqla.ModelView(UserInfo, db.session, category="Other"))
admin.add_view(TreeView(Tree, db.session, category="Other")) admin.add_view(TreeView(Tree, db.session, category="Other"))
admin.add_view(ScreenView(Screen, db.session, category="Other"))
admin.add_sub_category(name="Links", parent_name="Other") admin.add_sub_category(name="Links", parent_name="Other")
admin.add_link(MenuLink(name='Back Home', url='/', category='Links')) admin.add_link(MenuLink(name='Back Home', url='/', category='Links'))
admin.add_link(MenuLink(name='Google', url='http://www.google.com/', category='Links')) admin.add_link(MenuLink(name='External link', url='http://www.example.com/', category='Links'))
admin.add_link(MenuLink(name='Mozilla', url='http://mozilla.org/', category='Links'))
def build_sample_db(): def build_sample_db():
...@@ -298,13 +356,35 @@ def build_sample_db(): ...@@ -298,13 +356,35 @@ def build_sample_db():
'Ali', 'Mason', 'Mitchell', 'Rose', 'Davis', 'Davies', 'Rodriguez', 'Cox', 'Alexander' 'Ali', 'Mason', 'Mitchell', 'Rose', 'Davis', 'Davies', 'Rodriguez', 'Cox', 'Alexander'
] ]
countries = [
("ZA", "South Africa", 27, "ZAR", "Africa/Johannesburg"),
("BF", "Burkina Faso", 226, "XOF", "Africa/Ouagadougou"),
("US", "United States of America", 1, "USD", "America/New_York"),
("BR", "Brazil", 55, "BRL", "America/Sao_Paulo"),
("TZ", "Tanzania", 255, "TZS", "Africa/Dar_es_Salaam"),
("DE", "Germany", 49, "EUR", "Europe/Berlin"),
("CN", "China", 86, "CNY", "Asia/Shanghai"),
]
user_list = [] user_list = []
for i in range(len(first_names)): for i in range(len(first_names)):
user = User() user = User()
country = random.choice(countries)
user.type = random.choice(AVAILABLE_USER_TYPES)[0]
user.first_name = first_names[i] user.first_name = first_names[i]
user.last_name = last_names[i] user.last_name = last_names[i]
user.email = first_names[i].lower() + "@example.com" user.email = first_names[i].lower() + "@example.com"
user.info.append(UserInfo(key="foo", value="bar"))
user.website = "https://www.example.com"
user.ip_address = "127.0.0.1"
user.coutry = country[1]
user.currency = country[3]
user.timezone = country[4]
user.dialling_code = country[2]
user.local_phone_number = '0' + ''.join(random.choices('123456789', k=9))
user_list.append(user) user_list.append(user)
db.session.add(user) db.session.add(user)
...@@ -361,6 +441,7 @@ def build_sample_db(): ...@@ -361,6 +441,7 @@ def build_sample_db():
post.user = user post.user = user
post.title = entry['title'] post.title = entry['title']
post.text = entry['content'] post.text = entry['content']
post.background_color = random.choice(["#cccccc", "red", "lightblue", "#0f0"])
tmp = int(1000*random.random()) # random number between 0 and 1000: tmp = int(1000*random.random()) # random number between 0 and 1000:
post.date = datetime.datetime.now() - datetime.timedelta(days=tmp) post.date = datetime.datetime.now() - datetime.timedelta(days=tmp)
post.tags = random.sample(tag_list, 2) # select a couple of tags at random post.tags = random.sample(tag_list, 2) # select a couple of tags at random
...@@ -380,15 +461,6 @@ def build_sample_db(): ...@@ -380,15 +461,6 @@ def build_sample_db():
leaf.parent = branch leaf.parent = branch
db.session.add(leaf) db.session.add(leaf)
db.session.add(Pet(name='Dog', available=True))
db.session.add(Pet(name='Fish', available=True))
db.session.add(Pet(name='Cat', available=True))
db.session.add(Pet(name='Parrot', available=True))
db.session.add(Pet(name='Ocelot', available=False))
db.session.add(Screen(width=500, height=2000))
db.session.add(Screen(width=550, height=1900))
db.session.commit() db.session.commit()
return return
......
Flask Flask
Flask-Admin Flask-Admin
Flask-BabelEx
Flask-SQLAlchemy Flask-SQLAlchemy
tablib tablib
sqlalchemy_utils
arrow
colour
...@@ -2,6 +2,7 @@ from flask_admin.babel import lazy_gettext ...@@ -2,6 +2,7 @@ from flask_admin.babel import lazy_gettext
from flask_admin.model import filters from flask_admin.model import filters
from flask_admin.contrib.sqla import tools from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_ from sqlalchemy.sql import not_, or_
import enum
class BaseSQLAFilter(filters.BaseFilter): class BaseSQLAFilter(filters.BaseFilter):
...@@ -339,6 +340,102 @@ class EnumFilterNotInList(FilterNotInList): ...@@ -339,6 +340,102 @@ class EnumFilterNotInList(FilterNotInList):
return values return values
class ChoiceTypeEqualFilter(FilterEqual):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeEqualFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_type = None
# loop through choice 'values' to try and find an exact match
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if choice.name == user_query:
choice_type = choice.value
break
else:
for type, value in column.type.choices:
if value == user_query:
choice_type = type
break
if choice_type:
return query.filter(column == choice_type)
else:
return query.filter(column.in_([]))
class ChoiceTypeNotEqualFilter(FilterNotEqual):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeNotEqualFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_type = None
# loop through choice 'values' to try and find an exact match
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if choice.name == user_query:
choice_type = choice.value
break
else:
for type, value in column.type.choices:
if value == user_query:
choice_type = type
break
if choice_type:
# != can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(column != choice_type, column == None)) # noqa: E711
else:
return query
class ChoiceTypeLikeFilter(FilterLike):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeLikeFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_types = []
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if user_query.lower() in choice.name.lower():
choice_types.append(choice.value)
else:
for type, value in column.type.choices:
if user_query.lower() in value.lower():
choice_types.append(type)
if choice_types:
return query.filter(column.in_(choice_types))
else:
return query
class ChoiceTypeNotLikeFilter(FilterNotLike):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeNotLikeFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_types = []
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if user_query.lower() in choice.name.lower():
choice_types.append(choice.value)
else:
for type, value in column.type.choices:
if user_query.lower() in value.lower():
choice_types.append(type)
if choice_types:
# != can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(column.notin_(choice_types), column == None)) # noqa: E711
else:
return query
class UuidFilterEqual(FilterEqual, filters.BaseUuidFilter): class UuidFilterEqual(FilterEqual, filters.BaseUuidFilter):
pass pass
...@@ -359,6 +456,7 @@ class UuidFilterNotInList(filters.BaseUuidListFilter, FilterNotInList): ...@@ -359,6 +456,7 @@ class UuidFilterNotInList(filters.BaseUuidListFilter, FilterNotInList):
class FilterConverter(filters.BaseFilterConverter): class FilterConverter(filters.BaseFilterConverter):
strings = (FilterLike, FilterNotLike, FilterEqual, FilterNotEqual, strings = (FilterLike, FilterNotLike, FilterEqual, FilterNotEqual,
FilterEmpty, FilterInList, FilterNotInList) FilterEmpty, FilterInList, FilterNotInList)
string_key_filters = (FilterEqual, FilterNotEqual, FilterEmpty, FilterInList, FilterNotInList)
int_filters = (IntEqualFilter, IntNotEqualFilter, IntGreaterFilter, int_filters = (IntEqualFilter, IntNotEqualFilter, IntGreaterFilter,
IntSmallerFilter, FilterEmpty, IntInListFilter, IntSmallerFilter, FilterEmpty, IntInListFilter,
IntNotInListFilter) IntNotInListFilter)
...@@ -378,8 +476,11 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -378,8 +476,11 @@ class FilterConverter(filters.BaseFilterConverter):
time_filters = (TimeEqualFilter, TimeNotEqualFilter, TimeGreaterFilter, time_filters = (TimeEqualFilter, TimeNotEqualFilter, TimeGreaterFilter,
TimeSmallerFilter, TimeBetweenFilter, TimeNotBetweenFilter, TimeSmallerFilter, TimeBetweenFilter, TimeNotBetweenFilter,
FilterEmpty) FilterEmpty)
choice_type_filters = (ChoiceTypeEqualFilter, ChoiceTypeNotEqualFilter,
ChoiceTypeLikeFilter, ChoiceTypeNotLikeFilter, FilterEmpty)
uuid_filters = (UuidFilterEqual, UuidFilterNotEqual, FilterEmpty, uuid_filters = (UuidFilterEqual, UuidFilterNotEqual, FilterEmpty,
UuidFilterInList, UuidFilterNotInList) UuidFilterInList, UuidFilterNotInList)
arrow_type_filters = (DateTimeGreaterFilter, DateTimeSmallerFilter, FilterEmpty)
def convert(self, type_name, column, name, **kwargs): def convert(self, type_name, column, name, **kwargs):
filter_name = type_name.lower() filter_name = type_name.lower()
...@@ -391,10 +492,15 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -391,10 +492,15 @@ class FilterConverter(filters.BaseFilterConverter):
@filters.convert('string', 'char', 'unicode', 'varchar', 'tinytext', @filters.convert('string', 'char', 'unicode', 'varchar', 'tinytext',
'text', 'mediumtext', 'longtext', 'unicodetext', 'text', 'mediumtext', 'longtext', 'unicodetext',
'nchar', 'nvarchar', 'ntext', 'citext') 'nchar', 'nvarchar', 'ntext', 'citext', 'emailtype',
'URLType', 'IPAddressType')
def conv_string(self, column, name, **kwargs): def conv_string(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.strings] return [f(column, name, **kwargs) for f in self.strings]
@filters.convert('UUIDType', 'ColorType', 'TimezoneType', 'CurrencyType')
def conv_string_keys(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.string_key_filters]
@filters.convert('boolean', 'tinyint') @filters.convert('boolean', 'tinyint')
def conv_bool(self, column, name, **kwargs): def conv_bool(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.bool_filters] return [f(column, name, **kwargs) for f in self.bool_filters]
...@@ -420,6 +526,14 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -420,6 +526,14 @@ class FilterConverter(filters.BaseFilterConverter):
def conv_time(self, column, name, **kwargs): def conv_time(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.time_filters] return [f(column, name, **kwargs) for f in self.time_filters]
@filters.convert('ChoiceType')
def conv_sqla_utils_choice(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.choice_type_filters]
@filters.convert('ArrowType')
def conv_sqla_utils_arrow(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.arrow_type_filters]
@filters.convert('enum') @filters.convert('enum')
def conv_enum(self, column, name, options=None, **kwargs): def conv_enum(self, column, name, options=None, **kwargs):
if not options: if not options:
......
import warnings import warnings
from enum import Enum from enum import Enum, EnumMeta
from wtforms import fields, validators from wtforms import fields, validators
from sqlalchemy import Boolean, Column from sqlalchemy import Boolean, Column
...@@ -13,7 +13,7 @@ from flask_admin.model.helpers import prettify_name ...@@ -13,7 +13,7 @@ from flask_admin.model.helpers import prettify_name
from flask_admin._backwards import get_property from flask_admin._backwards import get_property
from flask_admin._compat import iteritems, text_type from flask_admin._compat import iteritems, text_type
from .validators import Unique from .validators import Unique, valid_currency, valid_color, TimeZoneValidator
from .fields import (QuerySelectField, QuerySelectMultipleField, from .fields import (QuerySelectField, QuerySelectMultipleField,
InlineModelFormList, InlineHstoreList, HstoreForm) InlineModelFormList, InlineHstoreList, HstoreForm)
from flask_admin.model.fields import InlineFormField from flask_admin.model.fields import InlineFormField
...@@ -243,9 +243,8 @@ class AdminModelConverter(ModelConverterBase): ...@@ -243,9 +243,8 @@ class AdminModelConverter(ModelConverterBase):
if override: if override:
return override(**kwargs) return override(**kwargs)
# Check choices # Check if a list of 'form_choices' are specified
form_choices = getattr(self.view, 'form_choices', None) form_choices = getattr(self.view, 'form_choices', None)
if mapper.class_ == self.view.model and form_choices: if mapper.class_ == self.view.model and form_choices:
choices = form_choices.get(prop.key) choices = form_choices.get(prop.key)
if choices: if choices:
...@@ -263,7 +262,6 @@ class AdminModelConverter(ModelConverterBase): ...@@ -263,7 +262,6 @@ class AdminModelConverter(ModelConverterBase):
return converter(model=model, mapper=mapper, prop=prop, return converter(model=model, mapper=mapper, prop=prop,
column=column, field_args=kwargs) column=column, field_args=kwargs)
return None return None
@classmethod @classmethod
...@@ -273,27 +271,52 @@ class AdminModelConverter(ModelConverterBase): ...@@ -273,27 +271,52 @@ class AdminModelConverter(ModelConverterBase):
@converts('String') # includes VARCHAR, CHAR, and Unicode @converts('String') # includes VARCHAR, CHAR, and Unicode
def conv_String(self, column, field_args, **extra): def conv_String(self, column, field_args, **extra):
if hasattr(column.type, 'enums'): if column.nullable:
accepted_values = list(column.type.enums) filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
field_args['choices'] = [(f, f) for f in column.type.enums] self._string_common(column=column, field_args=field_args, **extra)
return fields.StringField(**field_args)
if column.nullable: @converts('sqlalchemy.sql.sqltypes.Enum')
field_args['allow_blank'] = column.nullable def convert_enum(self, column, field_args, **extra):
accepted_values.append(None) available_choices = [(f, f) for f in column.type.enums]
accepted_values = [key for key, val in available_choices]
field_args['validators'].append(validators.AnyOf(accepted_values)) if column.nullable:
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v) field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
return form.Select2Field(**field_args) field_args['choices'] = available_choices
field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v)
return form.Select2Field(**field_args)
@converts('sqlalchemy_utils.types.choice.ChoiceType')
def convert_choice_type(self, column, field_args, **extra):
available_choices = []
# choices can either be specified as an enum, or as a list of tuples
if isinstance(column.type.choices, EnumMeta):
available_choices = [(f.value, f.name) for f in column.type.choices]
else:
available_choices = column.type.choices
accepted_values = [key for key, val in available_choices]
if column.nullable: if column.nullable:
field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', []) filters = field_args.get('filters', [])
filters.append(lambda x: x or None) filters.append(lambda x: x or None)
field_args['filters'] = filters field_args['filters'] = filters
self._string_common(column=column, field_args=field_args, **extra) field_args['choices'] = available_choices
return fields.StringField(**field_args) field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = choice_type_coerce_factory(column.type)
return form.Select2Field(**field_args)
@converts('Text', 'LargeBinary', 'Binary', 'CIText') # includes UnicodeText @converts('Text', 'LargeBinary', 'Binary', 'CIText') # includes UnicodeText
def conv_Text(self, field_args, **extra): def conv_Text(self, field_args, **extra):
...@@ -317,6 +340,44 @@ class AdminModelConverter(ModelConverterBase): ...@@ -317,6 +340,44 @@ class AdminModelConverter(ModelConverterBase):
def convert_time(self, field_args, **extra): def convert_time(self, field_args, **extra):
return form.TimeField(**field_args) return form.TimeField(**field_args)
@converts('sqlalchemy_utils.types.arrow.ArrowType')
def convert_arrow_time(self, field_args, **extra):
return form.DateTimeField(**field_args)
@converts('sqlalchemy_utils.types.email.EmailType')
def convert_email(self, field_args, **extra):
field_args['validators'].append(validators.Email())
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.url.URLType')
def convert_url(self, field_args, **extra):
field_args['validators'].append(validators.URL())
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.ip_address.IPAddressType')
def convert_ip_address(self, field_args, **extra):
field_args['validators'].append(validators.IPAddress())
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.color.ColorType')
def convert_color(self, field_args, **extra):
field_args['validators'].append(valid_color)
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.currency.CurrencyType')
def convert_currency(self, field_args, **extra):
field_args['validators'].append(valid_currency)
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.timezone.TimezoneType')
def convert_timezone(self, column, field_args, **extra):
field_args['validators'].append(TimeZoneValidator(coerce_function=column.type._coerce))
return fields.StringField(**field_args)
@converts('Integer') # includes BigInteger and SmallInteger @converts('Integer') # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra): def handle_integer_types(self, column, field_args, **extra):
unsigned = getattr(column.type, 'unsigned', False) unsigned = getattr(column.type, 'unsigned', False)
...@@ -342,10 +403,12 @@ class AdminModelConverter(ModelConverterBase): ...@@ -342,10 +403,12 @@ class AdminModelConverter(ModelConverterBase):
field_args['validators'].append(validators.MacAddress()) field_args['validators'].append(validators.MacAddress())
return fields.StringField(**field_args) return fields.StringField(**field_args)
@converts('sqlalchemy.dialects.postgresql.base.UUID') @converts('sqlalchemy.dialects.postgresql.base.UUID',
'sqlalchemy_utils.types.uuid.UUIDType')
def conv_PGUuid(self, field_args, **extra): def conv_PGUuid(self, field_args, **extra):
field_args.setdefault('label', u'UUID') field_args.setdefault('label', u'UUID')
field_args['validators'].append(validators.UUID()) field_args['validators'].append(validators.UUID())
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args) return fields.StringField(**field_args)
@converts('sqlalchemy.dialects.postgresql.base.ARRAY', @converts('sqlalchemy.dialects.postgresql.base.ARRAY',
...@@ -363,6 +426,41 @@ class AdminModelConverter(ModelConverterBase): ...@@ -363,6 +426,41 @@ class AdminModelConverter(ModelConverterBase):
return form.JSONField(**field_args) return form.JSONField(**field_args)
def avoid_empty_strings(value):
"""
Return None if the incoming value is an empty string or whitespace.
"""
if value:
try:
value = value.strip()
except AttributeError:
# values are not always strings
pass
return value if value else None
def choice_type_coerce_factory(type_):
"""
Return a function to coerce a ChoiceType column, for use by Select2Field.
:param type_: ChoiceType object
"""
from sqlalchemy_utils import Choice
choices = type_.choices
if isinstance(choices, type) and issubclass(choices, Enum):
key, choice_cls = 'value', choices
else:
key, choice_cls = 'code', Choice
def choice_coerce(value):
if value is None:
return None
if isinstance(value, choice_cls):
return getattr(value, key)
return type_.python_type(value)
return choice_coerce
def _resolve_prop(prop): def _resolve_prop(prop):
""" """
Resolve proxied property Resolve proxied property
......
from sqlalchemy.ext.associationproxy import _AssociationList from sqlalchemy.ext.associationproxy import _AssociationList
from flask_admin.model.typefmt import BASE_FORMATTERS, list_formatter from flask_admin.model.typefmt import BASE_FORMATTERS, EXPORT_FORMATTERS, \
list_formatter
from sqlalchemy.orm.collections import InstrumentedList from sqlalchemy.orm.collections import InstrumentedList
def choice_formatter(view, choice):
"""
Return label of selected choice
see https://sqlalchemy-utils.readthedocs.io/
:param choice:
sqlalchemy_utils Choice, which has a `code` and a `value`
"""
return choice.value
def arrow_formatter(view, arrow_time):
"""
Return human-friendly string of the time relative to now.
see https://arrow.readthedocs.io/
:param arrow_time:
Arrow object for handling datetimes
"""
return arrow_time.humanize()
def arrow_export_formatter(view, arrow_time):
"""
Return string representation of Arrow object
see https://arrow.readthedocs.io/
:param arrow_time:
Arrow object for handling datetimes
"""
return arrow_time.format()
DEFAULT_FORMATTERS = BASE_FORMATTERS.copy() DEFAULT_FORMATTERS = BASE_FORMATTERS.copy()
EXPORT_FORMATTERS = EXPORT_FORMATTERS.copy()
DEFAULT_FORMATTERS.update({ DEFAULT_FORMATTERS.update({
InstrumentedList: list_formatter, InstrumentedList: list_formatter,
_AssociationList: list_formatter _AssociationList: list_formatter,
}) })
try:
from sqlalchemy_utils import Choice
DEFAULT_FORMATTERS[Choice] = choice_formatter
except ImportError:
pass
try:
from arrow import Arrow
DEFAULT_FORMATTERS[Arrow] = arrow_formatter
EXPORT_FORMATTERS[Arrow] = arrow_export_formatter
except ImportError:
pass
...@@ -66,3 +66,34 @@ class ItemsRequired(InputRequired): ...@@ -66,3 +66,34 @@ class ItemsRequired(InputRequired):
message = self.message message = self.message
raise ValidationError(message) raise ValidationError(message)
def valid_currency(form, field):
from sqlalchemy_utils import Currency
try:
Currency(field.data)
except (TypeError, ValueError):
raise ValidationError(field.gettext(u'Not a valid ISO currency code (e.g. USD, EUR, CNY).'))
def valid_color(form, field):
from colour import Color
try:
Color(field.data)
except (ValueError):
raise ValidationError(field.gettext(u'Not a valid color (e.g. "red", "#f00", "#ff0000").'))
class TimeZoneValidator(object):
"""
Tries to coerce a TimZone object from input data
"""
def __init__(self, coerce_function):
self.coerce_function = coerce_function
def __call__(self, form, field):
try:
self.coerce_function(str(field.data))
except Exception:
msg = u'Not a valid timezone (e.g. "America/New_York", "Africa/Johannesburg", "Asia/Singapore").'
raise ValidationError(field.gettext(msg))
...@@ -3,16 +3,20 @@ from nose.tools import eq_, ok_, raises, assert_true ...@@ -3,16 +3,20 @@ from nose.tools import eq_, ok_, raises, assert_true
from wtforms import fields, validators from wtforms import fields, validators
from flask_admin import form from flask_admin import form
from flask_admin.form.fields import Select2Field, DateTimeField
from flask_admin._compat import as_unicode from flask_admin._compat import as_unicode
from flask_admin._compat import iteritems from flask_admin._compat import iteritems
from flask_admin.contrib.sqla import ModelView, filters, tools from flask_admin.contrib.sqla import ModelView, filters, tools
from flask_babelex import Babel from flask_babelex import Babel
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils import EmailType, ChoiceType, UUIDType, URLType, CurrencyType, ColorType, ArrowType, IPAddressType
from . import setup from . import setup
from datetime import datetime, time, date from datetime import datetime, time, date
import uuid
import enum
import arrow
class CustomModelView(ModelView): class CustomModelView(ModelView):
...@@ -24,13 +28,20 @@ class CustomModelView(ModelView): ...@@ -24,13 +28,20 @@ class CustomModelView(ModelView):
super(CustomModelView, self).__init__(model, session, name, category, super(CustomModelView, self).__init__(model, session, name, category,
endpoint, url) endpoint, url)
form_choices = {
'choice_field': [
('choice-1', 'One'),
('choice-2', 'Two')
]
}
def create_models(db): def create_models(db):
class Model1(db.Model): class Model1(db.Model):
def __init__(self, test1=None, test2=None, test3=None, test4=None, def __init__(self, test1=None, test2=None, test3=None, test4=None,
bool_field=False, date_field=None, time_field=None, bool_field=False, date_field=None, time_field=None,
datetime_field=None, enum_field=None): datetime_field=None, email_field=None,
choice_field=None, enum_field=None):
self.test1 = test1 self.test1 = test1
self.test2 = test2 self.test2 = test2
self.test3 = test3 self.test3 = test3
...@@ -39,19 +50,37 @@ def create_models(db): ...@@ -39,19 +50,37 @@ def create_models(db):
self.date_field = date_field self.date_field = date_field
self.time_field = time_field self.time_field = time_field
self.datetime_field = datetime_field self.datetime_field = datetime_field
self.email_field = email_field
self.choice_field = choice_field
self.enum_field = enum_field self.enum_field = enum_field
class EnumChoices(enum.Enum):
first = 1
second = 2
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
test1 = db.Column(db.String(20)) test1 = db.Column(db.String(20))
test2 = db.Column(db.Unicode(20)) test2 = db.Column(db.Unicode(20))
test3 = db.Column(db.Text) test3 = db.Column(db.Text)
test4 = db.Column(db.UnicodeText) test4 = db.Column(db.UnicodeText)
bool_field = db.Column(db.Boolean) bool_field = db.Column(db.Boolean)
enum_field = db.Column(db.Enum('model1_v1', 'model1_v2'), nullable=True)
date_field = db.Column(db.Date) date_field = db.Column(db.Date)
time_field = db.Column(db.Time) time_field = db.Column(db.Time)
datetime_field = db.Column(db.DateTime) datetime_field = db.Column(db.DateTime)
email_field = db.Column(EmailType)
enum_field = db.Column(db.Enum('model1_v1', 'model1_v2'), nullable=True)
choice_field = db.Column(db.String, nullable=True)
sqla_utils_choice = db.Column(ChoiceType([
('choice-1', u'First choice'),
('choice-2', u'Second choice')
]))
sqla_utils_enum = db.Column(ChoiceType(EnumChoices, impl=db.Integer()))
sqla_utils_arrow = db.Column(ArrowType, default=arrow.utcnow())
sqla_utils_uuid = db.Column(UUIDType(binary=False), default=uuid.uuid4)
sqla_utils_url = db.Column(URLType)
sqla_utils_ip_address = db.Column(IPAddressType)
sqla_utils_currency = db.Column(CurrencyType)
sqla_utils_color = db.Column(ColorType)
def __unicode__(self): def __unicode__(self):
return self.test1 return self.test1
...@@ -95,7 +124,7 @@ def fill_db(db, Model1, Model2): ...@@ -95,7 +124,7 @@ def fill_db(db, Model1, Model2):
model1_obj1 = Model1('test1_val_1', 'test2_val_1', bool_field=True) model1_obj1 = Model1('test1_val_1', 'test2_val_1', bool_field=True)
model1_obj2 = Model1('test1_val_2', 'test2_val_2', bool_field=False) model1_obj2 = Model1('test1_val_2', 'test2_val_2', bool_field=False)
model1_obj3 = Model1('test1_val_3', 'test2_val_3') model1_obj3 = Model1('test1_val_3', 'test2_val_3')
model1_obj4 = Model1('test1_val_4', 'test2_val_4') model1_obj4 = Model1('test1_val_4', 'test2_val_4', email_field="test@test.com", choice_field="choice-1")
model2_obj1 = Model2('test2_val_1', model1=model1_obj1, float_field=None) model2_obj1 = Model2('test2_val_1', model1=model1_obj1, float_field=None)
model2_obj2 = Model2('test2_val_2', model1=model1_obj2, float_field=None) model2_obj2 = Model2('test2_val_2', model1=model1_obj2, float_field=None)
...@@ -130,6 +159,7 @@ def test_model(): ...@@ -130,6 +159,7 @@ def test_model():
Model1, Model2 = create_models(db) Model1, Model2 = create_models(db)
view = CustomModelView(Model1, db.session) view = CustomModelView(Model1, db.session)
admin.add_view(view) admin.add_view(view)
eq_(view.model, Model1) eq_(view.model, Model1)
...@@ -153,32 +183,76 @@ def test_model(): ...@@ -153,32 +183,76 @@ def test_model():
eq_(view._create_form_class.test2.field_class, fields.StringField) eq_(view._create_form_class.test2.field_class, fields.StringField)
eq_(view._create_form_class.test3.field_class, fields.TextAreaField) eq_(view._create_form_class.test3.field_class, fields.TextAreaField)
eq_(view._create_form_class.test4.field_class, fields.TextAreaField) eq_(view._create_form_class.test4.field_class, fields.TextAreaField)
eq_(view._create_form_class.email_field.field_class, fields.StringField)
eq_(view._create_form_class.choice_field.field_class, Select2Field)
eq_(view._create_form_class.enum_field.field_class, Select2Field)
eq_(view._create_form_class.sqla_utils_choice.field_class, Select2Field)
eq_(view._create_form_class.sqla_utils_enum.field_class, Select2Field)
eq_(view._create_form_class.sqla_utils_arrow.field_class, DateTimeField)
eq_(view._create_form_class.sqla_utils_uuid.field_class, fields.StringField)
eq_(view._create_form_class.sqla_utils_url.field_class, fields.StringField)
eq_(view._create_form_class.sqla_utils_ip_address.field_class, fields.StringField)
eq_(view._create_form_class.sqla_utils_currency.field_class, fields.StringField)
eq_(view._create_form_class.sqla_utils_color.field_class, fields.StringField)
# Make some test clients # Make some test clients
client = app.test_client() client = app.test_client()
# check that we can retrieve a list view
rv = client.get('/admin/model1/') rv = client.get('/admin/model1/')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
# check that we can retrieve a 'create' view
rv = client.get('/admin/model1/new/') rv = client.get('/admin/model1/new/')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
rv = client.post('/admin/model1/new/', # create a new record
data=dict(test1='test1large', uuid_obj = uuid.uuid4()
test2='test2', rv = client.post(
time_field=time(0, 0, 0))) '/admin/model1/new/',
data=dict(
test1='test1large',
test2='test2',
time_field=time(0, 0, 0),
email_field="Test@TEST.com",
choice_field="choice-1",
enum_field='model1_v1',
sqla_utils_choice="choice-1",
sqla_utils_enum=1,
sqla_utils_arrow='2018-10-27 14:17:00',
sqla_utils_uuid=str(uuid_obj),
sqla_utils_url="http://www.example.com",
sqla_utils_ip_address='127.0.0.1',
sqla_utils_currency='USD',
sqla_utils_color='#f0f0f0',
)
)
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
# check that the new record was persisted
model = db.session.query(Model1).first() model = db.session.query(Model1).first()
eq_(model.test1, u'test1large') eq_(model.test1, u'test1large')
eq_(model.test2, u'test2') eq_(model.test2, u'test2')
eq_(model.test3, u'') eq_(model.test3, u'')
eq_(model.test4, u'') eq_(model.test4, u'')
eq_(model.email_field, u'test@test.com')
eq_(model.choice_field, u'choice-1')
eq_(model.enum_field, u'model1_v1')
eq_(model.sqla_utils_choice, u'choice-1')
eq_(model.sqla_utils_enum.value, 1)
eq_(model.sqla_utils_arrow, arrow.get('2018-10-27 14:17:00'))
eq_(model.sqla_utils_uuid, uuid_obj)
eq_(model.sqla_utils_url, "http://www.example.com")
eq_(str(model.sqla_utils_ip_address), '127.0.0.1')
eq_(str(model.sqla_utils_currency), 'USD')
eq_(model.sqla_utils_color.hex, '#f0f0f0')
# check that the new record shows up on the list view
rv = client.get('/admin/model1/') rv = client.get('/admin/model1/')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
ok_(u'test1large' in rv.data.decode('utf-8')) ok_(u'test1large' in rv.data.decode('utf-8'))
# check that we can retrieve an edit view
url = '/admin/model1/edit/?id=%s' % model.id url = '/admin/model1/edit/?id=%s' % model.id
rv = client.get(url) rv = client.get(url)
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
...@@ -186,16 +260,44 @@ def test_model(): ...@@ -186,16 +260,44 @@ def test_model():
# verify that midnight does not show as blank # verify that midnight does not show as blank
ok_(u'00:00:00' in rv.data.decode('utf-8')) ok_(u'00:00:00' in rv.data.decode('utf-8'))
# edit the record
new_uuid_obj = uuid.uuid4()
rv = client.post(url, rv = client.post(url,
data=dict(test1='test1small', test2='test2large')) data=dict(test1='test1small',
test2='test2large',
email_field='Test2@TEST.com',
choice_field='__None',
enum_field='__None',
sqla_utils_choice='__None',
sqla_utils_enum='__None',
sqla_utils_arrow='',
sqla_utils_uuid=str(new_uuid_obj),
sqla_utils_url='',
sqla_utils_ip_address='',
sqla_utils_currency='',
sqla_utils_color='',
))
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
# check that the changes were persisted
model = db.session.query(Model1).first() model = db.session.query(Model1).first()
eq_(model.test1, 'test1small') eq_(model.test1, 'test1small')
eq_(model.test2, 'test2large') eq_(model.test2, 'test2large')
eq_(model.test3, '') eq_(model.test3, '')
eq_(model.test4, '') eq_(model.test4, '')
eq_(model.email_field, u'test2@test.com')
eq_(model.choice_field, None)
eq_(model.enum_field, None)
eq_(model.sqla_utils_choice, None)
eq_(model.sqla_utils_enum, None)
eq_(model.sqla_utils_arrow, None)
eq_(model.sqla_utils_uuid, new_uuid_obj)
eq_(model.sqla_utils_url, None)
eq_(model.sqla_utils_ip_address, None)
eq_(model.sqla_utils_currency, None)
eq_(model.sqla_utils_color, None)
# check that the model can be deleted
url = '/admin/model1/delete/?id=%s' % model.id url = '/admin/model1/delete/?id=%s' % model.id
rv = client.post(url) rv = client.post(url)
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
...@@ -279,13 +381,16 @@ def test_exclude_columns(): ...@@ -279,13 +381,16 @@ def test_exclude_columns():
view = CustomModelView( view = CustomModelView(
Model1, db.session, Model1, db.session,
column_exclude_list=['test2', 'test4', 'enum_field', 'date_field', 'time_field', 'datetime_field'] column_exclude_list=['test2', 'test4', 'enum_field', 'date_field', 'time_field', 'datetime_field',
'sqla_utils_choice', 'sqla_utils_enum', 'sqla_utils_arrow', 'sqla_utils_uuid',
'sqla_utils_url', 'sqla_utils_ip_address', 'sqla_utils_currency', 'sqla_utils_color']
) )
admin.add_view(view) admin.add_view(view)
eq_( eq_(
view._list_columns, view._list_columns,
[('test1', 'Test1'), ('test3', 'Test3'), ('bool_field', 'Bool Field')] [('test1', 'Test1'), ('test3', 'Test3'), ('bool_field', 'Bool Field'),
('email_field', 'Email Field'), ('choice_field', 'Choice Field')]
) )
client = app.test_client() client = app.test_client()
...@@ -640,13 +745,100 @@ def test_column_filters(): ...@@ -640,13 +745,100 @@ def test_column_filters():
) )
eq_( eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Enum Field']], [(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Date Field']],
[ [
(30, u'equals'), (30, u'equals'),
(31, u'not equal'), (31, u'not equal'),
(32, u'empty'), (32, u'greater than'),
(33, u'in list'), (33, u'smaller than'),
(34, u'not in list'), (34, u'between'),
(35, u'not between'),
(36, u'empty'),
]
)
eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Time Field']],
[
(37, u'equals'),
(38, u'not equal'),
(39, u'greater than'),
(40, u'smaller than'),
(41, u'between'),
(42, u'not between'),
(43, u'empty'),
]
)
eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Datetime Field']],
[
(44, u'equals'),
(45, u'not equal'),
(46, u'greater than'),
(47, u'smaller than'),
(48, u'between'),
(49, u'not between'),
(50, u'empty'),
]
)
eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Email Field']],
[
(51, u'contains'),
(52, u'not contains'),
(53, u'equals'),
(54, u'not equal'),
(55, u'empty'),
(56, u'in list'),
(57, u'not in list'),
]
)
eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Enum Field']],
[
(58, u'equals'),
(59, u'not equal'),
(60, u'empty'),
(61, u'in list'),
(62, u'not in list'),
]
)
eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Choice Field']],
[
(63, u'contains'),
(64, u'not contains'),
(65, u'equals'),
(66, u'not equal'),
(67, u'empty'),
(68, u'in list'),
(69, u'not in list'),
]
)
eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Sqla Utils Choice']],
[
(70, u'equals'),
(71, u'not equal'),
(72, u'contains'),
(73, u'not contains'),
(74, u'empty'),
]
)
eq_(
[(f['index'], f['operation']) for f in view._filter_groups[u'Model1 / Sqla Utils Enum']],
[
(75, u'equals'),
(76, u'not equal'),
(77, u'contains'),
(78, u'not contains'),
(79, u'empty'),
] ]
) )
...@@ -1502,9 +1694,20 @@ def test_form_columns(): ...@@ -1502,9 +1694,20 @@ def test_form_columns():
excluded_column = db.Column(db.String) excluded_column = db.Column(db.String)
class ChildModel(db.Model): class ChildModel(db.Model):
class EnumChoices(enum.Enum):
first = 1
second = 2
id = db.Column(db.String, primary_key=True) id = db.Column(db.String, primary_key=True)
model_id = db.Column(db.Integer, db.ForeignKey(Model.id)) model_id = db.Column(db.Integer, db.ForeignKey(Model.id))
model = db.relationship(Model, backref='backref') model = db.relationship(Model, backref='backref')
enum_field = db.Column(db.Enum('model1_v1', 'model1_v2'), nullable=True)
choice_field = db.Column(db.String, nullable=True)
sqla_utils_choice = db.Column(ChoiceType([
('choice-1', u'First choice'),
('choice-2', u'Second choice')
]))
sqla_utils_enum = db.Column(ChoiceType(EnumChoices, impl=db.Integer()))
db.create_all() db.create_all()
...@@ -1521,11 +1724,21 @@ def test_form_columns(): ...@@ -1521,11 +1724,21 @@ def test_form_columns():
ok_('int_field' in form1._fields) ok_('int_field' in form1._fields)
ok_('text_field' in form1._fields) ok_('text_field' in form1._fields)
ok_('datetime_field' not in form1._fields) ok_('datetime_field' not in form1._fields)
ok_('excluded_column' not in form2._fields) ok_('excluded_column' not in form2._fields)
# check that relation shows up as a query select
ok_(type(form3.model).__name__ == 'QuerySelectField') ok_(type(form3.model).__name__ == 'QuerySelectField')
# check that select field is rendered if form_choices were specified
ok_(type(form3.choice_field).__name__ == 'Select2Field')
# check that select field is rendered for enum fields
ok_(type(form3.enum_field).__name__ == 'Select2Field')
# check that sqlalchemy_utils field types are handled appropriately
ok_(type(form3.sqla_utils_choice).__name__ == 'Select2Field')
ok_(type(form3.sqla_utils_enum).__name__ == 'Select2Field')
# test form_columns with model objects # test form_columns with model objects
view4 = CustomModelView(Model, db.session, endpoint='view1', view4 = CustomModelView(Model, db.session, endpoint='view1',
form_columns=[Model.int_field]) form_columns=[Model.int_field])
......
...@@ -16,4 +16,7 @@ nose ...@@ -16,4 +16,7 @@ nose
coveralls coveralls
pylint pylint
sqlalchemy-citext sqlalchemy-citext
sqlalchemy_utils
azure-storage-blob azure-storage-blob
arrow<0.14.0
colour
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment