Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete support for csv.Reader class signature. #232

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions rows/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def serialize(cls, value, *args, **kwargs):
return value

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
"""Deserialize a value just after importing it

`cls.deserialize` should always return a value of type `cls.TYPE` or
Expand Down Expand Up @@ -100,7 +100,7 @@ def serialize(cls, value, *args, **kwargs):
return ''

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
if value is not None:
if isinstance(value, six.binary_type):
return value
Expand Down Expand Up @@ -133,7 +133,7 @@ def serialize(cls, value, *args, **kwargs):
return cls.SERIALIZED_VALUES[value]

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
value = super(BoolField, cls).deserialize(value)
if value is None or isinstance(value, cls.TYPE):
return value
Expand Down Expand Up @@ -167,7 +167,7 @@ def serialize(cls, value, *args, **kwargs):
return locale.format('%d', value, grouping=grouping)

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
value = super(IntegerField, cls).deserialize(value)
if value is None or isinstance(value, cls.TYPE):
return value
Expand Down Expand Up @@ -203,7 +203,7 @@ def serialize(cls, value, *args, **kwargs):
return locale.format('%f', value, grouping=grouping)

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
value = super(FloatField, cls).deserialize(value)
if value is None or isinstance(value, cls.TYPE):
return value
Expand Down Expand Up @@ -242,7 +242,7 @@ def serialize(cls, value, *args, **kwargs):
return locale.format(string_format, value, grouping=grouping)

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
value = super(DecimalField, cls).deserialize(value)
if value is None or isinstance(value, cls.TYPE):
return value
Expand Down Expand Up @@ -300,7 +300,7 @@ def serialize(cls, value, *args, **kwargs):
return '{}%'.format(value)

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
if isinstance(value, cls.TYPE):
return value
elif is_null(value):
Expand Down Expand Up @@ -331,7 +331,7 @@ def serialize(cls, value, *args, **kwargs):
return six.text_type(value.strftime(cls.OUTPUT_FORMAT))

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
value = super(DateField, cls).deserialize(value)
if value is None or isinstance(value, cls.TYPE):
return value
Expand Down Expand Up @@ -360,7 +360,7 @@ def serialize(cls, value, *args, **kwargs):
return six.text_type(value.isoformat())

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
value = super(DatetimeField, cls).deserialize(value)
if value is None or isinstance(value, cls.TYPE):
return value
Expand All @@ -383,7 +383,7 @@ class TextField(Field):
TYPE = (six.text_type, )

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
if value is None or isinstance(value, cls.TYPE):
return value
else:
Expand All @@ -407,7 +407,7 @@ def serialize(cls, value, *args, **kwargs):
return six.text_type(value)

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
value = super(EmailField, cls).deserialize(value)
if value is None or not value.strip():
return None
Expand All @@ -432,7 +432,7 @@ def serialize(cls, value, *args, **kwargs):
return json.dumps(value)

@classmethod
def deserialize(cls, value, *args, **kwargs):
def deserialize(cls, value):
if value is None or isinstance(value, cls.TYPE):
return value
else:
Expand Down Expand Up @@ -472,8 +472,7 @@ def unique_values(values):
return result


def detect_types(field_names, field_values, field_types=AVAILABLE_FIELD_TYPES,
*args, **kwargs):
def detect_types(field_names, field_values, field_types=AVAILABLE_FIELD_TYPES):
"""Where the magic happens"""

# TODO: look strategy of csv.Sniffer.has_header
Expand Down Expand Up @@ -512,7 +511,7 @@ def detect_types(field_names, field_values, field_types=AVAILABLE_FIELD_TYPES,
cant_be = set()
for type_ in possible_types:
try:
type_.deserialize(value, *args, **kwargs)
type_.deserialize(value)
except (ValueError, TypeError):
cant_be.add(type_)
for type_to_remove in cant_be:
Expand Down
22 changes: 20 additions & 2 deletions rows/plugins/plugin_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def discover_dialect(sample, encoding):
# Could not detect dialect, fall back to 'excel'
return unicodecsv.excel

def _pop_dialect_params(params, dialect, parameter):
return params.pop(parameter, getattr(dialect, parameter))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason for poping the parameter? Should we only get to avoid messing with them? The downside of pop would be someone trying to do something else with kwargs attributes after get_dialect_parameters only to find out the it's not there anymore.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason of kwargs poping is not messing the future calls that use kwargs in the same functions.
Using *args and **kwargs is always a decision on where the code will be messed up.


def import_from_csv(filename_or_fobj, encoding='utf-8', dialect=None,
sample_size=8192, *args, **kwargs):
Expand All @@ -58,12 +60,28 @@ def import_from_csv(filename_or_fobj, encoding='utf-8', dialect=None,

filename, fobj = get_filename_and_fobj(filename_or_fobj, mode='rb')


if dialect is None:
cursor = fobj.tell()
dialect = discover_dialect(fobj.read(sample_size), encoding)
fobj.seek(cursor)

reader = unicodecsv.reader(fobj, encoding=encoding, dialect=dialect)
else:
dialect = unicodecsv.get_dialect(dialect)

reader_kwargs = {
'doublequote': _pop_dialect_params(kwargs, dialect, 'doublequote'),
'lineterminator': _pop_dialect_params(kwargs, dialect,
'lineterminator'),
'skipinitialspace': _pop_dialect_params(kwargs, dialect,
'skipinitialspace'),
'escapechar': _pop_dialect_params(kwargs, dialect, 'escapechar'),
'delimiter': _pop_dialect_params(kwargs, dialect, 'delimiter'),
'quotechar': _pop_dialect_params(kwargs, dialect, 'quotechar'),
'quoting': _pop_dialect_params(kwargs, dialect, 'quoting'),
'strict': kwargs.get('strict', False)}

reader = unicodecsv.reader(
fobj, encoding=encoding, dialect=dialect, **reader_kwargs)

meta = {'imported_from': 'csv',
'filename': filename,
Expand Down
15 changes: 15 additions & 0 deletions tests/tests_plugin_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def test_import_from_csv_force_dialect(self, mocked_create_table):
call_args = mocked_create_table.call_args_list[0]
self.assertEqual(data, list(call_args[0][0]))

@mock.patch('rows.plugins.plugin_csv.create_table')
def test_import_from_csv_implements_full_reader_signature(self, mocked_create_table):
data, lines = make_csv_data(quote_char='"',
field_delimiter="\t",
line_delimiter="\n")
fobj = BytesIO()
fobj.write(lines.encode('utf-8'))
fobj.seek(0)

rows.import_from_csv(fobj, doublequote=True, delimiter='\t', lineterminator='\n')
call_args = mocked_create_table.call_args_list[0]
self.assertEqual(data, list(call_args[0][0]))


def test_detect_dialect_more_data(self):
temp = tempfile.NamedTemporaryFile(delete=False)
filename = '{}.{}'.format(temp.name, self.file_extension)
Expand Down Expand Up @@ -250,3 +264,4 @@ def test_export_to_csv_accepts_dialect(self):
result_1 = rows.export_to_csv(utils.table, dialect=csv.excel_tab)
result_2 = rows.export_to_csv(utils.table, dialect=csv.excel)
self.assertEqual(result_1.replace(b'\t', b','), result_2)