diff --git a/superset/config.py b/superset/config.py index 9c6adf599b07b..4a7434093c9b1 100644 --- a/superset/config.py +++ b/superset/config.py @@ -750,6 +750,9 @@ class D3Format(TypedDict, total=False): COLUMNAR_EXTENSIONS = {"parquet", "zip"} ALLOWED_EXTENSIONS = {*EXCEL_EXTENSIONS, *CSV_EXTENSIONS, *COLUMNAR_EXTENSIONS} +# Optional maximum file size in bytes when uploading a CSV +CSV_UPLOAD_MAX_SIZE = None + # CSV Options: key/value pairs that will be passed as argument to DataFrame.to_csv # method. # note: index option should not be overridden diff --git a/superset/forms.py b/superset/forms.py index f1e220ba952f7..1266870301e93 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -16,10 +16,12 @@ # under the License. """Contains the logic to create cohesive forms on the explore view""" import json +import os from typing import Any, Optional from flask_appbuilder.fieldwidgets import BS3TextFieldWidget -from wtforms import Field +from flask_babel import gettext as _ +from wtforms import Field, ValidationError class JsonListField(Field): @@ -53,6 +55,27 @@ def process_formdata(self, valuelist: list[str]) -> None: self.data = [] +class FileSizeLimit: # pylint: disable=too-few-public-methods + """Imposes an optional maximum filesize limit for uploaded files""" + + def __init__(self, max_size: Optional[int]): + self.max_size = max_size + + def __call__(self, form: dict[str, Any], field: Any) -> None: + if self.max_size is None: + return + + field.data.flush() + size = os.fstat(field.data.fileno()).st_size + if size > self.max_size: + raise ValidationError( + _( + "File size must be less than or equal to %(max_size)s bytes", + max_size=self.max_size, + ) + ) + + def filter_not_empty_values(values: Optional[list[Any]]) -> Optional[list[Any]]: """Returns a list of non empty values or None""" if not values: diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py index b906e5e70b880..767f4bb4e56cc 100644 --- a/superset/views/database/forms.py +++ b/superset/views/database/forms.py @@ -33,6 +33,7 @@ from superset import app, db, security_manager from superset.forms import ( CommaSeparatedListField, + FileSizeLimit, filter_not_empty_values, JsonListField, ) @@ -109,6 +110,7 @@ class CsvToDatabaseForm(UploadToDatabaseForm): description=_("Select a file to be uploaded to the database"), validators=[ FileRequired(), + FileSizeLimit(config["CSV_UPLOAD_MAX_SIZE"]), FileAllowed( config["ALLOWED_EXTENSIONS"].intersection(config["CSV_EXTENSIONS"]), _( diff --git a/tests/unit_tests/forms_tests.py b/tests/unit_tests/forms_tests.py new file mode 100644 index 0000000000000..0ede23551ff76 --- /dev/null +++ b/tests/unit_tests/forms_tests.py @@ -0,0 +1,54 @@ +import contextlib +import tempfile +from typing import Optional + +import pytest +from flask_wtf.file import FileField +from wtforms import Form, ValidationError + +from superset.forms import FileSizeLimit + + +def _get_test_form(size_limit: Optional[int]) -> Form: + class TestForm(Form): + test = FileField("test", validators=[FileSizeLimit(size_limit)]) + + return TestForm() + + +@contextlib.contextmanager +def _tempfile(contents: bytes): + with tempfile.NamedTemporaryFile() as f: + f.write(contents) + f.flush() + + yield f + + +def test_file_size_limit_pass() -> None: + """Permit files which do not exceed the size limit""" + limit = 100 + form = _get_test_form(limit) + + with _tempfile(b"." * limit) as f: + form.test.data = f + assert form.validate() is True + + +def test_file_size_limit_fail() -> None: + """Reject files which are too large""" + limit = 100 + form = _get_test_form(limit) + + with _tempfile(b"." * (limit + 1)) as f: + form.test.data = f + assert form.validate() is False + + +def test_file_size_limit_ignored_if_none() -> None: + """Permit files when there is no limit""" + form = _get_test_form(None) + + with _tempfile(b"." * 200) as f: + form.test.data = f + assert form.validate() is True