Unverified Commit 3e558465 authored by Petrus Janse van Rensburg's avatar Petrus Janse van Rensburg Committed by GitHub

Merge pull request #1742 from flask-admin/sqlalchemy-utils-types

Sqlalchemy utils types
parents d29796b6 ce6bdc16
...@@ -5,6 +5,7 @@ Next release ...@@ -5,6 +5,7 @@ Next release
----- -----
* Fix display of inline x-editable boolean fields on list view * Fix display of inline x-editable boolean fields on list view
* Add support for several SQLAlchemy-Utils data types
1.5.3 1.5.3
----- -----
......
This diff is collapsed.
Flask Flask
Flask-Admin Flask-Admin
Flask-BabelEx
Flask-SQLAlchemy Flask-SQLAlchemy
tablib tablib
sqlalchemy_utils
arrow
colour
...@@ -2,6 +2,7 @@ from flask_admin.babel import lazy_gettext ...@@ -2,6 +2,7 @@ from flask_admin.babel import lazy_gettext
from flask_admin.model import filters from flask_admin.model import filters
from flask_admin.contrib.sqla import tools from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_ from sqlalchemy.sql import not_, or_
import enum
class BaseSQLAFilter(filters.BaseFilter): class BaseSQLAFilter(filters.BaseFilter):
...@@ -339,6 +340,102 @@ class EnumFilterNotInList(FilterNotInList): ...@@ -339,6 +340,102 @@ class EnumFilterNotInList(FilterNotInList):
return values return values
class ChoiceTypeEqualFilter(FilterEqual):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeEqualFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_type = None
# loop through choice 'values' to try and find an exact match
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if choice.name == user_query:
choice_type = choice.value
break
else:
for type, value in column.type.choices:
if value == user_query:
choice_type = type
break
if choice_type:
return query.filter(column == choice_type)
else:
return query.filter(column.in_([]))
class ChoiceTypeNotEqualFilter(FilterNotEqual):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeNotEqualFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_type = None
# loop through choice 'values' to try and find an exact match
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if choice.name == user_query:
choice_type = choice.value
break
else:
for type, value in column.type.choices:
if value == user_query:
choice_type = type
break
if choice_type:
# != can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(column != choice_type, column == None)) # noqa: E711
else:
return query
class ChoiceTypeLikeFilter(FilterLike):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeLikeFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_types = []
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if user_query.lower() in choice.name.lower():
choice_types.append(choice.value)
else:
for type, value in column.type.choices:
if user_query.lower() in value.lower():
choice_types.append(type)
if choice_types:
return query.filter(column.in_(choice_types))
else:
return query
class ChoiceTypeNotLikeFilter(FilterNotLike):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeNotLikeFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_types = []
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if user_query.lower() in choice.name.lower():
choice_types.append(choice.value)
else:
for type, value in column.type.choices:
if user_query.lower() in value.lower():
choice_types.append(type)
if choice_types:
# != can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(column.notin_(choice_types), column == None)) # noqa: E711
else:
return query
class UuidFilterEqual(FilterEqual, filters.BaseUuidFilter): class UuidFilterEqual(FilterEqual, filters.BaseUuidFilter):
pass pass
...@@ -359,6 +456,7 @@ class UuidFilterNotInList(filters.BaseUuidListFilter, FilterNotInList): ...@@ -359,6 +456,7 @@ class UuidFilterNotInList(filters.BaseUuidListFilter, FilterNotInList):
class FilterConverter(filters.BaseFilterConverter): class FilterConverter(filters.BaseFilterConverter):
strings = (FilterLike, FilterNotLike, FilterEqual, FilterNotEqual, strings = (FilterLike, FilterNotLike, FilterEqual, FilterNotEqual,
FilterEmpty, FilterInList, FilterNotInList) FilterEmpty, FilterInList, FilterNotInList)
string_key_filters = (FilterEqual, FilterNotEqual, FilterEmpty, FilterInList, FilterNotInList)
int_filters = (IntEqualFilter, IntNotEqualFilter, IntGreaterFilter, int_filters = (IntEqualFilter, IntNotEqualFilter, IntGreaterFilter,
IntSmallerFilter, FilterEmpty, IntInListFilter, IntSmallerFilter, FilterEmpty, IntInListFilter,
IntNotInListFilter) IntNotInListFilter)
...@@ -378,8 +476,11 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -378,8 +476,11 @@ class FilterConverter(filters.BaseFilterConverter):
time_filters = (TimeEqualFilter, TimeNotEqualFilter, TimeGreaterFilter, time_filters = (TimeEqualFilter, TimeNotEqualFilter, TimeGreaterFilter,
TimeSmallerFilter, TimeBetweenFilter, TimeNotBetweenFilter, TimeSmallerFilter, TimeBetweenFilter, TimeNotBetweenFilter,
FilterEmpty) FilterEmpty)
choice_type_filters = (ChoiceTypeEqualFilter, ChoiceTypeNotEqualFilter,
ChoiceTypeLikeFilter, ChoiceTypeNotLikeFilter, FilterEmpty)
uuid_filters = (UuidFilterEqual, UuidFilterNotEqual, FilterEmpty, uuid_filters = (UuidFilterEqual, UuidFilterNotEqual, FilterEmpty,
UuidFilterInList, UuidFilterNotInList) UuidFilterInList, UuidFilterNotInList)
arrow_type_filters = (DateTimeGreaterFilter, DateTimeSmallerFilter, FilterEmpty)
def convert(self, type_name, column, name, **kwargs): def convert(self, type_name, column, name, **kwargs):
filter_name = type_name.lower() filter_name = type_name.lower()
...@@ -391,10 +492,15 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -391,10 +492,15 @@ class FilterConverter(filters.BaseFilterConverter):
@filters.convert('string', 'char', 'unicode', 'varchar', 'tinytext', @filters.convert('string', 'char', 'unicode', 'varchar', 'tinytext',
'text', 'mediumtext', 'longtext', 'unicodetext', 'text', 'mediumtext', 'longtext', 'unicodetext',
'nchar', 'nvarchar', 'ntext', 'citext') 'nchar', 'nvarchar', 'ntext', 'citext', 'emailtype',
'URLType', 'IPAddressType')
def conv_string(self, column, name, **kwargs): def conv_string(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.strings] return [f(column, name, **kwargs) for f in self.strings]
@filters.convert('UUIDType', 'ColorType', 'TimezoneType', 'CurrencyType')
def conv_string_keys(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.string_key_filters]
@filters.convert('boolean', 'tinyint') @filters.convert('boolean', 'tinyint')
def conv_bool(self, column, name, **kwargs): def conv_bool(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.bool_filters] return [f(column, name, **kwargs) for f in self.bool_filters]
...@@ -420,6 +526,14 @@ class FilterConverter(filters.BaseFilterConverter): ...@@ -420,6 +526,14 @@ class FilterConverter(filters.BaseFilterConverter):
def conv_time(self, column, name, **kwargs): def conv_time(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.time_filters] return [f(column, name, **kwargs) for f in self.time_filters]
@filters.convert('ChoiceType')
def conv_sqla_utils_choice(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.choice_type_filters]
@filters.convert('ArrowType')
def conv_sqla_utils_arrow(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.arrow_type_filters]
@filters.convert('enum') @filters.convert('enum')
def conv_enum(self, column, name, options=None, **kwargs): def conv_enum(self, column, name, options=None, **kwargs):
if not options: if not options:
......
import warnings import warnings
from enum import Enum from enum import Enum, EnumMeta
from wtforms import fields, validators from wtforms import fields, validators
from sqlalchemy import Boolean, Column from sqlalchemy import Boolean, Column
...@@ -13,7 +13,7 @@ from flask_admin.model.helpers import prettify_name ...@@ -13,7 +13,7 @@ from flask_admin.model.helpers import prettify_name
from flask_admin._backwards import get_property from flask_admin._backwards import get_property
from flask_admin._compat import iteritems, text_type from flask_admin._compat import iteritems, text_type
from .validators import Unique from .validators import Unique, valid_currency, valid_color, TimeZoneValidator
from .fields import (QuerySelectField, QuerySelectMultipleField, from .fields import (QuerySelectField, QuerySelectMultipleField,
InlineModelFormList, InlineHstoreList, HstoreForm) InlineModelFormList, InlineHstoreList, HstoreForm)
from flask_admin.model.fields import InlineFormField from flask_admin.model.fields import InlineFormField
...@@ -243,9 +243,8 @@ class AdminModelConverter(ModelConverterBase): ...@@ -243,9 +243,8 @@ class AdminModelConverter(ModelConverterBase):
if override: if override:
return override(**kwargs) return override(**kwargs)
# Check choices # Check if a list of 'form_choices' are specified
form_choices = getattr(self.view, 'form_choices', None) form_choices = getattr(self.view, 'form_choices', None)
if mapper.class_ == self.view.model and form_choices: if mapper.class_ == self.view.model and form_choices:
choices = form_choices.get(prop.key) choices = form_choices.get(prop.key)
if choices: if choices:
...@@ -263,7 +262,6 @@ class AdminModelConverter(ModelConverterBase): ...@@ -263,7 +262,6 @@ class AdminModelConverter(ModelConverterBase):
return converter(model=model, mapper=mapper, prop=prop, return converter(model=model, mapper=mapper, prop=prop,
column=column, field_args=kwargs) column=column, field_args=kwargs)
return None return None
@classmethod @classmethod
...@@ -273,27 +271,52 @@ class AdminModelConverter(ModelConverterBase): ...@@ -273,27 +271,52 @@ class AdminModelConverter(ModelConverterBase):
@converts('String') # includes VARCHAR, CHAR, and Unicode @converts('String') # includes VARCHAR, CHAR, and Unicode
def conv_String(self, column, field_args, **extra): def conv_String(self, column, field_args, **extra):
if hasattr(column.type, 'enums'): if column.nullable:
accepted_values = list(column.type.enums) filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
field_args['choices'] = [(f, f) for f in column.type.enums] self._string_common(column=column, field_args=field_args, **extra)
return fields.StringField(**field_args)
if column.nullable: @converts('sqlalchemy.sql.sqltypes.Enum')
field_args['allow_blank'] = column.nullable def convert_enum(self, column, field_args, **extra):
accepted_values.append(None) available_choices = [(f, f) for f in column.type.enums]
accepted_values = [key for key, val in available_choices]
field_args['validators'].append(validators.AnyOf(accepted_values)) if column.nullable:
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v) field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
return form.Select2Field(**field_args) field_args['choices'] = available_choices
field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v)
return form.Select2Field(**field_args)
@converts('sqlalchemy_utils.types.choice.ChoiceType')
def convert_choice_type(self, column, field_args, **extra):
available_choices = []
# choices can either be specified as an enum, or as a list of tuples
if isinstance(column.type.choices, EnumMeta):
available_choices = [(f.value, f.name) for f in column.type.choices]
else:
available_choices = column.type.choices
accepted_values = [key for key, val in available_choices]
if column.nullable: if column.nullable:
field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', []) filters = field_args.get('filters', [])
filters.append(lambda x: x or None) filters.append(lambda x: x or None)
field_args['filters'] = filters field_args['filters'] = filters
self._string_common(column=column, field_args=field_args, **extra) field_args['choices'] = available_choices
return fields.StringField(**field_args) field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = choice_type_coerce_factory(column.type)
return form.Select2Field(**field_args)
@converts('Text', 'LargeBinary', 'Binary', 'CIText') # includes UnicodeText @converts('Text', 'LargeBinary', 'Binary', 'CIText') # includes UnicodeText
def conv_Text(self, field_args, **extra): def conv_Text(self, field_args, **extra):
...@@ -317,6 +340,44 @@ class AdminModelConverter(ModelConverterBase): ...@@ -317,6 +340,44 @@ class AdminModelConverter(ModelConverterBase):
def convert_time(self, field_args, **extra): def convert_time(self, field_args, **extra):
return form.TimeField(**field_args) return form.TimeField(**field_args)
@converts('sqlalchemy_utils.types.arrow.ArrowType')
def convert_arrow_time(self, field_args, **extra):
return form.DateTimeField(**field_args)
@converts('sqlalchemy_utils.types.email.EmailType')
def convert_email(self, field_args, **extra):
field_args['validators'].append(validators.Email())
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.url.URLType')
def convert_url(self, field_args, **extra):
field_args['validators'].append(validators.URL())
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.ip_address.IPAddressType')
def convert_ip_address(self, field_args, **extra):
field_args['validators'].append(validators.IPAddress())
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.color.ColorType')
def convert_color(self, field_args, **extra):
field_args['validators'].append(valid_color)
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.currency.CurrencyType')
def convert_currency(self, field_args, **extra):
field_args['validators'].append(valid_currency)
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.timezone.TimezoneType')
def convert_timezone(self, column, field_args, **extra):
field_args['validators'].append(TimeZoneValidator(coerce_function=column.type._coerce))
return fields.StringField(**field_args)
@converts('Integer') # includes BigInteger and SmallInteger @converts('Integer') # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra): def handle_integer_types(self, column, field_args, **extra):
unsigned = getattr(column.type, 'unsigned', False) unsigned = getattr(column.type, 'unsigned', False)
...@@ -342,10 +403,12 @@ class AdminModelConverter(ModelConverterBase): ...@@ -342,10 +403,12 @@ class AdminModelConverter(ModelConverterBase):
field_args['validators'].append(validators.MacAddress()) field_args['validators'].append(validators.MacAddress())
return fields.StringField(**field_args) return fields.StringField(**field_args)
@converts('sqlalchemy.dialects.postgresql.base.UUID') @converts('sqlalchemy.dialects.postgresql.base.UUID',
'sqlalchemy_utils.types.uuid.UUIDType')
def conv_PGUuid(self, field_args, **extra): def conv_PGUuid(self, field_args, **extra):
field_args.setdefault('label', u'UUID') field_args.setdefault('label', u'UUID')
field_args['validators'].append(validators.UUID()) field_args['validators'].append(validators.UUID())
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args) return fields.StringField(**field_args)
@converts('sqlalchemy.dialects.postgresql.base.ARRAY', @converts('sqlalchemy.dialects.postgresql.base.ARRAY',
...@@ -363,6 +426,41 @@ class AdminModelConverter(ModelConverterBase): ...@@ -363,6 +426,41 @@ class AdminModelConverter(ModelConverterBase):
return form.JSONField(**field_args) return form.JSONField(**field_args)
def avoid_empty_strings(value):
"""
Return None if the incoming value is an empty string or whitespace.
"""
if value:
try:
value = value.strip()
except AttributeError:
# values are not always strings
pass
return value if value else None
def choice_type_coerce_factory(type_):
"""
Return a function to coerce a ChoiceType column, for use by Select2Field.
:param type_: ChoiceType object
"""
from sqlalchemy_utils import Choice
choices = type_.choices
if isinstance(choices, type) and issubclass(choices, Enum):
key, choice_cls = 'value', choices
else:
key, choice_cls = 'code', Choice
def choice_coerce(value):
if value is None:
return None
if isinstance(value, choice_cls):
return getattr(value, key)
return type_.python_type(value)
return choice_coerce
def _resolve_prop(prop): def _resolve_prop(prop):
""" """
Resolve proxied property Resolve proxied property
......
from sqlalchemy.ext.associationproxy import _AssociationList from sqlalchemy.ext.associationproxy import _AssociationList
from flask_admin.model.typefmt import BASE_FORMATTERS, list_formatter from flask_admin.model.typefmt import BASE_FORMATTERS, EXPORT_FORMATTERS, \
list_formatter
from sqlalchemy.orm.collections import InstrumentedList from sqlalchemy.orm.collections import InstrumentedList
def choice_formatter(view, choice):
"""
Return label of selected choice
see https://sqlalchemy-utils.readthedocs.io/
:param choice:
sqlalchemy_utils Choice, which has a `code` and a `value`
"""
return choice.value
def arrow_formatter(view, arrow_time):
"""
Return human-friendly string of the time relative to now.
see https://arrow.readthedocs.io/
:param arrow_time:
Arrow object for handling datetimes
"""
return arrow_time.humanize()
def arrow_export_formatter(view, arrow_time):
"""
Return string representation of Arrow object
see https://arrow.readthedocs.io/
:param arrow_time:
Arrow object for handling datetimes
"""
return arrow_time.format()
DEFAULT_FORMATTERS = BASE_FORMATTERS.copy() DEFAULT_FORMATTERS = BASE_FORMATTERS.copy()
EXPORT_FORMATTERS = EXPORT_FORMATTERS.copy()
DEFAULT_FORMATTERS.update({ DEFAULT_FORMATTERS.update({
InstrumentedList: list_formatter, InstrumentedList: list_formatter,
_AssociationList: list_formatter _AssociationList: list_formatter,
}) })
try:
from sqlalchemy_utils import Choice
DEFAULT_FORMATTERS[Choice] = choice_formatter
except ImportError:
pass
try:
from arrow import Arrow
DEFAULT_FORMATTERS[Arrow] = arrow_formatter
EXPORT_FORMATTERS[Arrow] = arrow_export_formatter
except ImportError:
pass
...@@ -66,3 +66,34 @@ class ItemsRequired(InputRequired): ...@@ -66,3 +66,34 @@ class ItemsRequired(InputRequired):
message = self.message message = self.message
raise ValidationError(message) raise ValidationError(message)
def valid_currency(form, field):
from sqlalchemy_utils import Currency
try:
Currency(field.data)
except (TypeError, ValueError):
raise ValidationError(field.gettext(u'Not a valid ISO currency code (e.g. USD, EUR, CNY).'))
def valid_color(form, field):
from colour import Color
try:
Color(field.data)
except (ValueError):
raise ValidationError(field.gettext(u'Not a valid color (e.g. "red", "#f00", "#ff0000").'))
class TimeZoneValidator(object):
"""
Tries to coerce a TimZone object from input data
"""
def __init__(self, coerce_function):
self.coerce_function = coerce_function
def __call__(self, form, field):
try:
self.coerce_function(str(field.data))
except Exception:
msg = u'Not a valid timezone (e.g. "America/New_York", "Africa/Johannesburg", "Asia/Singapore").'
raise ValidationError(field.gettext(msg))
This diff is collapsed.
...@@ -16,4 +16,7 @@ nose ...@@ -16,4 +16,7 @@ nose
coveralls coveralls
pylint pylint
sqlalchemy-citext sqlalchemy-citext
sqlalchemy_utils
azure-storage-blob azure-storage-blob
arrow<0.14.0
colour
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment