Commit 1e6eefc7 authored by Paul Brown's avatar Paul Brown

add csrf token validation to actions

parent 2ce8dbc0
......@@ -3,7 +3,7 @@ from flask import request, redirect
from flask_admin import tools
from flask_admin._compat import text_type
from flask_admin.helpers import get_redirect_target
from flask_admin.helpers import get_redirect_target, flash_errors
def action(name, text, confirmation=None):
......@@ -104,16 +104,22 @@ class ActionsMixin(object):
If not provided, will return user to the return url in the form
or the list view.
"""
action = request.form.get('action')
ids = request.form.getlist('rowid')
form = self.action_form()
handler = self._actions_data.get(action)
if self.validate_form(form):
# using getlist instead of FieldList for backward compatibility
ids = request.form.getlist('rowid')
action = form.action.data
if handler and self.is_action_allowed(action):
response = handler[0](ids)
handler = self._actions_data.get(action)
if response is not None:
return response
if handler and self.is_action_allowed(action):
response = handler[0](ids)
if response is not None:
return response
else:
flash_errors(form, message='Failed to perform action. %(error)s')
if return_view:
url = self.get_url('.' + return_view)
......
......@@ -413,6 +413,19 @@ class BaseFileAdmin(BaseView, ActionsMixin):
return DeleteForm
def get_action_form(self):
"""
Create form class for model action.
Override to implement customized behavior.
"""
class ActionForm(self.form_base_class):
action = fields.HiddenField()
url = fields.HiddenField()
# rowid is retrieved using getlist, for backward compatibility
return ActionForm
def upload_form(self):
"""
Instantiate file upload form and return it.
......@@ -471,6 +484,18 @@ class BaseFileAdmin(BaseView, ActionsMixin):
else:
return delete_form_class()
def action_form(self):
"""
Instantiate action form and return it.
Override to implement custom behavior.
"""
action_form_class = self.get_action_form()
if request.form:
return action_form_class(request.form)
else:
return action_form_class()
def is_file_allowed(self, filename):
"""
Verify if file can be uploaded.
......@@ -812,6 +837,10 @@ class BaseFileAdmin(BaseView, ActionsMixin):
# Actions
actions, actions_confirmation = self.get_actions_list()
if actions:
action_form = self.action_form()
else:
action_form = None
def sort_url(column, invert=False):
desc = None
......@@ -829,6 +858,7 @@ class BaseFileAdmin(BaseView, ActionsMixin):
items=items,
actions=actions,
actions_confirmation=actions_confirmation,
action_form=action_form,
delete_form=delete_form,
sort_column=sort_column,
sort_desc=sort_desc,
......
......@@ -795,6 +795,7 @@ class BaseModelView(BaseView, ActionsMixin):
self._create_form_class = self.get_create_form()
self._edit_form_class = self.get_edit_form()
self._delete_form_class = self.get_delete_form()
self._action_form_class = self.get_action_form()
# List View In-Line Editing
if self.column_editable_list:
......@@ -1254,6 +1255,19 @@ class BaseModelView(BaseView, ActionsMixin):
return DeleteForm
def get_action_form(self):
"""
Create form class for a model action.
Override to implement customized behavior.
"""
class ActionForm(self.form_base_class):
action = HiddenField()
url = HiddenField()
# rowid is retrieved using getlist, for backward compatibility
return ActionForm
def create_form(self, obj=None):
"""
Instantiate model creation form and return it.
......@@ -1295,6 +1309,14 @@ class BaseModelView(BaseView, ActionsMixin):
"""
return self._list_form_class(get_form_data(), obj=obj)
def action_form(self, obj=None):
"""
Instantiate model action form and return it.
Override to implement custom behavior.
"""
return self._action_form_class(get_form_data(), obj=obj)
def validate_form(self, form):
"""
Validate the form on submit.
......@@ -1540,7 +1562,7 @@ class BaseModelView(BaseView, ActionsMixin):
"""
pass
def on_form_prefill (self, form, id):
def on_form_prefill(self, form, id):
"""
Perform additional actions to pre-fill the edit form.
......@@ -1874,6 +1896,10 @@ class BaseModelView(BaseView, ActionsMixin):
# Actions
actions, actions_confirmation = self.get_actions_list()
if actions:
action_form = self.action_form()
else:
action_form = None
clear_search_url = self._get_list_url(view_args.clone(page=0,
sort=view_args.sort,
......@@ -1886,6 +1912,7 @@ class BaseModelView(BaseView, ActionsMixin):
data=data,
list_forms=list_forms,
delete_form=delete_form,
action_form=action_form,
# List
list_columns=self._list_columns,
......
......@@ -14,11 +14,13 @@
{% macro form(actions, url) %}
{% if actions %}
<form id="action_form" action="{{ url }}" method="POST" style="display: none">
{% if csrf_token %}
{% if action_form.csrf_token %}
{{ action_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<input type="hidden" name="url" value="{{ return_url }}">
<input type="hidden" id="action" name="action" />
{{ action_form.url(value=return_url) }}
{{ action_form.action() }}
</form>
{% endif %}
{% endmacro %}
......
......@@ -14,11 +14,13 @@
{% macro form(actions, url) %}
{% if actions %}
<form id="action_form" action="{{ url }}" method="POST" style="display: none">
{% if csrf_token %}
{% if action_form.csrf_token %}
{{ action_form.csrf_token }}
{% elif csrf_token %}
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}"/>
{% endif %}
<input type="hidden" name="url" value="{{ return_url }}">
<input type="hidden" id="action" name="action" />
{{ action_form.url(value=return_url) }}
{{ action_form.action() }}
</form>
{% endif %}
{% endmacro %}
......
......@@ -379,7 +379,7 @@ def test_csrf():
# Create with CSRF token
rv = client.post('/admin/secure/new/', data=dict(name='test1',
csrf_token=csrf_token))
csrf_token=csrf_token))
eq_(rv.status_code, 302)
###############
......@@ -424,6 +424,23 @@ def test_csrf():
eq_(rv.status_code, 200)
ok_(u'Record was successfully deleted.' in rv.data.decode('utf-8'))
################
# actions
################
rv = client.get('/admin/secure/')
eq_(rv.status_code, 200)
ok_(u'name="csrf_token"' in rv.data.decode('utf-8'))
csrf_token = get_csrf_token(rv.data.decode('utf-8'))
# Delete without CSRF token, test validation errors
rv = client.post('/admin/secure/action/',
data=dict(rowid='1', url='/admin/secure/', action='delete'),
follow_redirects=True)
eq_(rv.status_code, 200)
ok_(u'Record was successfully deleted.' not in rv.data.decode('utf-8'))
ok_(u'Failed to perform action.' in rv.data.decode('utf-8'))
def test_custom_form():
app, admin = setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment