Commit 2fa0d0aa authored by Priit Laes's avatar Priit Laes

Add an argument to FileUploadField to disallow overwriting existing files.

Resolves one of the issues documented in #890
parent 14dec4f5
...@@ -101,7 +101,8 @@ class FileView(sqla.ModelView): ...@@ -101,7 +101,8 @@ class FileView(sqla.ModelView):
form_args = { form_args = {
'path': { 'path': {
'label': 'File', 'label': 'File',
'base_path': file_path 'base_path': file_path,
'allow_overwrite': False
} }
} }
......
...@@ -122,7 +122,7 @@ class FileUploadField(fields.StringField): ...@@ -122,7 +122,7 @@ class FileUploadField(fields.StringField):
def __init__(self, label=None, validators=None, def __init__(self, label=None, validators=None,
base_path=None, relative_path=None, base_path=None, relative_path=None,
namegen=None, allowed_extensions=None, namegen=None, allowed_extensions=None,
permission=0o666, permission=0o666, allow_overwrite=True,
**kwargs): **kwargs):
""" """
Constructor. Constructor.
...@@ -154,6 +154,11 @@ class FileUploadField(fields.StringField): ...@@ -154,6 +154,11 @@ class FileUploadField(fields.StringField):
:param allowed_extensions: :param allowed_extensions:
List of allowed extensions. If not provided, will allow any file. List of allowed extensions. If not provided, will allow any file.
:param allow_overwrite:
Whether to overwrite existing files in upload directory. Defaults to `True`.
.. versionadded:: 1.1.1
The `allow_overwrite` parameter was added.
""" """
self.base_path = base_path self.base_path = base_path
self.relative_path = relative_path self.relative_path = relative_path
...@@ -161,6 +166,7 @@ class FileUploadField(fields.StringField): ...@@ -161,6 +166,7 @@ class FileUploadField(fields.StringField):
self.namegen = namegen or namegen_filename self.namegen = namegen or namegen_filename
self.allowed_extensions = allowed_extensions self.allowed_extensions = allowed_extensions
self.permission = permission self.permission = permission
self._allow_overwrite = allow_overwrite
self._should_delete = False self._should_delete = False
...@@ -188,6 +194,8 @@ class FileUploadField(fields.StringField): ...@@ -188,6 +194,8 @@ class FileUploadField(fields.StringField):
def pre_validate(self, form): def pre_validate(self, form):
if self._is_uploaded_file(self.data) and not self.is_file_allowed(self.data.filename): if self._is_uploaded_file(self.data) and not self.is_file_allowed(self.data.filename):
raise ValidationError(gettext('Invalid file extension')) raise ValidationError(gettext('Invalid file extension'))
if self._allow_overwrite == False and os.path.exists(self._get_path(self.data.filename)):
raise ValidationError(gettext('File "%s" already exists.' % self.data.filename))
def process(self, formdata, data=unset_value): def process(self, formdata, data=unset_value):
if formdata: if formdata:
...@@ -253,6 +261,9 @@ class FileUploadField(fields.StringField): ...@@ -253,6 +261,9 @@ class FileUploadField(fields.StringField):
if not op.exists(op.dirname(path)): if not op.exists(op.dirname(path)):
os.makedirs(os.path.dirname(path), self.permission | 0o111) os.makedirs(os.path.dirname(path), self.permission | 0o111)
if self._allow_overwrite == False and os.path.exists(path):
raise ValueError(gettext('File "%s" already exists.' % path))
data.save(path) data.save(path)
return filename return filename
......
...@@ -40,6 +40,9 @@ def test_upload_field(): ...@@ -40,6 +40,9 @@ def test_upload_field():
class TestForm(form.BaseForm): class TestForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path) upload = form.FileUploadField('Upload', base_path=path)
class TestNoOverWriteForm(form.BaseForm):
upload = form.FileUploadField('Upload', base_path=path, allow_overwrite=False)
class Dummy(object): class Dummy(object):
pass pass
...@@ -74,6 +77,7 @@ def test_upload_field(): ...@@ -74,6 +77,7 @@ def test_upload_field():
# Check delete # Check delete
with app.test_request_context(method='POST', data={'_upload-delete': 'checked'}): with app.test_request_context(method='POST', data={'_upload-delete': 'checked'}):
my_form = TestForm(helpers.get_form_data()) my_form = TestForm(helpers.get_form_data())
ok_(my_form.validate()) ok_(my_form.validate())
...@@ -83,6 +87,24 @@ def test_upload_field(): ...@@ -83,6 +87,24 @@ def test_upload_field():
ok_(not op.exists(op.join(path, 'test2.txt'))) ok_(not op.exists(op.join(path, 'test2.txt')))
# Check overwrite
_remove_testfiles()
my_form_ow = TestNoOverWriteForm()
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(my_form_ow.validate())
my_form_ow.populate_obj(dummy)
eq_(dummy.upload, 'test1.txt')
ok_(op.exists(op.join(path, 'test1.txt')))
with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
my_form_ow = TestNoOverWriteForm(helpers.get_form_data())
ok_(not my_form_ow.validate())
_remove_testfiles()
def test_image_upload_field(): def test_image_upload_field():
app = Flask(__name__) app = Flask(__name__)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment