Commit cadf5bc9 authored by Serge S. Koval's avatar Serge S. Koval

Simple Django-like search for SQLa models.

parent 0a239f91
...@@ -10,14 +10,16 @@ ...@@ -10,14 +10,16 @@
- Built-in filtering support - Built-in filtering support
- Configurable operations (=, >, <, etc) - Configurable operations (=, >, <, etc)
- Callable operations - Callable operations
- Custom paginator class? - Built-in search support
- Paginator class
- Custom CSS/JS in admin interface - Custom CSS/JS in admin interface
- SQLA Model Admin - SQLA Model Admin
- Validation of the joins in the query
- Built-in filtering support - Built-in filtering support
- Built-in search support
- Support for related models
- Many2Many support - Many2Many support
- Verify if it is working properly - Verify if it is working properly
- WYSIWYG editor support - WYSIWYG editor support?
- File admin - File admin
- Header title - Header title
- Mass-delete functionality - Mass-delete functionality
......
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
.. autoattribute:: BaseModelView.list_columns .. autoattribute:: BaseModelView.list_columns
.. autoattribute:: BaseModelView.rename_columns .. autoattribute:: BaseModelView.rename_columns
.. autoattribute:: BaseModelView.sortable_columns .. autoattribute:: BaseModelView.sortable_columns
.. autoattribute:: ModelView.searchable_columns
.. autoattribute:: BaseModelView.form_columns .. autoattribute:: BaseModelView.form_columns
.. autoattribute:: BaseModelView.form_args .. autoattribute:: BaseModelView.form_args
...@@ -58,6 +61,8 @@ ...@@ -58,6 +61,8 @@
.. automethod:: ModelView.get_create_form .. automethod:: ModelView.get_create_form
.. automethod:: ModelView.get_edit_form .. automethod:: ModelView.get_edit_form
.. automethod:: ModelView.init_search
Data Data
---- ----
...@@ -88,4 +93,5 @@ ...@@ -88,4 +93,5 @@
------------ ------------
.. automethod:: ModelView._get_url .. automethod:: ModelView._get_url
.. automethod:: ModelView.scaffold_auto_joins .. automethod:: ModelView.scaffold_auto_joins
\ No newline at end of file .. automethod:: ModelView.is_text_column_type
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
.. autoattribute:: BaseModelView.list_columns .. autoattribute:: BaseModelView.list_columns
.. autoattribute:: BaseModelView.rename_columns .. autoattribute:: BaseModelView.rename_columns
.. autoattribute:: BaseModelView.sortable_columns .. autoattribute:: BaseModelView.sortable_columns
.. autoattribute:: BaseModelView.searchable_columns
.. autoattribute:: BaseModelView.form_columns .. autoattribute:: BaseModelView.form_columns
.. autoattribute:: BaseModelView.form_args .. autoattribute:: BaseModelView.form_args
...@@ -51,6 +54,8 @@ ...@@ -51,6 +54,8 @@
.. automethod:: BaseModelView.get_create_form .. automethod:: BaseModelView.get_create_form
.. automethod:: BaseModelView.get_edit_form .. automethod:: BaseModelView.get_edit_form
.. automethod:: BaseModelView.init_search
Data Data
---- ----
......
...@@ -58,6 +58,8 @@ class PostAdmin(sqlamodel.ModelView): ...@@ -58,6 +58,8 @@ class PostAdmin(sqlamodel.ModelView):
# Rename 'title' columns to 'Post Title' in list view # Rename 'title' columns to 'Post Title' in list view
rename_columns = dict(title='Post Title') rename_columns = dict(title='Post Title')
searchable_columns = ('title', User.username)
# Pass arguments to WTForms. In this case, change label for text field to # Pass arguments to WTForms. In this case, change label for text field to
# be 'Big Text' and add required() validator. # be 'Big Text' and add required() validator.
form_args = dict( form_args = dict(
......
...@@ -2,6 +2,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute ...@@ -2,6 +2,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm import subqueryload from sqlalchemy.orm import subqueryload
from sqlalchemy.sql.expression import desc from sqlalchemy.sql.expression import desc
from sqlalchemy import or_
from wtforms import ValidationError, fields, validators from wtforms import ValidationError, fields, validators
from wtforms.ext.sqlalchemy.orm import model_form, converts, ModelConverter from wtforms.ext.sqlalchemy.orm import model_form, converts, ModelConverter
...@@ -187,6 +188,38 @@ class ModelView(BaseModelView): ...@@ -187,6 +188,38 @@ class ModelView(BaseModelView):
Please refer to the `subqueryload` on list of possible values. Please refer to the `subqueryload` on list of possible values.
""" """
searchable_columns = None
"""
Collection of the searchable columns. Only text-based columns
are searchable (`String`, `Unicode`, `Text`, `UnicodeText`).
Example::
class MyModelView(ModelView):
searchable_columns = ('name', 'email')
You can also pass columns::
class MyModelView(ModelView):
searchable_columns = (User.name, User.email)
Following search rules apply:
- If you enter *ZZZ* in the UI search field, it will generate *ILIKE '%ZZZ%'*
statement against searchable columns.
- If you enter multiple words, each word will be searched separately, but
only rows that contain all words will be displayed. For example, searching
for 'abc def' will find all rows that contain 'abc' and 'def' in one or
more columns.
- If you prefix your search term with ^, it will find all rows
that start with ^. So, if you entered *^ZZZ*, *ILIKE 'ZZZ%'* will be used.
- If you prefix your search term with =, it will do exact match.
For example, if you entered *=ZZZ*, *ILIKE 'ZZZ'* statement will be used.
"""
def __init__(self, model, session, def __init__(self, model, session,
name=None, category=None, endpoint=None, url=None): name=None, category=None, endpoint=None, url=None):
""" """
...@@ -207,6 +240,10 @@ class ModelView(BaseModelView): ...@@ -207,6 +240,10 @@ class ModelView(BaseModelView):
""" """
self.session = session self.session = session
self._search_fields = None
self._search_joins = None
self._search_joins_names = None
super(ModelView, self).__init__(model, name, category, endpoint, url) super(ModelView, self).__init__(model, name, category, endpoint, url)
# Configuration # Configuration
...@@ -217,6 +254,9 @@ class ModelView(BaseModelView): ...@@ -217,6 +254,9 @@ class ModelView(BaseModelView):
# Internal API # Internal API
def _get_model_iterator(self): def _get_model_iterator(self):
"""
Return property iterator for the model
"""
return self.model._sa_class_manager.mapper.iterate_properties return self.model._sa_class_manager.mapper.iterate_properties
# Scaffolding # Scaffolding
...@@ -266,6 +306,57 @@ class ModelView(BaseModelView): ...@@ -266,6 +306,57 @@ class ModelView(BaseModelView):
return columns return columns
def init_search(self):
"""
Initialize search. Returns `True` if search is supported for this
view.
For SQLAlchemy, this will initialize internal fields: list of
column objects used for filtering, etc.
"""
if self.searchable_columns:
self._search_fields = []
self._search_joins = []
self._search_joins_names = set()
for p in self.searchable_columns:
# If item is a stirng, resolve it as an attribute
if isinstance(p, basestring):
attr = getattr(self.model, p, None)
else:
attr = p
# Only column searches are supported
if (not attr or
not hasattr(attr, 'property') or
not hasattr(attr.property, 'columns')):
raise Exception('Invalid searchable column "%s"' % p)
for column in attr.property.columns:
column_type = type(column.type).__name__
if not self.is_text_column_type(column_type):
raise Exception('Can only search on text columns. ' +
'Failed to setup search for "%s"' % p)
self._search_fields.append(column)
# If it belongs to different table - add a join
if column.table != self.model.__table__:
self._search_joins.append(column.table)
self._search_joins_names.add(column.table.name)
return bool(self.searchable_columns)
def is_text_column_type(self, name):
"""
Verify if column type is text-based.
Returns `True` for `String`, `Unicode`, `Text`, `UnicodeText`
"""
return (name == 'String' or name == 'Unicode' or
name == 'Text' or name == 'UnicodeText')
def scaffold_form(self): def scaffold_form(self):
""" """
Create form from the model. Create form from the model.
...@@ -297,7 +388,7 @@ class ModelView(BaseModelView): ...@@ -297,7 +388,7 @@ class ModelView(BaseModelView):
return joined return joined
# Database-related API # Database-related API
def get_list(self, page, sort_column, sort_desc, execute=True): def get_list(self, page, sort_column, sort_desc, search, execute=True):
""" """
Return models from the database. Return models from the database.
...@@ -307,11 +398,42 @@ class ModelView(BaseModelView): ...@@ -307,11 +398,42 @@ class ModelView(BaseModelView):
Sort column name Sort column name
`sort_desc` `sort_desc`
Descending or ascending sort Descending or ascending sort
`search`
Search query
`execute` `execute`
Execute query immediately? Default is `True` Execute query immediately? Default is `True`
""" """
# Will contain names of joined tables to avoid duplicate joins
joins = set()
query = self.session.query(self.model) query = self.session.query(self.model)
# Apply search before counting results
if self._search_supported and search:
# Apply search-related joins
if self._search_joins:
query = query.join(*self._search_joins)
joins |= self._search_joins_names
# Apply terms
terms = search.split(' ')
for term in terms:
if not term:
continue
if term.startswith('^'):
stmt = '%s%%' % term[1:]
elif term.startswith('='):
stmt = term[1:]
else:
stmt = '%%%s%%' % term
filter_stmt = [c.ilike(stmt) for c in self._search_fields]
query = query.filter(or_(*filter_stmt))
# Calculate number of rows
count = query.count() count = query.count()
# Auto join # Auto join
...@@ -329,9 +451,16 @@ class ModelView(BaseModelView): ...@@ -329,9 +451,16 @@ class ModelView(BaseModelView):
# contains dot. # contains dot.
if '.' in sort_field: if '.' in sort_field:
parts = sort_field.split('.', 1) parts = sort_field.split('.', 1)
query = query.join(parts[0])
if parts[0] not in joins:
query = query.join(parts[0])
joins.add(parts[0])
elif isinstance(sort_field, InstrumentedAttribute): elif isinstance(sort_field, InstrumentedAttribute):
query = query.join(sort_field.parententity) table = sort_field.parententity.tables[0]
if table.name not in joins:
query = query.join(table)
joins.add(table.name)
else: else:
sort_field = None sort_field = None
......
...@@ -85,6 +85,18 @@ class BaseModelView(BaseView): ...@@ -85,6 +85,18 @@ class BaseModelView(BaseView):
sortable_columns = ('name', ('user', User.username)) sortable_columns = ('name', ('user', User.username))
""" """
searchable_columns = None
"""
Collection of the searchable columns. It is assumed that only
text-only fields are searchable, but it is up for a model implementation
to make decision.
For example::
class MyModelView(BaseModelView):
searchable_columns = ('name', 'email')
"""
form_columns = None form_columns = None
""" """
Collection of the model field names for the form. If set to `None` will Collection of the model field names for the form. If set to `None` will
...@@ -160,6 +172,8 @@ class BaseModelView(BaseView): ...@@ -160,6 +172,8 @@ class BaseModelView(BaseView):
self._create_form_class = self.get_create_form() self._create_form_class = self.get_create_form()
self._edit_form_class = self.get_edit_form() self._edit_form_class = self.get_edit_form()
self._search_supported = self.init_search()
# Public API # Public API
def scaffold_list_columns(self): def scaffold_list_columns(self):
""" """
...@@ -225,6 +239,13 @@ class BaseModelView(BaseView): ...@@ -225,6 +239,13 @@ class BaseModelView(BaseView):
return result return result
def init_search(self):
"""
Initialize search. If data provider does not support search,
`init_search` will return `False`.
"""
return False
def scaffold_form(self): def scaffold_form(self):
""" """
Create `form.BaseForm` inherited class from the model. Must be implemented in Create `form.BaseForm` inherited class from the model. Must be implemented in
...@@ -284,7 +305,7 @@ class BaseModelView(BaseView): ...@@ -284,7 +305,7 @@ class BaseModelView(BaseView):
return self._list_columns[idx] return self._list_columns[idx]
# Database-related API # Database-related API
def get_list(self, page, sort_field, sort_desc): def get_list(self, page, sort_field, sort_desc, search):
""" """
Return list of models from the data source with applied pagination Return list of models from the data source with applied pagination
and sorting. and sorting.
...@@ -297,6 +318,8 @@ class BaseModelView(BaseView): ...@@ -297,6 +318,8 @@ class BaseModelView(BaseView):
Sort column name or None. Sort column name or None.
`sort_desc` `sort_desc`
If set to True, sorting is in descending order. If set to True, sorting is in descending order.
`search`
Search query
""" """
raise NotImplemented('Please implement get_list method') raise NotImplemented('Please implement get_list method')
...@@ -373,10 +396,11 @@ class BaseModelView(BaseView): ...@@ -373,10 +396,11 @@ class BaseModelView(BaseView):
page = request.args.get('page', 0, type=int) page = request.args.get('page', 0, type=int)
sort = request.args.get('sort', None, type=int) sort = request.args.get('sort', None, type=int)
sort_desc = request.args.get('desc', None, type=int) sort_desc = request.args.get('desc', None, type=int)
search = request.args.get('search', None)
return page, sort, sort_desc return page, sort, sort_desc, search
def _get_url(self, view, page, sort, sort_desc): def _get_url(self, view=None, page=None, sort=None, sort_desc=None, search=None):
""" """
Generate page URL with current page, sort column and Generate page URL with current page, sort column and
other parameters. other parameters.
...@@ -389,8 +413,17 @@ class BaseModelView(BaseView): ...@@ -389,8 +413,17 @@ class BaseModelView(BaseView):
Sort column index Sort column index
`sort_desc` `sort_desc`
Use descending sorting order Use descending sorting order
`search`
Search query
""" """
return url_for(view, page=page, sort=sort, desc=sort_desc) if not search:
search = None
return url_for(view,
page=page,
sort=sort,
desc=sort_desc,
search=search)
# Views # Views
@expose('/') @expose('/')
...@@ -399,7 +432,7 @@ class BaseModelView(BaseView): ...@@ -399,7 +432,7 @@ class BaseModelView(BaseView):
List view List view
""" """
# Grab parameters from URL # Grab parameters from URL
page, sort_idx, sort_desc = self._get_extra_args() page, sort_idx, sort_desc, search = self._get_extra_args()
# Map column index to column name # Map column index to column name
sort_column = self._get_column_by_idx(sort_idx) sort_column = self._get_column_by_idx(sort_idx)
...@@ -407,7 +440,7 @@ class BaseModelView(BaseView): ...@@ -407,7 +440,7 @@ class BaseModelView(BaseView):
sort_column = sort_column[0] sort_column = sort_column[0]
# Get count and data # Get count and data
count, data = self.get_list(page, sort_column, sort_desc) count, data = self.get_list(page, sort_column, sort_desc, search)
# Calculate number of pages # Calculate number of pages
num_pages = count / self.page_size num_pages = count / self.page_size
...@@ -420,7 +453,7 @@ class BaseModelView(BaseView): ...@@ -420,7 +453,7 @@ class BaseModelView(BaseView):
if p == 0: if p == 0:
p = None p = None
return self._get_url('.index_view', p, sort_idx, sort_desc) return self._get_url('.index_view', p, sort_idx, sort_desc, search)
def sort_url(column, invert=False): def sort_url(column, invert=False):
desc = None desc = None
...@@ -428,7 +461,7 @@ class BaseModelView(BaseView): ...@@ -428,7 +461,7 @@ class BaseModelView(BaseView):
if invert and not sort_desc: if invert and not sort_desc:
desc = 1 desc = 1
return self._get_url('.index_view', page, column, desc) return self._get_url('.index_view', page, column, desc, search)
def get_value(obj, field): def get_value(obj, field):
return getattr(obj, field, None) return getattr(obj, field, None)
...@@ -440,7 +473,11 @@ class BaseModelView(BaseView): ...@@ -440,7 +473,11 @@ class BaseModelView(BaseView):
sortable_columns=self._sortable_columns, sortable_columns=self._sortable_columns,
# Stuff # Stuff
get_value=get_value, get_value=get_value,
return_url=self._get_url('.index_view', page, sort_idx, sort_desc), return_url=self._get_url('.index_view',
page,
sort_idx,
sort_desc,
search),
# Pagination # Pagination
pager_url=pager_url, pager_url=pager_url,
num_pages=num_pages, num_pages=num_pages,
...@@ -448,7 +485,15 @@ class BaseModelView(BaseView): ...@@ -448,7 +485,15 @@ class BaseModelView(BaseView):
# Sorting # Sorting
sort_column=sort_idx, sort_column=sort_idx,
sort_desc=sort_desc, sort_desc=sort_desc,
sort_url=sort_url sort_url=sort_url,
# Search
search_supported=self._search_supported,
clear_search_url=self._get_url('.index_view',
None,
sort_idx,
sort_desc,
None),
search=search
) )
@expose('/new/', methods=('GET', 'POST')) @expose('/new/', methods=('GET', 'POST'))
......
...@@ -2,6 +2,24 @@ ...@@ -2,6 +2,24 @@
{% import 'admin/lib.html' as lib %} {% import 'admin/lib.html' as lib %}
{% block body %} {% block body %}
{% if search_supported %}
<form method="GET" action="{{ return_url }}" class="well form-search">
{% if search %}
<a href="{{ clear_search_url }}">
<i class="icon-remove"></i>
</a>
{% endif %}
{% if sort_column is not none %}
<input type="hidden" name="sort" value="{{ sort_column }}"></input>
{% endif %}
{% if sort_desc %}
<input type="hidden" name="desc" value="{{ sort_desc }}"></input>
{% endif %}
<input type="text" name="search" value="{{ search or '' }}" class="span10 search-query"></input>
<button type="submit" class="btn">Search</button>
</form>
{% endif %}
<table class="table table-striped table-bordered model-list"> <table class="table table-striped table-bordered model-list">
<thead> <thead>
<tr> <tr>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment