Commit e8b1b6b5 authored by Serge S. Koval's avatar Serge S. Koval

Merge pull request #1038 from tandreas/csv_export

Implement CSV export for BaseModelView.
parents 9db8c9e5 51113b95
...@@ -298,6 +298,12 @@ To **manage related models inline**:: ...@@ -298,6 +298,12 @@ To **manage related models inline**::
These inline forms can be customised. Have a look at the API documentation for These inline forms can be customised. Have a look at the API documentation for
:meth:`~flask_admin.contrib.sqla.ModelView.inline_models`. :meth:`~flask_admin.contrib.sqla.ModelView.inline_models`.
To **enable csv export** of the model view::
can_export = True
This will add a button to the model view that exports records, truncating at :attr:`~flask_admin.model.BaseModelView.max_export_rows`.
Adding Your Own Views Adding Your Own Views
===================== =====================
......
...@@ -31,6 +31,10 @@ if not PY2: ...@@ -31,6 +31,10 @@ if not PY2:
return str(s) return str(s)
def csv_encode(s):
''' Returns unicode string expected by Python 3's csv module '''
return as_unicode(s)
# Various tools # Various tools
from functools import reduce from functools import reduce
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
...@@ -50,6 +54,10 @@ else: ...@@ -50,6 +54,10 @@ else:
return unicode(s) return unicode(s)
def csv_encode(s):
''' Returns byte string expected by Python 2's csv module '''
return as_unicode(s).encode('utf-8')
# Helpers # Helpers
reduce = __builtins__['reduce'] if isinstance(__builtins__, dict) else __builtins__.reduce reduce = __builtins__['reduce'] if isinstance(__builtins__, dict) else __builtins__.reduce
from urlparse import urljoin, urlparse from urlparse import urljoin, urlparse
......
import warnings import warnings
import re import re
import csv
import time
from werkzeug import secure_filename
from flask import (request, redirect, flash, abort, json, Response, from flask import (request, redirect, flash, abort, json, Response,
get_flashed_messages) get_flashed_messages, stream_with_context)
from jinja2 import contextfunction from jinja2 import contextfunction
from wtforms.fields import HiddenField from wtforms.fields import HiddenField
from wtforms.fields.core import UnboundField from wtforms.fields.core import UnboundField
...@@ -18,12 +22,11 @@ from flask_admin.helpers import (get_form_data, validate_form_on_submit, ...@@ -18,12 +22,11 @@ from flask_admin.helpers import (get_form_data, validate_form_on_submit,
get_redirect_target, flash_errors) get_redirect_target, flash_errors)
from flask_admin.tools import rec_getattr from flask_admin.tools import rec_getattr
from flask_admin._backwards import ObsoleteAttr from flask_admin._backwards import ObsoleteAttr
from flask_admin._compat import iteritems, OrderedDict, as_unicode from flask_admin._compat import iteritems, OrderedDict, as_unicode, csv_encode
from .helpers import prettify_name, get_mdict_item_or_list from .helpers import prettify_name, get_mdict_item_or_list
from .ajax import AjaxModelLoader from .ajax import AjaxModelLoader
from .fields import ListEditableFieldList from .fields import ListEditableFieldList
# Used to generate filter query string name # Used to generate filter query string name
filter_char_re = re.compile('[^a-z0-9 ]') filter_char_re = re.compile('[^a-z0-9 ]')
filter_compact_re = re.compile(' +') filter_compact_re = re.compile(' +')
...@@ -95,6 +98,9 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -95,6 +98,9 @@ class BaseModelView(BaseView, ActionsMixin):
when there are too many columns to display in the list_view. when there are too many columns to display in the list_view.
""" """
can_export = False
"""Is model list export allowed"""
# Templates # Templates
list_template = 'admin/model/list.html' list_template = 'admin/model/list.html'
"""Default list view template""" """Default list view template"""
...@@ -194,14 +200,25 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -194,14 +200,25 @@ class BaseModelView(BaseView, ActionsMixin):
pass pass
""" """
column_formatters_export = None
"""
Dictionary of list view column formatters to be used for export.
Defaults to column_formatters when set to None.
Functions the same way as column_formatters except
that macros are not supported.
"""
column_type_formatters = ObsoleteAttr('column_type_formatters', 'list_type_formatters', None) column_type_formatters = ObsoleteAttr('column_type_formatters', 'list_type_formatters', None)
""" """
Dictionary of value type formatters to be used in the list view. Dictionary of value type formatters to be used in the list view.
By default, two types are formatted: By default, three types are formatted:
1. ``None`` will be displayed as an empty string 1. ``None`` will be displayed as an empty string
2. ``bool`` will be displayed as a checkmark if it is ``True`` 2. ``bool`` will be displayed as a checkmark if it is ``True``
3. ``list`` will be joined using ', '
If you don't like the default behavior and don't want any type formatters If you don't like the default behavior and don't want any type formatters
applied, just override this property with an empty dictionary:: applied, just override this property with an empty dictionary::
...@@ -237,6 +254,18 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -237,6 +254,18 @@ class BaseModelView(BaseView, ActionsMixin):
pass pass
""" """
column_type_formatters_export = None
"""
Dictionary of value type formatters to be used in the export.
By default, two types are formatted:
1. ``None`` will be displayed as an empty string
2. ``list`` will be joined using ', '
Functions the same way as column_type_formatters.
"""
column_labels = ObsoleteAttr('column_labels', 'rename_columns', None) column_labels = ObsoleteAttr('column_labels', 'rename_columns', None)
""" """
Dictionary where key is column name and value is string to display. Dictionary where key is column name and value is string to display.
...@@ -579,6 +608,12 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -579,6 +608,12 @@ class BaseModelView(BaseView, ActionsMixin):
action_disallowed_list = ['delete'] action_disallowed_list = ['delete']
""" """
# Export settings
export_max_rows = None
"""
Maximum number of rows allowed for export.
"""
# Various settings # Various settings
page_size = 20 page_size = 20
""" """
...@@ -732,10 +767,17 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -732,10 +767,17 @@ class BaseModelView(BaseView, ActionsMixin):
else: else:
self.column_choices = self._column_choices_map = dict() self.column_choices = self._column_choices_map = dict()
# Column formatters
if self.column_formatters_export is None:
self.column_formatters_export = self.column_formatters
# Type formatters # Type formatters
if self.column_type_formatters is None: if self.column_type_formatters is None:
self.column_type_formatters = dict(typefmt.BASE_FORMATTERS) self.column_type_formatters = dict(typefmt.BASE_FORMATTERS)
if self.column_type_formatters_export is None:
self.column_type_formatters_export = dict(typefmt.EXPORT_FORMATTERS)
if self.column_descriptions is None: if self.column_descriptions is None:
self.column_descriptions = dict() self.column_descriptions = dict()
...@@ -1214,7 +1256,8 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1214,7 +1256,8 @@ class BaseModelView(BaseView, ActionsMixin):
return None return None
# Database-related API # Database-related API
def get_list(self, page, sort_field, sort_desc, search, filters): def get_list(self, page, sort_field, sort_desc, search, filters,
page_size=None):
""" """
Return a paginated and sorted list of models from the data source. Return a paginated and sorted list of models from the data source.
...@@ -1231,6 +1274,10 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1231,6 +1274,10 @@ class BaseModelView(BaseView, ActionsMixin):
:param filters: :param filters:
List of filter tuples. First value in a tuple is a search List of filter tuples. First value in a tuple is a search
index, second value is a search value. index, second value is a search value.
:param page_size:
Number of results. Defaults to ModelView's page_size. Can be
overriden to change the page_size limit. Removing the page_size
limit requires setting page_size to 0 or False.
""" """
raise NotImplementedError('Please implement get_list method') raise NotImplementedError('Please implement get_list method')
...@@ -1493,19 +1540,23 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1493,19 +1540,23 @@ class BaseModelView(BaseView, ActionsMixin):
""" """
return rec_getattr(model, name) return rec_getattr(model, name)
@contextfunction def _get_list_value(self, context, model, name, column_formatters,
def get_list_value(self, context, model, name): column_type_formatters):
""" """
Returns the value to be displayed in the list view Returns the value to be displayed.
:param context: :param context:
:py:class:`jinja2.runtime.Context` :py:class:`jinja2.runtime.Context` if available
:param model: :param model:
Model instance Model instance
:param name: :param name:
Field name Field name
:param column_formatters:
column_formatters to be used.
:param column_type_formatters:
column_type_formatters to be used.
""" """
column_fmt = self.column_formatters.get(name) column_fmt = column_formatters.get(name)
if column_fmt is not None: if column_fmt is not None:
value = column_fmt(self, context, model, name) value = column_fmt(self, context, model, name)
else: else:
...@@ -1516,7 +1567,7 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1516,7 +1567,7 @@ class BaseModelView(BaseView, ActionsMixin):
return choices_map.get(value) or value return choices_map.get(value) or value
type_fmt = None type_fmt = None
for typeobj, formatter in self.column_type_formatters.items(): for typeobj, formatter in column_type_formatters.items():
if isinstance(value, typeobj): if isinstance(value, typeobj):
type_fmt = formatter type_fmt = formatter
break break
...@@ -1525,6 +1576,44 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1525,6 +1576,44 @@ class BaseModelView(BaseView, ActionsMixin):
return value return value
@contextfunction
def get_list_value(self, context, model, name):
"""
Returns the value to be displayed in the list view
:param context:
:py:class:`jinja2.runtime.Context`
:param model:
Model instance
:param name:
Field name
"""
return self._get_list_value(
context,
model,
name,
self.column_formatters,
self.column_type_formatters,
)
def get_export_value(self, model, name):
"""
Returns the value to be displayed in export.
Allows export to use different (non HTML) formatters.
:param model:
Model instance
:param name:
Field name
"""
return self._get_list_value(
None,
model,
name,
self.column_formatters_export,
self.column_type_formatters_export,
)
# AJAX references # AJAX references
def _process_ajax_references(self): def _process_ajax_references(self):
""" """
...@@ -1823,6 +1912,76 @@ class BaseModelView(BaseView, ActionsMixin): ...@@ -1823,6 +1912,76 @@ class BaseModelView(BaseView, ActionsMixin):
""" """
return self.handle_action() return self.handle_action()
@expose('/export/csv/')
def export_csv(self):
"""
Export a CSV of records.
"""
return_url = get_redirect_target() or self.get_url('.index_view')
if not self.can_export:
flash(gettext('Permission denied.'))
return redirect(return_url)
# Macros in column_formatters are not supported.
# Macros will have a function name 'inner'
# This causes non-macro functions named 'inner' not work.
for col, func in iteritems(self.column_formatters):
if func.__name__ == 'inner':
raise NotImplementedError(
'Macros not implemented. Override with '
'column_formatters_export. Column: %s' % (col,)
)
# Grab parameters from URL
view_args = self._get_list_extra_args()
# Map column index to column name
sort_column = self._get_column_by_idx(view_args.sort)
if sort_column is not None:
sort_column = sort_column[0]
# Get count and data
count, data = self.get_list(0, sort_column, view_args.sort_desc,
view_args.search, view_args.filters,
page_size=self.export_max_rows)
# https://docs.djangoproject.com/en/1.8/howto/outputting-csv/
class Echo(object):
"""
An object that implements just the write method of the file-like
interface.
"""
def write(self, value):
"""
Write the value by returning it, instead of storing
in a buffer.
"""
return value
writer = csv.writer(Echo())
def generate():
# Append the column titles at the beginning
titles = [csv_encode(c[1]) for c in self._list_columns]
yield writer.writerow(titles)
for row in data:
vals = [csv_encode(self.get_export_value(row, c[0]))
for c in self._list_columns]
yield writer.writerow(vals)
filename = '%s_%s.csv' % (self.name,
time.strftime("%Y-%m-%d_%H-%M-%S"))
disposition = 'attachment;filename=%s' % (secure_filename(filename),)
return Response(
stream_with_context(generate()),
headers={'Content-Disposition': disposition},
mimetype='text/csv'
)
@expose('/ajax/lookup/') @expose('/ajax/lookup/')
def ajax_lookup(self): def ajax_lookup(self):
name = request.args.get('name') name = request.args.get('name')
......
...@@ -49,3 +49,8 @@ BASE_FORMATTERS = { ...@@ -49,3 +49,8 @@ BASE_FORMATTERS = {
bool: bool_formatter, bool: bool_formatter,
list: list_formatter, list: list_formatter,
} }
EXPORT_FORMATTERS = {
type(None): empty_formatter,
list: list_formatter,
}
...@@ -26,6 +26,12 @@ ...@@ -26,6 +26,12 @@
</li> </li>
{% endif %} {% endif %}
{% if admin_view.can_export %}
<li>
<a href="{{ get_url('.export_csv', **request.args) }}" title="{{ _gettext('Export') }}">{{ _gettext('Export') }}</a>
</li>
{% endif %}
{% if filters %} {% if filters %}
<li class="dropdown"> <li class="dropdown">
{{ model_layout.filter_options() }} {{ model_layout.filter_options() }}
......
...@@ -26,6 +26,12 @@ ...@@ -26,6 +26,12 @@
</li> </li>
{% endif %} {% endif %}
{% if admin_view.can_export %}
<li>
<a href="{{ get_url('.export_csv', **request.args) }}" title="{{ _gettext('Export') }}">{{ _gettext('Export') }}</a>
</li>
{% endif %}
{% if filters %} {% if filters %}
<li class="dropdown"> <li class="dropdown">
{{ model_layout.filter_options() }} {{ model_layout.filter_options() }}
......
...@@ -12,6 +12,8 @@ from wtforms import fields ...@@ -12,6 +12,8 @@ from wtforms import fields
from flask_admin import Admin, form from flask_admin import Admin, form
from flask_admin._compat import iteritems, itervalues from flask_admin._compat import iteritems, itervalues
from flask_admin.model import base, filters from flask_admin.model import base, filters
from flask_admin.model.template import macro
from itertools import islice
def wtforms2_and_up(func): def wtforms2_and_up(func):
...@@ -46,8 +48,8 @@ class SimpleFilter(filters.BaseFilter): ...@@ -46,8 +48,8 @@ class SimpleFilter(filters.BaseFilter):
class MockModelView(base.BaseModelView): class MockModelView(base.BaseModelView):
def __init__(self, model, name=None, category=None, endpoint=None, url=None, def __init__(self, model, data=None, name=None, category=None,
**kwargs): endpoint=None, url=None, **kwargs):
# Allow to set any attributes from parameters # Allow to set any attributes from parameters
for k, v in iteritems(kwargs): for k, v in iteritems(kwargs):
setattr(self, k, v) setattr(self, k, v)
...@@ -60,9 +62,12 @@ class MockModelView(base.BaseModelView): ...@@ -60,9 +62,12 @@ class MockModelView(base.BaseModelView):
self.search_arguments = [] self.search_arguments = []
self.all_models = {1: Model(1), if data is None:
2: Model(2)} self.all_models = {1: Model(1), 2: Model(2)}
self.last_id = 3 else:
self.all_models = data
self.last_id = len(self.all_models) + 1
# Scaffolding # Scaffolding
def get_pk_value(self, model): def get_pk_value(self, model):
...@@ -89,9 +94,12 @@ class MockModelView(base.BaseModelView): ...@@ -89,9 +94,12 @@ class MockModelView(base.BaseModelView):
return Form return Form
# Data # Data
def get_list(self, page, sort_field, sort_desc, search, filters): def get_list(self, page, sort_field, sort_desc, search, filters,
page_size=None):
self.search_arguments.append((page, sort_field, sort_desc, search, filters)) self.search_arguments.append((page, sort_field, sort_desc, search, filters))
return len(self.all_models), itervalues(self.all_models) count = len(self.all_models)
data = islice(itervalues(self.all_models), 0, page_size)
return count, data
def get_one(self, id): def get_one(self, id):
return self.all_models.get(int(id)) return self.all_models.get(int(id))
...@@ -538,3 +546,120 @@ def check_class_name(): ...@@ -538,3 +546,120 @@ def check_class_name():
view = DummyView(Model) view = DummyView(Model)
eq_(view.name, 'Dummy View') eq_(view.name, 'Dummy View')
def test_export_csv():
app, admin = setup()
client = app.test_client()
# test redirect when csv export is disabled
view = MockModelView(Model, column_list=['col1', 'col2'], endpoint="test")
admin.add_view(view)
rv = client.get('/admin/test/export/csv/')
eq_(rv.status_code, 302)
# basic test of csv export with a few records
view_data = {
1: Model(1, "col1_1", "col2_1"),
2: Model(2, "col1_2", "col2_2"),
3: Model(3, "col1_3", "col2_3"),
}
view = MockModelView(Model, view_data, can_export=True,
column_list=['col1', 'col2'])
admin.add_view(view)
rv = client.get('/admin/model/export/csv/')
data = rv.data.decode('utf-8')
eq_(rv.mimetype, 'text/csv')
eq_(rv.status_code, 200)
ok_("Col1,Col2\r\n"
"col1_1,col2_1\r\n"
"col1_2,col2_2\r\n"
"col1_3,col2_3\r\n" == data)
# test utf8 characters in csv export
view_data[4] = Model(1, u'\u2013ut8_1\u2013', u'\u2013utf8_2\u2013')
view = MockModelView(Model, view_data, can_export=True,
column_list=['col1', 'col2'], endpoint="utf8")
admin.add_view(view)
rv = client.get('/admin/utf8/export/csv/')
data = rv.data.decode('utf-8')
eq_(rv.status_code, 200)
ok_(u'\u2013ut8_1\u2013,\u2013utf8_2\u2013\r\n' in data)
# test row limit
view_data = {
1: Model(1, "col1_1", "col2_1"),
2: Model(2, "col1_2", "col2_2"),
3: Model(3, "col1_3", "col2_3"),
}
view = MockModelView(Model, view_data, can_export=True,
column_list=['col1', 'col2'], export_max_rows=2,
endpoint='row_limit_2')
admin.add_view(view)
rv = client.get('/admin/row_limit_2/export/csv/')
data = rv.data.decode('utf-8')
eq_(rv.status_code, 200)
ok_("Col1,Col2\r\n"
"col1_1,col2_1\r\n"
"col1_2,col2_2\r\n" == data)
# test None type, integer type, column_labels, and column_formatters
view_data = {
1: Model(1, "col1_1", 1),
2: Model(2, "col1_2", 2),
3: Model(3, None, 3),
}
view = MockModelView(
Model, view_data, can_export=True, column_list=['col1', 'col2'],
column_labels={'col1': 'Str Field', 'col2': 'Int Field'},
column_formatters=dict(col2=lambda v, c, m, p: m.col2*2),
endpoint="types_and_formatters"
)
admin.add_view(view)
rv = client.get('/admin/types_and_formatters/export/csv/')
data = rv.data.decode('utf-8')
eq_(rv.status_code, 200)
ok_("Str Field,Int Field\r\n"
"col1_1,2\r\n"
"col1_2,4\r\n"
",6\r\n" == data)
# test column_formatters_export and column_formatters_export
type_formatters = {type(None): lambda view, value: "null"}
view = MockModelView(
Model, view_data, can_export=True, column_list=['col1', 'col2'],
column_formatters_export=dict(col2=lambda v, c, m, p: m.col2*3),
column_formatters=dict(col2=lambda v, c, m, p: m.col2*2), # overridden
column_type_formatters_export=type_formatters,
endpoint="export_types_and_formatters"
)
admin.add_view(view)
rv = client.get('/admin/export_types_and_formatters/export/csv/')
data = rv.data.decode('utf-8')
eq_(rv.status_code, 200)
ok_("Col1,Col2\r\n"
"col1_1,3\r\n"
"col1_2,6\r\n"
"null,9\r\n" == data)
# Macros are not implemented for csv export yet and will throw an error
view = MockModelView(
Model, can_export=True, column_list=['col1', 'col2'],
column_formatters=dict(col1=macro('render_macro')),
endpoint="macro_exception"
)
admin.add_view(view)
rv = client.get('/admin/macro_exception/export/csv/')
data = rv.data.decode('utf-8')
eq_(rv.status_code, 500)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment