Skip to content

Commit

Permalink
adds csv reader tests (#1197)
Browse files Browse the repository at this point in the history
* adds csv reader tests

* does not translate new lines in text data writers

* fixes more tests
  • Loading branch information
rudolfix authored Apr 8, 2024
1 parent 06e666e commit a290e89
Show file tree
Hide file tree
Showing 16 changed files with 385 additions and 232 deletions.
2 changes: 1 addition & 1 deletion dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _flush_items(self, allow_empty_file: bool = False) -> None:
if self.writer_spec.is_binary_format:
self._file = self.open(self._file_name, "wb") # type: ignore
else:
self._file = self.open(self._file_name, "wt", encoding="utf-8") # type: ignore
self._file = self.open(self._file_name, "wt", encoding="utf-8", newline="") # type: ignore
self._writer = self.writer_cls(self._file, caps=self._caps) # type: ignore[assignment]
self._writer.write_header(self._current_columns)
# write buffer
Expand Down
36 changes: 35 additions & 1 deletion dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,16 @@ def writer_spec(cls) -> FileWriterSpec:

class CsvWriter(DataWriter):
def __init__(
self, f: IO[Any], caps: DestinationCapabilitiesContext = None, delimiter: str = ","
self,
f: IO[Any],
caps: DestinationCapabilitiesContext = None,
delimiter: str = ",",
bytes_encoding: str = "utf-8",
) -> None:
super().__init__(f, caps)
self.delimiter = delimiter
self.writer: csv.DictWriter[str] = None
self.bytes_encoding = bytes_encoding

def write_header(self, columns_schema: TTableSchemaColumns) -> None:
self._columns_schema = columns_schema
Expand All @@ -374,8 +379,37 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
quoting=csv.QUOTE_NONNUMERIC,
)
self.writer.writeheader()
# find row items that are of the complex type (could be abstracted out for use in other writers?)
self.complex_indices = [
i for i, field in columns_schema.items() if field["data_type"] == "complex"
]
# find row items that are of the complex type (could be abstracted out for use in other writers?)
self.bytes_indices = [
i for i, field in columns_schema.items() if field["data_type"] == "binary"
]

def write_data(self, rows: Sequence[Any]) -> None:
# convert bytes and json
if self.complex_indices or self.bytes_indices:
for row in rows:
for key in self.complex_indices:
if (value := row.get(key)) is not None:
row[key] = json.dumps(value)
for key in self.bytes_indices:
if (value := row.get(key)) is not None:
# assumed bytes value
try:
row[key] = value.decode(self.bytes_encoding)
except UnicodeError:
raise InvalidDataItem(
"csv",
"object",
f"'{key}' contains bytes that cannot be decoded with"
f" {self.bytes_encoding}. Remove binary columns or replace their"
" content with a hex representation: \\x... while keeping data"
" type as binary.",
)

self.writer.writerows(rows)
# count rows that got written
self.items_count += sum(len(row) for row in rows)
Expand Down
6 changes: 6 additions & 0 deletions docs/website/docs/dlt-ecosystem/file-formats/csv.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ info = pipeline.run(some_source(), loader_file_format="csv")
* dates are represented as ISO 8601

## Limitations
**arrow writer**

* binary columns are supported only if they contain valid UTF-8 characters
* complex (nested, struct) types are not supported

**csv writer**
* binary columns are supported only if they contain valid UTF-8 characters (easy to add more encodings)
* complex columns dumped with json.dumps
* **None** values are always quoted
43 changes: 7 additions & 36 deletions tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
)
from dlt.common.schema import TColumnSchema, TTableSchemaColumns


TArrowFormat = Literal["pandas", "table", "record_batch"]

from tests.utils import TArrowFormat, TestDataItemFormat, arrow_item_from_pandas

# _UUID = "c8209ee7-ee95-4b90-8c9f-f7a0f8b51014"
JSON_TYPED_DICT: StrAny = {
Expand Down Expand Up @@ -281,38 +279,8 @@ def assert_all_data_types_row(
assert db_mapping == expected_rows


def arrow_format_from_pandas(
df: Any,
object_format: TArrowFormat,
) -> Any:
from dlt.common.libs.pyarrow import pyarrow as pa

if object_format == "pandas":
return df
elif object_format == "table":
return pa.Table.from_pandas(df)
elif object_format == "record_batch":
return pa.RecordBatch.from_pandas(df)
raise ValueError("Unknown item type: " + object_format)


def arrow_item_from_table(
table: Any,
object_format: TArrowFormat,
) -> Any:
from dlt.common.libs.pyarrow import pyarrow as pa

if object_format == "pandas":
return table.to_pandas()
elif object_format == "table":
return table
elif object_format == "record_batch":
return table.to_batches()[0]
raise ValueError("Unknown item type: " + object_format)


def arrow_table_all_data_types(
object_format: TArrowFormat,
object_format: TestDataItemFormat,
include_json: bool = True,
include_time: bool = True,
include_binary: bool = True,
Expand Down Expand Up @@ -379,15 +347,18 @@ def arrow_table_all_data_types(
.drop(columns=["null"])
.to_dict("records")
)
return arrow_format_from_pandas(df, object_format), rows, data
if object_format == "object":
return rows, rows, data
else:
return arrow_item_from_pandas(df, object_format), rows, data


def prepare_shuffled_tables() -> Tuple[Any, Any, Any]:
from dlt.common.libs.pyarrow import remove_columns
from dlt.common.libs.pyarrow import pyarrow as pa

table, _, _ = arrow_table_all_data_types(
"table",
"arrow-table",
include_json=False,
include_not_normalized_name=False,
tz="Europe/Berlin",
Expand Down
12 changes: 6 additions & 6 deletions tests/extract/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def some_data(created_at=dlt.sources.incremental("data.items[0].created_at")):
assert s["last_value"] == 2


@pytest.mark.parametrize("item_type", ["arrow", "pandas"])
@pytest.mark.parametrize("item_type", ["arrow-table", "pandas"])
def test_nested_cursor_path_arrow_fails(item_type: TestDataItemFormat) -> None:
data = [{"data": {"items": [{"created_at": 2}]}}]
source_items = data_to_item_format(item_type, data)
Expand Down Expand Up @@ -708,7 +708,7 @@ def some_data(step, last_timestamp=dlt.sources.incremental("ts")):
p.run(r, destination="duckdb")


@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"json"})
@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"object"})
def test_start_value_set_to_last_value_arrow(item_type: TestDataItemFormat) -> None:
p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb")
now = pendulum.now()
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def some_data(
resource.apply_hints(incremental=dlt.sources.incremental("updated_at", initial_value=start_dt))
# and the data is naive. so it will work as expected with naive datetimes in the result set
data = list(resource)
if item_type == "json":
if item_type == "object":
# we do not convert data in arrow tables
assert data[0]["updated_at"].tzinfo is None

Expand All @@ -1059,7 +1059,7 @@ def some_data(
)
)
data = list(resource)
if item_type == "json":
if item_type == "object":
assert data[0]["updated_at"].tzinfo is None

# now use naive initial value but data is UTC
Expand All @@ -1070,7 +1070,7 @@ def some_data(
)
)
# will cause invalid comparison
if item_type == "json":
if item_type == "object":
with pytest.raises(InvalidStepFunctionArguments):
list(resource)
else:
Expand Down Expand Up @@ -1392,7 +1392,7 @@ def descending(
for chunk in chunks(count(start=48, step=-1), 10):
data = [{"updated_at": i, "package": package} for i in chunk]
# print(data)
yield data_to_item_format("json", data)
yield data_to_item_format("object", data)
if updated_at.can_close():
out_of_range.append(package)
return
Expand Down
4 changes: 2 additions & 2 deletions tests/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def expect_extracted_file(


class AssertItems(ItemTransform[TDataItem]):
def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "json") -> None:
def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "object") -> None:
self.expected_items = expected_items
self.item_type = item_type

Expand All @@ -56,7 +56,7 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]:


def data_item_to_list(from_type: TestDataItemFormat, values: List[TDataItem]):
if from_type in ["arrow", "arrow-batch"]:
if from_type in ["arrow-table", "arrow-batch"]:
return values[0].to_pylist()
elif from_type == "pandas":
return values[0].to_dict("records")
Expand Down
128 changes: 0 additions & 128 deletions tests/libs/test_arrow_csv_writer.py

This file was deleted.

Loading

0 comments on commit a290e89

Please sign in to comment.