Commit 0fefc1c8 authored by PJ Janse van Rensburg's avatar PJ Janse van Rensburg

In-progress: Add tests for ChoiceType field.

parent effa3e3f
......@@ -15,6 +15,12 @@ from flask_admin.contrib.sqla.fields import InlineModelFormList
from flask_admin.contrib.sqla.filters import BaseSQLAFilter, FilterEqual
from sqlalchemy_utils.types import ChoiceType, EmailType
import enum
class EnumChoices(enum.Enum):
first = 1
second = 2
# Create application
......@@ -46,6 +52,7 @@ class User(db.Model):
type = db.Column(ChoiceType(AVAILABLE_TYPES), nullable=True)
email = db.Column(EmailType, unique=True, nullable=False)
pets = db.relationship('Pet', backref='owner')
enum_choice_field = db.Column(ChoiceType(EnumChoices, impl=db.Integer()), nullable=True)
def __str__(self):
return "{}, {}".format(self.last_name, self.first_name)
......
import warnings
from enum import Enum
from enum import Enum, EnumMeta
from wtforms import fields, validators
from sqlalchemy import Boolean, Column
from sqlalchemy.orm import ColumnProperty
from sqlalchemy_utils import Choice
from flask_admin import form
from flask_admin.model.form import (converts, ModelConverterBase,
......@@ -277,7 +278,10 @@ class AdminModelConverter(ModelConverterBase):
if hasattr(column.type, 'enums'):
available_choices = [(f, f) for f in column.type.enums]
elif hasattr(column.type, 'choices'):
available_choices = column.type.choices
if isinstance(column.type.choices, EnumMeta):
available_choices = [(str(f.value), f.name) for f in column.type.choices]
else:
available_choices = column.type.choices
if available_choices:
field_args['choices'] = available_choices
accepted_values = [key for key, val in available_choices]
......@@ -287,7 +291,7 @@ class AdminModelConverter(ModelConverterBase):
accepted_values.append(None)
field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v)
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else (v.code if isinstance(v, Choice) else text_type(v))
return form.Select2Field(**field_args)
......
......@@ -3,17 +3,24 @@ from nose.tools import eq_, ok_, raises, assert_true
from wtforms import fields, validators
from flask_admin import form
from flask_admin.form.fields import Select2Field
from flask_admin._compat import as_unicode
from flask_admin._compat import iteritems
from flask_admin.contrib.sqla import ModelView, filters, tools
from flask_babelex import Babel
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils.types.email import EmailType
from sqlalchemy_utils import EmailType, ChoiceType
from . import setup
from datetime import datetime, time, date
import enum
class EnumChoices(enum.Enum):
first = 1
second = 2
class CustomModelView(ModelView):
......@@ -31,7 +38,8 @@ def create_models(db):
class Model1(db.Model):
def __init__(self, test1=None, test2=None, test3=None, test4=None,
bool_field=False, date_field=None, time_field=None,
datetime_field=None, enum_field=None, email_field=None):
datetime_field=None, enum_field=None, email_field=None,
choice_field=None, enum_choice_field=None):
self.test1 = test1
self.test2 = test2
self.test3 = test3
......@@ -41,6 +49,9 @@ def create_models(db):
self.time_field = time_field
self.datetime_field = datetime_field
self.enum_field = enum_field
self.email_field = email_field
self.choice_field = choice_field
self.enum_choice_field = enum_choice_field
id = db.Column(db.Integer, primary_key=True)
test1 = db.Column(db.String(20))
......@@ -53,6 +64,11 @@ def create_models(db):
time_field = db.Column(db.Time)
datetime_field = db.Column(db.DateTime)
email_field = db.Column(EmailType)
choice_field = db.Column(ChoiceType([
('choice-1', u'First choice'),
('choice-2', u'Second choice')
]))
enum_choice_field = db.Column(ChoiceType(EnumChoices, impl=db.Integer()))
def __unicode__(self):
return self.test1
......@@ -96,7 +112,7 @@ def fill_db(db, Model1, Model2):
model1_obj1 = Model1('test1_val_1', 'test2_val_1', bool_field=True)
model1_obj2 = Model1('test1_val_2', 'test2_val_2', bool_field=False)
model1_obj3 = Model1('test1_val_3', 'test2_val_3')
model1_obj4 = Model1('test1_val_4', 'test2_val_4')
model1_obj4 = Model1('test1_val_4', 'test2_val_4', email_field="test@test.com", choice_field="choice-1")
model2_obj1 = Model2('test2_val_1', model1=model1_obj1, float_field=None)
model2_obj2 = Model2('test2_val_2', model1=model1_obj2, float_field=None)
......@@ -155,6 +171,8 @@ def test_model():
eq_(view._create_form_class.test3.field_class, fields.TextAreaField)
eq_(view._create_form_class.test4.field_class, fields.TextAreaField)
eq_(view._create_form_class.email_field.field_class, fields.StringField)
eq_(view._create_form_class.choice_field.field_class, Select2Field)
eq_(view._create_form_class.enum_choice_field.field_class, Select2Field)
# Make some test clients
client = app.test_client()
......@@ -169,7 +187,9 @@ def test_model():
data=dict(test1='test1large',
test2='test2',
time_field=time(0, 0, 0),
email_field="Test@TEST.com"))
email_field="Test@TEST.com",
choice_field="choice-1",
enum_choice_field=1))
eq_(rv.status_code, 302)
model = db.session.query(Model1).first()
......@@ -178,6 +198,8 @@ def test_model():
eq_(model.test3, u'')
eq_(model.test4, u'')
eq_(model.email_field, u'test@test.com')
eq_(model.choice_field, u'choice-1')
eq_(model.enum_choice_field, EnumChoices(1))
rv = client.get('/admin/model1/')
eq_(rv.status_code, 200)
......@@ -193,7 +215,8 @@ def test_model():
rv = client.post(url,
data=dict(test1='test1small',
test2='test2large',
email_field='Test2@TEST.com'))
email_field='Test2@TEST.com',
choice_field=None))
eq_(rv.status_code, 302)
model = db.session.query(Model1).first()
......@@ -292,7 +315,9 @@ def test_exclude_columns():
eq_(
view._list_columns,
[('test1', 'Test1'), ('test3', 'Test3'), ('bool_field', 'Bool Field'), ('email_field', 'Email Field')]
[('test1', 'Test1'), ('test3', 'Test3'), ('bool_field', 'Bool Field'),
('email_field', 'Email Field'), ('choice_field', 'Choice Field'),
('enum_choice_field', 'Enum Choice Field')]
)
client = app.test_client()
......@@ -1528,7 +1553,6 @@ def test_form_columns():
ok_('int_field' in form1._fields)
ok_('text_field' in form1._fields)
ok_('datetime_field' not in form1._fields)
ok_('excluded_column' not in form2._fields)
ok_(type(form3.model).__name__ == 'QuerySelectField')
......@@ -1539,6 +1563,9 @@ def test_form_columns():
form4 = view4.create_form()
ok_('int_field' in form4._fields)
# test form_columns with special sqlalchemy_utils types
@raises(Exception)
def test_complex_form_columns():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment