diff --git a/flattentool/__init__.py b/flattentool/__init__.py index 004afb8..803ef76 100644 --- a/flattentool/__init__.py +++ b/flattentool/__init__.py @@ -1,4 +1,5 @@ import codecs +import datetime import json import sys from collections import OrderedDict @@ -192,12 +193,14 @@ def __float__(self): return self -def decimal_default(o): +def decimal_datetime_default(o): if isinstance(o, Decimal): if int(o) == o: return int(o) else: return NumberStr(o) + if isinstance(o, datetime.datetime): + return str(o) raise TypeError(repr(o) + " is not JSON serializable") @@ -372,12 +375,18 @@ def unflatten( else: if output_name is None: print( - json.dumps(base, indent=4, default=decimal_default, ensure_ascii=False) + json.dumps( + base, indent=4, default=decimal_datetime_default, ensure_ascii=False + ) ) else: with codecs.open(output_name, "w", encoding="utf-8") as fp: json.dump( - base, fp, indent=4, default=decimal_default, ensure_ascii=False + base, + fp, + indent=4, + default=decimal_datetime_default, + ensure_ascii=False, ) if cell_source_map: with codecs.open(cell_source_map, "w", encoding="utf-8") as fp: @@ -385,7 +394,7 @@ def unflatten( cell_source_map_data, fp, indent=4, - default=decimal_default, + default=decimal_datetime_default, ensure_ascii=False, ) if heading_source_map: @@ -394,6 +403,6 @@ def unflatten( heading_source_map_data, fp, indent=4, - default=decimal_default, + default=decimal_datetime_default, ensure_ascii=False, ) diff --git a/flattentool/tests/test_init.py b/flattentool/tests/test_init.py index 511404d..0f382dc 100644 --- a/flattentool/tests/test_init.py +++ b/flattentool/tests/test_init.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +import datetime import json from decimal import Decimal import pytest -from flattentool import decimal_default, unflatten +from flattentool import decimal_datetime_default, unflatten def original_cell_and_row_locations(data): @@ -51,9 +52,13 @@ def original_headings(heading_data): return headings -def test_decimal_default(): - assert json.dumps(Decimal("1.2"), default=decimal_default) == "1.2" - assert json.dumps(Decimal("42"), default=decimal_default) == "42" +def test_decimal_datetime_default(): + assert json.dumps(Decimal("1.2"), default=decimal_datetime_default) == "1.2" + assert json.dumps(Decimal("42"), default=decimal_datetime_default) == "42" + assert ( + json.dumps(datetime.datetime(2024, 1, 1), default=decimal_datetime_default) + == '"2024-01-01 00:00:00"' + ) def lines_strip_whitespace(text):