Commit 9cec4976 authored by Serge S. Koval's avatar Serge S. Koval

Separated sqla model conversion from wtforms

parent fe06e21c
Copyright (c) 2012, Serge S. Koval and contributors.
Copyright (c) 2012, Serge S. Koval and contributors. See AUTHORS
for more details.
Some rights reserved.
......
from sqlalchemy.orm.exc import NoResultFound
from wtforms import ValidationError, fields, validators
from wtforms.ext.sqlalchemy.orm import converts, ModelConverter
from wtforms.ext.sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField
from wtforms import fields, validators
from flask.ext.admin import form
from flask.ext.admin.model.form import converts, ModelConverterBase
from .validators import Unique
from .fields import QuerySelectField, QuerySelectMultipleField
class Unique(object):
"""Checks field value unicity against specified table field.
:param get_session:
A function that return a SQAlchemy Session.
:param model:
The model to check unicity against.
:param column:
The unique column.
:param message:
The error message.
"""
field_flags = ('unique', )
def __init__(self, db_session, model, column, message=None):
self.db_session = db_session
self.model = model
self.column = column
self.message = message
def __call__(self, form, field):
try:
obj = (self.db_session.query(self.model)
.filter(self.column == field.data).one())
if not hasattr(form, '_obj') or not form._obj == obj:
if self.message is None:
self.message = field.gettext(u'Already exists.')
raise ValidationError(self.message)
except NoResultFound:
pass
class AdminModelConverter(ModelConverter):
class AdminModelConverter(ModelConverterBase):
"""
SQLAlchemy model to form converter
"""
......@@ -64,7 +31,7 @@ class AdminModelConverter(ModelConverter):
return None
def convert(self, model, mapper, prop, field_args, *args):
def convert(self, model, mapper, prop, field_args):
kwargs = {
'validators': [],
'filters': []
......@@ -73,6 +40,7 @@ class AdminModelConverter(ModelConverter):
if field_args:
kwargs.update(field_args)
# Check if it is relation or property
if hasattr(prop, 'direction'):
remote_model = prop.mapper.class_
local_column = prop.local_remote_pairs[0][0]
......@@ -145,15 +113,59 @@ class AdminModelConverter(ModelConverter):
# Apply label
kwargs['label'] = self._get_label(prop.key, kwargs)
# Check if more than one column mapped to the property
if len(prop.columns) != 1:
raise TypeError('Can not convert multiple-column properties (%s.%s)' % (model, prop.key))
# Figure out default value
default = getattr(column, 'default', None)
if default is not None:
callable_default = getattr(default, 'arg', None)
if callable_default and callable(callable_default):
default = callable_default(None)
if default is not None:
kwargs['default'] = default
# Check nullable
if column.nullable:
kwargs['validators'].append(validators.Optional())
# Override field type if necessary
override = self._get_field_override(prop.key)
if override:
return override(**kwargs)
return super(AdminModelConverter, self).convert(model,
mapper,
prop,
kwargs)
# Run converter
converter = self.get_converter(column)
if converter is None:
return None
return converter(model=model, mapper=mapper, prop=prop,
column=column, field_args=kwargs)
@classmethod
def _string_common(cls, column, field_args, **extra):
if column.type.length:
field_args['validators'].append(validators.Length(max=column.type.length))
@converts('String', 'Unicode')
def conv_String(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return fields.TextField(**field_args)
@converts('Text', 'UnicodeText',
'types.LargeBinary', 'types.Binary')
def conv_Text(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return fields.TextAreaField(**field_args)
@converts('Boolean')
def conv_Boolean(self, field_args, **extra):
return fields.BooleanField(**field_args)
@converts('Date')
def convert_date(self, field_args, **extra):
......@@ -168,3 +180,68 @@ class AdminModelConverter(ModelConverter):
@converts('Time')
def convert_time(self, field_args, **extra):
return form.TimeField(**field_args)
@converts('Integer', 'SmallInteger')
def handle_integer_types(self, column, field_args, **extra):
unsigned = getattr(column.type, 'unsigned', False)
if unsigned:
field_args['validators'].append(validators.NumberRange(min=0))
return fields.IntegerField(**field_args)
@converts('Numeric', 'Float')
def handle_decimal_types(self, column, field_args, **extra):
places = getattr(column.type, 'scale', 2)
if places is not None:
field_args['places'] = places
return fields.DecimalField(**field_args)
@converts('databases.mysql.MSYear')
def conv_MSYear(self, field_args, **extra):
field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
return fields.TextField(**field_args)
@converts('databases.postgres.PGInet', 'dialects.postgresql.base.INET')
def conv_PGInet(self, field_args, **extra):
field_args.setdefault('label', u'IP Address')
field_args['validators'].append(validators.IPAddress())
return fields.TextField(**field_args)
@converts('dialects.postgresql.base.MACADDR')
def conv_PGMacaddr(self, field_args, **extra):
field_args.setdefault('label', u'MAC Address')
field_args['validators'].append(validators.MacAddress())
return fields.TextField(**field_args)
@converts('dialects.postgresql.base.UUID')
def conv_PGUuid(self, field_args, **extra):
field_args.setdefault('label', u'UUID')
field_args['validators'].append(validators.UUID())
return fields.TextField(**field_args)
def model_fields(model, converter, only=None, exclude=None, field_args=None):
"""
Generate a dictionary of fields for a given SQLAlchemy model.
See `model_form` docstring for description of parameters.
"""
# TODO: Support new 0.8 API
if not hasattr(model, '_sa_class_manager'):
raise TypeError('model must be a sqlalchemy mapped model')
mapper = model._sa_class_manager.mapper
field_args = field_args or {}
properties = ((p.key, p) for p in mapper.iterate_properties)
if only:
properties = (x for x in properties if x[0] in only)
elif exclude:
properties = (x for x in properties if x[0] not in exclude)
field_dict = {}
for name, prop in properties:
field = converter.convert(model, mapper, prop, field_args.get(name))
if field is not None:
field_dict[name] = field
return field_dict
......@@ -2,13 +2,12 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm import subqueryload
from sqlalchemy.sql.expression import desc
from sqlalchemy import or_, Column
from wtforms.ext.sqlalchemy.orm import model_form
from flask import flash
from flask.ext.admin.babel import gettext, ngettext, lazy_gettext
from flask.ext.admin.form import BaseForm
from flask.ext.admin.model import BaseModelView
from flask.ext.admin.model.form import model_form
from flask.ext.admin.actions import action
from flask.ext.admin.contrib.sqlamodel import form, filters, tools
......@@ -114,6 +113,18 @@ class ModelView(BaseModelView):
Override this attribute to use non-default converter.
"""
fast_mass_delete = False
"""
If set to `False` and user deletes more than one model using actions,
all models will be read from the database and then deleted one by one
giving SQLAlchemy chance to manually cleanup any dependencies (many-to-many
relationships, etc).
If set to True, will run DELETE statement which is somewhat faster, but
might leave corrupted data if you forget to configure DELETE CASCADE
for your model.
"""
def __init__(self, model, session,
name=None, category=None, endpoint=None, url=None):
"""
......@@ -360,12 +371,13 @@ class ModelView(BaseModelView):
"""
Create form from the model.
"""
return model_form(self.model,
BaseForm,
form_fields = form.model_fields(
self.model,
form.AdminModelConverter(self),
only=self.form_columns,
exclude=self.excluded_form_columns,
field_args=self.form_args,
converter=form.AdminModelConverter(self))
field_args=self.form_args)
return model_form(self.model, form_fields)
def scaffold_auto_joins(self):
"""
......@@ -568,8 +580,14 @@ class ModelView(BaseModelView):
query = self.session.query(self.model).filter(model_pk.in_(ids))
# TODO: Load up ORM and delete models one by one?
count = query.delete(synchronize_session=False)
if self.fast_mass_delete:
count = query.delete(synchronize_session=False)
else:
count = 0
for m in query.all():
self.session.delete(m)
count += 1
self.session.commit()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment