diff --git a/rows/fields.py b/rows/fields.py index bbd1729b..11d6af7e 100644 --- a/rows/fields.py +++ b/rows/fields.py @@ -63,7 +63,7 @@ def serialize(cls, value, *args, **kwargs): return value @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): """Deserialize a value just after importing it `cls.deserialize` should always return a value of type `cls.TYPE` or @@ -100,7 +100,7 @@ def serialize(cls, value, *args, **kwargs): return '' @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): if value is not None: if isinstance(value, six.binary_type): return value @@ -133,7 +133,7 @@ def serialize(cls, value, *args, **kwargs): return cls.SERIALIZED_VALUES[value] @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): value = super(BoolField, cls).deserialize(value) if value is None or isinstance(value, cls.TYPE): return value @@ -167,7 +167,7 @@ def serialize(cls, value, *args, **kwargs): return locale.format('%d', value, grouping=grouping) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): value = super(IntegerField, cls).deserialize(value) if value is None or isinstance(value, cls.TYPE): return value @@ -203,7 +203,7 @@ def serialize(cls, value, *args, **kwargs): return locale.format('%f', value, grouping=grouping) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): value = super(FloatField, cls).deserialize(value) if value is None or isinstance(value, cls.TYPE): return value @@ -242,7 +242,7 @@ def serialize(cls, value, *args, **kwargs): return locale.format(string_format, value, grouping=grouping) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): value = super(DecimalField, cls).deserialize(value) if value is None or isinstance(value, cls.TYPE): return value @@ -300,7 +300,7 @@ def serialize(cls, value, *args, **kwargs): return '{}%'.format(value) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): if isinstance(value, cls.TYPE): return value elif is_null(value): @@ -331,7 +331,7 @@ def serialize(cls, value, *args, **kwargs): return six.text_type(value.strftime(cls.OUTPUT_FORMAT)) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): value = super(DateField, cls).deserialize(value) if value is None or isinstance(value, cls.TYPE): return value @@ -360,7 +360,7 @@ def serialize(cls, value, *args, **kwargs): return six.text_type(value.isoformat()) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): value = super(DatetimeField, cls).deserialize(value) if value is None or isinstance(value, cls.TYPE): return value @@ -383,7 +383,7 @@ class TextField(Field): TYPE = (six.text_type, ) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): if value is None or isinstance(value, cls.TYPE): return value else: @@ -407,7 +407,7 @@ def serialize(cls, value, *args, **kwargs): return six.text_type(value) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): value = super(EmailField, cls).deserialize(value) if value is None or not value.strip(): return None @@ -432,7 +432,7 @@ def serialize(cls, value, *args, **kwargs): return json.dumps(value) @classmethod - def deserialize(cls, value, *args, **kwargs): + def deserialize(cls, value): if value is None or isinstance(value, cls.TYPE): return value else: @@ -472,8 +472,7 @@ def unique_values(values): return result -def detect_types(field_names, field_values, field_types=AVAILABLE_FIELD_TYPES, - *args, **kwargs): +def detect_types(field_names, field_values, field_types=AVAILABLE_FIELD_TYPES): """Where the magic happens""" # TODO: look strategy of csv.Sniffer.has_header @@ -512,7 +511,7 @@ def detect_types(field_names, field_values, field_types=AVAILABLE_FIELD_TYPES, cant_be = set() for type_ in possible_types: try: - type_.deserialize(value, *args, **kwargs) + type_.deserialize(value) except (ValueError, TypeError): cant_be.add(type_) for type_to_remove in cant_be: diff --git a/rows/plugins/plugin_csv.py b/rows/plugins/plugin_csv.py index a06b1571..7e386bdf 100644 --- a/rows/plugins/plugin_csv.py +++ b/rows/plugins/plugin_csv.py @@ -47,6 +47,8 @@ def discover_dialect(sample, encoding): # Could not detect dialect, fall back to 'excel' return unicodecsv.excel +def _pop_dialect_params(params, dialect, parameter): + return params.pop(parameter, getattr(dialect, parameter)) def import_from_csv(filename_or_fobj, encoding='utf-8', dialect=None, sample_size=8192, *args, **kwargs): @@ -58,12 +60,28 @@ def import_from_csv(filename_or_fobj, encoding='utf-8', dialect=None, filename, fobj = get_filename_and_fobj(filename_or_fobj, mode='rb') + if dialect is None: cursor = fobj.tell() dialect = discover_dialect(fobj.read(sample_size), encoding) fobj.seek(cursor) - - reader = unicodecsv.reader(fobj, encoding=encoding, dialect=dialect) + else: + dialect = unicodecsv.get_dialect(dialect) + + reader_kwargs = { + 'doublequote': _pop_dialect_params(kwargs, dialect, 'doublequote'), + 'lineterminator': _pop_dialect_params(kwargs, dialect, + 'lineterminator'), + 'skipinitialspace': _pop_dialect_params(kwargs, dialect, + 'skipinitialspace'), + 'escapechar': _pop_dialect_params(kwargs, dialect, 'escapechar'), + 'delimiter': _pop_dialect_params(kwargs, dialect, 'delimiter'), + 'quotechar': _pop_dialect_params(kwargs, dialect, 'quotechar'), + 'quoting': _pop_dialect_params(kwargs, dialect, 'quoting'), + 'strict': kwargs.get('strict', False)} + + reader = unicodecsv.reader( + fobj, encoding=encoding, dialect=dialect, **reader_kwargs) meta = {'imported_from': 'csv', 'filename': filename, diff --git a/tests/tests_plugin_csv.py b/tests/tests_plugin_csv.py index 9430f3be..668f0451 100644 --- a/tests/tests_plugin_csv.py +++ b/tests/tests_plugin_csv.py @@ -112,6 +112,20 @@ def test_import_from_csv_force_dialect(self, mocked_create_table): call_args = mocked_create_table.call_args_list[0] self.assertEqual(data, list(call_args[0][0])) + @mock.patch('rows.plugins.plugin_csv.create_table') + def test_import_from_csv_implements_full_reader_signature(self, mocked_create_table): + data, lines = make_csv_data(quote_char='"', + field_delimiter="\t", + line_delimiter="\n") + fobj = BytesIO() + fobj.write(lines.encode('utf-8')) + fobj.seek(0) + + rows.import_from_csv(fobj, doublequote=True, delimiter='\t', lineterminator='\n') + call_args = mocked_create_table.call_args_list[0] + self.assertEqual(data, list(call_args[0][0])) + + def test_detect_dialect_more_data(self): temp = tempfile.NamedTemporaryFile(delete=False) filename = '{}.{}'.format(temp.name, self.file_extension) @@ -250,3 +264,4 @@ def test_export_to_csv_accepts_dialect(self): result_1 = rows.export_to_csv(utils.table, dialect=csv.excel_tab) result_2 = rows.export_to_csv(utils.table, dialect=csv.excel) self.assertEqual(result_1.replace(b'\t', b','), result_2) +