Commit 0fefc1c8 authored by PJ Janse van Rensburg's avatar PJ Janse van Rensburg

In-progress: Add tests for ChoiceType field.

parent effa3e3f
...@@ -15,6 +15,12 @@ from flask_admin.contrib.sqla.fields import InlineModelFormList ...@@ -15,6 +15,12 @@ from flask_admin.contrib.sqla.fields import InlineModelFormList
from flask_admin.contrib.sqla.filters import BaseSQLAFilter, FilterEqual from flask_admin.contrib.sqla.filters import BaseSQLAFilter, FilterEqual
from sqlalchemy_utils.types import ChoiceType, EmailType from sqlalchemy_utils.types import ChoiceType, EmailType
import enum
class EnumChoices(enum.Enum):
first = 1
second = 2
# Create application # Create application
...@@ -46,6 +52,7 @@ class User(db.Model): ...@@ -46,6 +52,7 @@ class User(db.Model):
type = db.Column(ChoiceType(AVAILABLE_TYPES), nullable=True) type = db.Column(ChoiceType(AVAILABLE_TYPES), nullable=True)
email = db.Column(EmailType, unique=True, nullable=False) email = db.Column(EmailType, unique=True, nullable=False)
pets = db.relationship('Pet', backref='owner') pets = db.relationship('Pet', backref='owner')
enum_choice_field = db.Column(ChoiceType(EnumChoices, impl=db.Integer()), nullable=True)
def __str__(self): def __str__(self):
return "{}, {}".format(self.last_name, self.first_name) return "{}, {}".format(self.last_name, self.first_name)
......
import warnings import warnings
from enum import Enum from enum import Enum, EnumMeta
from wtforms import fields, validators from wtforms import fields, validators
from sqlalchemy import Boolean, Column from sqlalchemy import Boolean, Column
from sqlalchemy.orm import ColumnProperty from sqlalchemy.orm import ColumnProperty
from sqlalchemy_utils import Choice
from flask_admin import form from flask_admin import form
from flask_admin.model.form import (converts, ModelConverterBase, from flask_admin.model.form import (converts, ModelConverterBase,
...@@ -277,6 +278,9 @@ class AdminModelConverter(ModelConverterBase): ...@@ -277,6 +278,9 @@ class AdminModelConverter(ModelConverterBase):
if hasattr(column.type, 'enums'): if hasattr(column.type, 'enums'):
available_choices = [(f, f) for f in column.type.enums] available_choices = [(f, f) for f in column.type.enums]
elif hasattr(column.type, 'choices'): elif hasattr(column.type, 'choices'):
if isinstance(column.type.choices, EnumMeta):
available_choices = [(str(f.value), f.name) for f in column.type.choices]
else:
available_choices = column.type.choices available_choices = column.type.choices
if available_choices: if available_choices:
field_args['choices'] = available_choices field_args['choices'] = available_choices
...@@ -287,7 +291,7 @@ class AdminModelConverter(ModelConverterBase): ...@@ -287,7 +291,7 @@ class AdminModelConverter(ModelConverterBase):
accepted_values.append(None) accepted_values.append(None)
field_args['validators'].append(validators.AnyOf(accepted_values)) field_args['validators'].append(validators.AnyOf(accepted_values))
field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else text_type(v) field_args['coerce'] = lambda v: v.name if isinstance(v, Enum) else (v.code if isinstance(v, Choice) else text_type(v))
return form.Select2Field(**field_args) return form.Select2Field(**field_args)
......
...@@ -3,17 +3,24 @@ from nose.tools import eq_, ok_, raises, assert_true ...@@ -3,17 +3,24 @@ from nose.tools import eq_, ok_, raises, assert_true
from wtforms import fields, validators from wtforms import fields, validators
from flask_admin import form from flask_admin import form
from flask_admin.form.fields import Select2Field
from flask_admin._compat import as_unicode from flask_admin._compat import as_unicode
from flask_admin._compat import iteritems from flask_admin._compat import iteritems
from flask_admin.contrib.sqla import ModelView, filters, tools from flask_admin.contrib.sqla import ModelView, filters, tools
from flask_babelex import Babel from flask_babelex import Babel
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils.types.email import EmailType from sqlalchemy_utils import EmailType, ChoiceType
from . import setup from . import setup
from datetime import datetime, time, date from datetime import datetime, time, date
import enum
class EnumChoices(enum.Enum):
first = 1
second = 2
class CustomModelView(ModelView): class CustomModelView(ModelView):
...@@ -31,7 +38,8 @@ def create_models(db): ...@@ -31,7 +38,8 @@ def create_models(db):
class Model1(db.Model): class Model1(db.Model):
def __init__(self, test1=None, test2=None, test3=None, test4=None, def __init__(self, test1=None, test2=None, test3=None, test4=None,
bool_field=False, date_field=None, time_field=None, bool_field=False, date_field=None, time_field=None,
datetime_field=None, enum_field=None, email_field=None): datetime_field=None, enum_field=None, email_field=None,
choice_field=None, enum_choice_field=None):
self.test1 = test1 self.test1 = test1
self.test2 = test2 self.test2 = test2
self.test3 = test3 self.test3 = test3
...@@ -41,6 +49,9 @@ def create_models(db): ...@@ -41,6 +49,9 @@ def create_models(db):
self.time_field = time_field self.time_field = time_field
self.datetime_field = datetime_field self.datetime_field = datetime_field
self.enum_field = enum_field self.enum_field = enum_field
self.email_field = email_field
self.choice_field = choice_field
self.enum_choice_field = enum_choice_field
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
test1 = db.Column(db.String(20)) test1 = db.Column(db.String(20))
...@@ -53,6 +64,11 @@ def create_models(db): ...@@ -53,6 +64,11 @@ def create_models(db):
time_field = db.Column(db.Time) time_field = db.Column(db.Time)
datetime_field = db.Column(db.DateTime) datetime_field = db.Column(db.DateTime)
email_field = db.Column(EmailType) email_field = db.Column(EmailType)
choice_field = db.Column(ChoiceType([
('choice-1', u'First choice'),
('choice-2', u'Second choice')
]))
enum_choice_field = db.Column(ChoiceType(EnumChoices, impl=db.Integer()))
def __unicode__(self): def __unicode__(self):
return self.test1 return self.test1
...@@ -96,7 +112,7 @@ def fill_db(db, Model1, Model2): ...@@ -96,7 +112,7 @@ def fill_db(db, Model1, Model2):
model1_obj1 = Model1('test1_val_1', 'test2_val_1', bool_field=True) model1_obj1 = Model1('test1_val_1', 'test2_val_1', bool_field=True)
model1_obj2 = Model1('test1_val_2', 'test2_val_2', bool_field=False) model1_obj2 = Model1('test1_val_2', 'test2_val_2', bool_field=False)
model1_obj3 = Model1('test1_val_3', 'test2_val_3') model1_obj3 = Model1('test1_val_3', 'test2_val_3')
model1_obj4 = Model1('test1_val_4', 'test2_val_4') model1_obj4 = Model1('test1_val_4', 'test2_val_4', email_field="test@test.com", choice_field="choice-1")
model2_obj1 = Model2('test2_val_1', model1=model1_obj1, float_field=None) model2_obj1 = Model2('test2_val_1', model1=model1_obj1, float_field=None)
model2_obj2 = Model2('test2_val_2', model1=model1_obj2, float_field=None) model2_obj2 = Model2('test2_val_2', model1=model1_obj2, float_field=None)
...@@ -155,6 +171,8 @@ def test_model(): ...@@ -155,6 +171,8 @@ def test_model():
eq_(view._create_form_class.test3.field_class, fields.TextAreaField) eq_(view._create_form_class.test3.field_class, fields.TextAreaField)
eq_(view._create_form_class.test4.field_class, fields.TextAreaField) eq_(view._create_form_class.test4.field_class, fields.TextAreaField)
eq_(view._create_form_class.email_field.field_class, fields.StringField) eq_(view._create_form_class.email_field.field_class, fields.StringField)
eq_(view._create_form_class.choice_field.field_class, Select2Field)
eq_(view._create_form_class.enum_choice_field.field_class, Select2Field)
# Make some test clients # Make some test clients
client = app.test_client() client = app.test_client()
...@@ -169,7 +187,9 @@ def test_model(): ...@@ -169,7 +187,9 @@ def test_model():
data=dict(test1='test1large', data=dict(test1='test1large',
test2='test2', test2='test2',
time_field=time(0, 0, 0), time_field=time(0, 0, 0),
email_field="Test@TEST.com")) email_field="Test@TEST.com",
choice_field="choice-1",
enum_choice_field=1))
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
model = db.session.query(Model1).first() model = db.session.query(Model1).first()
...@@ -178,6 +198,8 @@ def test_model(): ...@@ -178,6 +198,8 @@ def test_model():
eq_(model.test3, u'') eq_(model.test3, u'')
eq_(model.test4, u'') eq_(model.test4, u'')
eq_(model.email_field, u'test@test.com') eq_(model.email_field, u'test@test.com')
eq_(model.choice_field, u'choice-1')
eq_(model.enum_choice_field, EnumChoices(1))
rv = client.get('/admin/model1/') rv = client.get('/admin/model1/')
eq_(rv.status_code, 200) eq_(rv.status_code, 200)
...@@ -193,7 +215,8 @@ def test_model(): ...@@ -193,7 +215,8 @@ def test_model():
rv = client.post(url, rv = client.post(url,
data=dict(test1='test1small', data=dict(test1='test1small',
test2='test2large', test2='test2large',
email_field='Test2@TEST.com')) email_field='Test2@TEST.com',
choice_field=None))
eq_(rv.status_code, 302) eq_(rv.status_code, 302)
model = db.session.query(Model1).first() model = db.session.query(Model1).first()
...@@ -292,7 +315,9 @@ def test_exclude_columns(): ...@@ -292,7 +315,9 @@ def test_exclude_columns():
eq_( eq_(
view._list_columns, view._list_columns,
[('test1', 'Test1'), ('test3', 'Test3'), ('bool_field', 'Bool Field'), ('email_field', 'Email Field')] [('test1', 'Test1'), ('test3', 'Test3'), ('bool_field', 'Bool Field'),
('email_field', 'Email Field'), ('choice_field', 'Choice Field'),
('enum_choice_field', 'Enum Choice Field')]
) )
client = app.test_client() client = app.test_client()
...@@ -1528,7 +1553,6 @@ def test_form_columns(): ...@@ -1528,7 +1553,6 @@ def test_form_columns():
ok_('int_field' in form1._fields) ok_('int_field' in form1._fields)
ok_('text_field' in form1._fields) ok_('text_field' in form1._fields)
ok_('datetime_field' not in form1._fields) ok_('datetime_field' not in form1._fields)
ok_('excluded_column' not in form2._fields) ok_('excluded_column' not in form2._fields)
ok_(type(form3.model).__name__ == 'QuerySelectField') ok_(type(form3.model).__name__ == 'QuerySelectField')
...@@ -1539,6 +1563,9 @@ def test_form_columns(): ...@@ -1539,6 +1563,9 @@ def test_form_columns():
form4 = view4.create_form() form4 = view4.create_form()
ok_('int_field' in form4._fields) ok_('int_field' in form4._fields)
# test form_columns with special sqlalchemy_utils types
@raises(Exception) @raises(Exception)
def test_complex_form_columns(): def test_complex_form_columns():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment