Unverified Commit 3e558465 authored by Petrus Janse van Rensburg's avatar Petrus Janse van Rensburg Committed by GitHub

Merge pull request #1742 from flask-admin/sqlalchemy-utils-types

Sqlalchemy utils types
parents d29796b6 ce6bdc16
......@@ -5,6 +5,7 @@ Next release
-----
* Fix display of inline x-editable boolean fields on list view
* Add support for several SQLAlchemy-Utils data types
1.5.3
-----
......
This diff is collapsed.
Flask
Flask-Admin
Flask-BabelEx
Flask-SQLAlchemy
tablib
sqlalchemy_utils
arrow
colour
......@@ -2,6 +2,7 @@ from flask_admin.babel import lazy_gettext
from flask_admin.model import filters
from flask_admin.contrib.sqla import tools
from sqlalchemy.sql import not_, or_
import enum
class BaseSQLAFilter(filters.BaseFilter):
......@@ -339,6 +340,102 @@ class EnumFilterNotInList(FilterNotInList):
return values
class ChoiceTypeEqualFilter(FilterEqual):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeEqualFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_type = None
# loop through choice 'values' to try and find an exact match
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if choice.name == user_query:
choice_type = choice.value
break
else:
for type, value in column.type.choices:
if value == user_query:
choice_type = type
break
if choice_type:
return query.filter(column == choice_type)
else:
return query.filter(column.in_([]))
class ChoiceTypeNotEqualFilter(FilterNotEqual):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeNotEqualFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_type = None
# loop through choice 'values' to try and find an exact match
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if choice.name == user_query:
choice_type = choice.value
break
else:
for type, value in column.type.choices:
if value == user_query:
choice_type = type
break
if choice_type:
# != can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(column != choice_type, column == None)) # noqa: E711
else:
return query
class ChoiceTypeLikeFilter(FilterLike):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeLikeFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_types = []
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if user_query.lower() in choice.name.lower():
choice_types.append(choice.value)
else:
for type, value in column.type.choices:
if user_query.lower() in value.lower():
choice_types.append(type)
if choice_types:
return query.filter(column.in_(choice_types))
else:
return query
class ChoiceTypeNotLikeFilter(FilterNotLike):
def __init__(self, column, name, options=None, **kwargs):
super(ChoiceTypeNotLikeFilter, self).__init__(column, name, options, **kwargs)
def apply(self, query, user_query, alias=None):
column = self.get_column(alias)
choice_types = []
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
for choice in column.type.choices:
if user_query.lower() in choice.name.lower():
choice_types.append(choice.value)
else:
for type, value in column.type.choices:
if user_query.lower() in value.lower():
choice_types.append(type)
if choice_types:
# != can exclude NULL values, so "or_ == None" needed to be added
return query.filter(or_(column.notin_(choice_types), column == None)) # noqa: E711
else:
return query
class UuidFilterEqual(FilterEqual, filters.BaseUuidFilter):
pass
......@@ -359,6 +456,7 @@ class UuidFilterNotInList(filters.BaseUuidListFilter, FilterNotInList):
class FilterConverter(filters.BaseFilterConverter):
strings = (FilterLike, FilterNotLike, FilterEqual, FilterNotEqual,
FilterEmpty, FilterInList, FilterNotInList)
string_key_filters = (FilterEqual, FilterNotEqual, FilterEmpty, FilterInList, FilterNotInList)
int_filters = (IntEqualFilter, IntNotEqualFilter, IntGreaterFilter,
IntSmallerFilter, FilterEmpty, IntInListFilter,
IntNotInListFilter)
......@@ -378,8 +476,11 @@ class FilterConverter(filters.BaseFilterConverter):
time_filters = (TimeEqualFilter, TimeNotEqualFilter, TimeGreaterFilter,
TimeSmallerFilter, TimeBetweenFilter, TimeNotBetweenFilter,
FilterEmpty)
choice_type_filters = (ChoiceTypeEqualFilter, ChoiceTypeNotEqualFilter,
ChoiceTypeLikeFilter, ChoiceTypeNotLikeFilter, FilterEmpty)
uuid_filters = (UuidFilterEqual, UuidFilterNotEqual, FilterEmpty,
UuidFilterInList, UuidFilterNotInList)
arrow_type_filters = (DateTimeGreaterFilter, DateTimeSmallerFilter, FilterEmpty)
def convert(self, type_name, column, name, **kwargs):
filter_name = type_name.lower()
......@@ -391,10 +492,15 @@ class FilterConverter(filters.BaseFilterConverter):
@filters.convert('string', 'char', 'unicode', 'varchar', 'tinytext',
'text', 'mediumtext', 'longtext', 'unicodetext',
'nchar', 'nvarchar', 'ntext', 'citext')
'nchar', 'nvarchar', 'ntext', 'citext', 'emailtype',
'URLType', 'IPAddressType')
def conv_string(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.strings]
@filters.convert('UUIDType', 'ColorType', 'TimezoneType', 'CurrencyType')
def conv_string_keys(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.string_key_filters]
@filters.convert('boolean', 'tinyint')
def conv_bool(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.bool_filters]
......@@ -420,6 +526,14 @@ class FilterConverter(filters.BaseFilterConverter):
def conv_time(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.time_filters]
@filters.convert('ChoiceType')
def conv_sqla_utils_choice(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.choice_type_filters]
@filters.convert('ArrowType')
def conv_sqla_utils_arrow(self, column, name, **kwargs):
return [f(column, name, **kwargs) for f in self.arrow_type_filters]
@filters.convert('enum')
def conv_enum(self, column, name, options=None, **kwargs):
if not options:
......
import warnings
from enum import Enum
from enum import Enum, EnumMeta
from wtforms import fields, validators
from sqlalchemy import Boolean, Column
......@@ -13,7 +13,7 @@ from flask_admin.model.helpers import prettify_name
from flask_admin._backwards import get_property
from flask_admin._compat import iteritems, text_type
from .validators import Unique
from .validators import Unique, valid_currency, valid_color, TimeZoneValidator
from .fields import (QuerySelectField, QuerySelectMultipleField,
InlineModelFormList, InlineHstoreList, HstoreForm)
from flask_admin.model.fields import InlineFormField
......@@ -243,9 +243,8 @@ class AdminModelConverter(ModelConverterBase):
if override:
return override(**kwargs)
# Check choices
# Check if a list of 'form_choices' are specified
form_choices = getattr(self.view, 'form_choices', None)
if mapper.class_ == self.view.model and form_choices:
choices = form_choices.get(prop.key)
if choices:
......@@ -263,7 +262,6 @@ class AdminModelConverter(ModelConverterBase):
return converter(model=model, mapper=mapper, prop=prop,
column=column, field_args=kwargs)
return None
@classmethod
......@@ -273,27 +271,52 @@ class AdminModelConverter(ModelConverterBase):
@converts('String') # includes VARCHAR, CHAR, and Unicode
def conv_String(self, column, field_args, **extra):
if hasattr(column.type, 'enums'):
accepted_values = list(column.type.enums)
if column.nullable:
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
field_args['choices'] = [(f, f) for f in column.type.enums]
self._string_common(column=column, field_args=field_args, **extra)
return fields.StringField(**field_args)
if column.nullable:
field_args['allow_blank'] = column.nullable
accepted_values.append(None)
@converts('sqlalchemy.sql.sqltypes.Enum')
def convert_enum(self, column, field_args, **extra):
available_choices = [(f, f) for f in column.type.enums]
accepted_values = [key for key, val in available_choices]
field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v)
if column.nullable:
field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
return form.Select2Field(**field_args)
field_args['choices'] = available_choices
field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v)
return form.Select2Field(**field_args)
@converts('sqlalchemy_utils.types.choice.ChoiceType')
def convert_choice_type(self, column, field_args, **extra):
available_choices = []
# choices can either be specified as an enum, or as a list of tuples
if isinstance(column.type.choices, EnumMeta):
available_choices = [(f.value, f.name) for f in column.type.choices]
else:
available_choices = column.type.choices
accepted_values = [key for key, val in available_choices]
if column.nullable:
field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
self._string_common(column=column, field_args=field_args, **extra)
return fields.StringField(**field_args)
field_args['choices'] = available_choices
field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = choice_type_coerce_factory(column.type)
return form.Select2Field(**field_args)
@converts('Text', 'LargeBinary', 'Binary', 'CIText') # includes UnicodeText
def conv_Text(self, field_args, **extra):
......@@ -317,6 +340,44 @@ class AdminModelConverter(ModelConverterBase):
def convert_time(self, field_args, **extra):
return form.TimeField(**field_args)
@converts('sqlalchemy_utils.types.arrow.ArrowType')
def convert_arrow_time(self, field_args, **extra):
return form.DateTimeField(**field_args)
@converts('sqlalchemy_utils.types.email.EmailType')
def convert_email(self, field_args, **extra):
field_args['validators'].append(validators.Email())
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.url.URLType')
def convert_url(self, field_args, **extra):
field_args['validators'].append(validators.URL())
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.ip_address.IPAddressType')
def convert_ip_address(self, field_args, **extra):
field_args['validators'].append(validators.IPAddress())
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.color.ColorType')
def convert_color(self, field_args, **extra):
field_args['validators'].append(valid_color)
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.currency.CurrencyType')
def convert_currency(self, field_args, **extra):
field_args['validators'].append(valid_currency)
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy_utils.types.timezone.TimezoneType')
def convert_timezone(self, column, field_args, **extra):
field_args['validators'].append(TimeZoneValidator(coerce_function=column.type._coerce))
return fields.StringField(**field_args)
@converts('Integer') # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra):
unsigned = getattr(column.type, 'unsigned', False)
......@@ -342,10 +403,12 @@ class AdminModelConverter(ModelConverterBase):
field_args['validators'].append(validators.MacAddress())
return fields.StringField(**field_args)
@converts('sqlalchemy.dialects.postgresql.base.UUID')
@converts('sqlalchemy.dialects.postgresql.base.UUID',
'sqlalchemy_utils.types.uuid.UUIDType')
def conv_PGUuid(self, field_args, **extra):
field_args.setdefault('label', u'UUID')
field_args['validators'].append(validators.UUID())
field_args['filters'] = [avoid_empty_strings] # don't accept empty strings, or whitespace
return fields.StringField(**field_args)
@converts('sqlalchemy.dialects.postgresql.base.ARRAY',
......@@ -363,6 +426,41 @@ class AdminModelConverter(ModelConverterBase):
return form.JSONField(**field_args)
def avoid_empty_strings(value):
"""
Return None if the incoming value is an empty string or whitespace.
"""
if value:
try:
value = value.strip()
except AttributeError:
# values are not always strings
pass
return value if value else None
def choice_type_coerce_factory(type_):
"""
Return a function to coerce a ChoiceType column, for use by Select2Field.
:param type_: ChoiceType object
"""
from sqlalchemy_utils import Choice
choices = type_.choices
if isinstance(choices, type) and issubclass(choices, Enum):
key, choice_cls = 'value', choices
else:
key, choice_cls = 'code', Choice
def choice_coerce(value):
if value is None:
return None
if isinstance(value, choice_cls):
return getattr(value, key)
return type_.python_type(value)
return choice_coerce
def _resolve_prop(prop):
"""
Resolve proxied property
......
from sqlalchemy.ext.associationproxy import _AssociationList
from flask_admin.model.typefmt import BASE_FORMATTERS, list_formatter
from flask_admin.model.typefmt import BASE_FORMATTERS, EXPORT_FORMATTERS, \
list_formatter
from sqlalchemy.orm.collections import InstrumentedList
def choice_formatter(view, choice):
"""
Return label of selected choice
see https://sqlalchemy-utils.readthedocs.io/
:param choice:
sqlalchemy_utils Choice, which has a `code` and a `value`
"""
return choice.value
def arrow_formatter(view, arrow_time):
"""
Return human-friendly string of the time relative to now.
see https://arrow.readthedocs.io/
:param arrow_time:
Arrow object for handling datetimes
"""
return arrow_time.humanize()
def arrow_export_formatter(view, arrow_time):
"""
Return string representation of Arrow object
see https://arrow.readthedocs.io/
:param arrow_time:
Arrow object for handling datetimes
"""
return arrow_time.format()
DEFAULT_FORMATTERS = BASE_FORMATTERS.copy()
EXPORT_FORMATTERS = EXPORT_FORMATTERS.copy()
DEFAULT_FORMATTERS.update({
InstrumentedList: list_formatter,
_AssociationList: list_formatter
_AssociationList: list_formatter,
})
try:
from sqlalchemy_utils import Choice
DEFAULT_FORMATTERS[Choice] = choice_formatter
except ImportError:
pass
try:
from arrow import Arrow
DEFAULT_FORMATTERS[Arrow] = arrow_formatter
EXPORT_FORMATTERS[Arrow] = arrow_export_formatter
except ImportError:
pass
......@@ -66,3 +66,34 @@ class ItemsRequired(InputRequired):
message = self.message
raise ValidationError(message)
def valid_currency(form, field):
from sqlalchemy_utils import Currency
try:
Currency(field.data)
except (TypeError, ValueError):
raise ValidationError(field.gettext(u'Not a valid ISO currency code (e.g. USD, EUR, CNY).'))
def valid_color(form, field):
from colour import Color
try:
Color(field.data)
except (ValueError):
raise ValidationError(field.gettext(u'Not a valid color (e.g. "red", "#f00", "#ff0000").'))
class TimeZoneValidator(object):
"""
Tries to coerce a TimZone object from input data
"""
def __init__(self, coerce_function):
self.coerce_function = coerce_function
def __call__(self, form, field):
try:
self.coerce_function(str(field.data))
except Exception:
msg = u'Not a valid timezone (e.g. "America/New_York", "Africa/Johannesburg", "Asia/Singapore").'
raise ValidationError(field.gettext(msg))
This diff is collapsed.
......@@ -16,4 +16,7 @@ nose
coveralls
pylint
sqlalchemy-citext
sqlalchemy_utils
azure-storage-blob
arrow<0.14.0
colour
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment