Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds csv reader tests #1197

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _flush_items(self, allow_empty_file: bool = False) -> None:
if self.writer_spec.is_binary_format:
self._file = self.open(self._file_name, "wb") # type: ignore
else:
self._file = self.open(self._file_name, "wt", encoding="utf-8") # type: ignore
self._file = self.open(self._file_name, "wt", encoding="utf-8", newline="") # type: ignore
self._writer = self.writer_cls(self._file, caps=self._caps) # type: ignore[assignment]
self._writer.write_header(self._current_columns)
# write buffer
Expand Down
36 changes: 35 additions & 1 deletion dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,16 @@ def writer_spec(cls) -> FileWriterSpec:

class CsvWriter(DataWriter):
def __init__(
self, f: IO[Any], caps: DestinationCapabilitiesContext = None, delimiter: str = ","
self,
f: IO[Any],
caps: DestinationCapabilitiesContext = None,
delimiter: str = ",",
bytes_encoding: str = "utf-8",
) -> None:
super().__init__(f, caps)
self.delimiter = delimiter
self.writer: csv.DictWriter[str] = None
self.bytes_encoding = bytes_encoding

def write_header(self, columns_schema: TTableSchemaColumns) -> None:
self._columns_schema = columns_schema
Expand All @@ -374,8 +379,37 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
quoting=csv.QUOTE_NONNUMERIC,
)
self.writer.writeheader()
# find row items that are of the complex type (could be abstracted out for use in other writers?)
self.complex_indices = [
i for i, field in columns_schema.items() if field["data_type"] == "complex"
]
# find row items that are of the complex type (could be abstracted out for use in other writers?)
self.bytes_indices = [
i for i, field in columns_schema.items() if field["data_type"] == "binary"
]

def write_data(self, rows: Sequence[Any]) -> None:
# convert bytes and json
if self.complex_indices or self.bytes_indices:
for row in rows:
for key in self.complex_indices:
if (value := row.get(key)) is not None:
row[key] = json.dumps(value)
for key in self.bytes_indices:
if (value := row.get(key)) is not None:
# assumed bytes value
try:
row[key] = value.decode(self.bytes_encoding)
except UnicodeError:
raise InvalidDataItem(
"csv",
"object",
f"'{key}' contains bytes that cannot be decoded with"
f" {self.bytes_encoding}. Remove binary columns or replace their"
" content with a hex representation: \\x... while keeping data"
" type as binary.",
)

self.writer.writerows(rows)
# count rows that got written
self.items_count += sum(len(row) for row in rows)
Expand Down
6 changes: 6 additions & 0 deletions docs/website/docs/dlt-ecosystem/file-formats/csv.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ info = pipeline.run(some_source(), loader_file_format="csv")
* dates are represented as ISO 8601

## Limitations
**arrow writer**

* binary columns are supported only if they contain valid UTF-8 characters
* complex (nested, struct) types are not supported

**csv writer**
* binary columns are supported only if they contain valid UTF-8 characters (easy to add more encodings)
* complex columns dumped with json.dumps
* **None** values are always quoted
43 changes: 7 additions & 36 deletions tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
)
from dlt.common.schema import TColumnSchema, TTableSchemaColumns


TArrowFormat = Literal["pandas", "table", "record_batch"]

from tests.utils import TArrowFormat, TestDataItemFormat, arrow_item_from_pandas

# _UUID = "c8209ee7-ee95-4b90-8c9f-f7a0f8b51014"
JSON_TYPED_DICT: StrAny = {
Expand Down Expand Up @@ -281,38 +279,8 @@ def assert_all_data_types_row(
assert db_mapping == expected_rows


def arrow_format_from_pandas(
df: Any,
object_format: TArrowFormat,
) -> Any:
from dlt.common.libs.pyarrow import pyarrow as pa

if object_format == "pandas":
return df
elif object_format == "table":
return pa.Table.from_pandas(df)
elif object_format == "record_batch":
return pa.RecordBatch.from_pandas(df)
raise ValueError("Unknown item type: " + object_format)


def arrow_item_from_table(
table: Any,
object_format: TArrowFormat,
) -> Any:
from dlt.common.libs.pyarrow import pyarrow as pa

if object_format == "pandas":
return table.to_pandas()
elif object_format == "table":
return table
elif object_format == "record_batch":
return table.to_batches()[0]
raise ValueError("Unknown item type: " + object_format)


def arrow_table_all_data_types(
object_format: TArrowFormat,
object_format: TestDataItemFormat,
include_json: bool = True,
include_time: bool = True,
include_binary: bool = True,
Expand Down Expand Up @@ -379,15 +347,18 @@ def arrow_table_all_data_types(
.drop(columns=["null"])
.to_dict("records")
)
return arrow_format_from_pandas(df, object_format), rows, data
if object_format == "object":
return rows, rows, data
else:
return arrow_item_from_pandas(df, object_format), rows, data


def prepare_shuffled_tables() -> Tuple[Any, Any, Any]:
from dlt.common.libs.pyarrow import remove_columns
from dlt.common.libs.pyarrow import pyarrow as pa

table, _, _ = arrow_table_all_data_types(
"table",
"arrow-table",
include_json=False,
include_not_normalized_name=False,
tz="Europe/Berlin",
Expand Down
12 changes: 6 additions & 6 deletions tests/extract/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def some_data(created_at=dlt.sources.incremental("data.items[0].created_at")):
assert s["last_value"] == 2


@pytest.mark.parametrize("item_type", ["arrow", "pandas"])
@pytest.mark.parametrize("item_type", ["arrow-table", "pandas"])
def test_nested_cursor_path_arrow_fails(item_type: TestDataItemFormat) -> None:
data = [{"data": {"items": [{"created_at": 2}]}}]
source_items = data_to_item_format(item_type, data)
Expand Down Expand Up @@ -708,7 +708,7 @@ def some_data(step, last_timestamp=dlt.sources.incremental("ts")):
p.run(r, destination="duckdb")


@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"json"})
@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"object"})
def test_start_value_set_to_last_value_arrow(item_type: TestDataItemFormat) -> None:
p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb")
now = pendulum.now()
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def some_data(
resource.apply_hints(incremental=dlt.sources.incremental("updated_at", initial_value=start_dt))
# and the data is naive. so it will work as expected with naive datetimes in the result set
data = list(resource)
if item_type == "json":
if item_type == "object":
# we do not convert data in arrow tables
assert data[0]["updated_at"].tzinfo is None

Expand All @@ -1059,7 +1059,7 @@ def some_data(
)
)
data = list(resource)
if item_type == "json":
if item_type == "object":
assert data[0]["updated_at"].tzinfo is None

# now use naive initial value but data is UTC
Expand All @@ -1070,7 +1070,7 @@ def some_data(
)
)
# will cause invalid comparison
if item_type == "json":
if item_type == "object":
with pytest.raises(InvalidStepFunctionArguments):
list(resource)
else:
Expand Down Expand Up @@ -1392,7 +1392,7 @@ def descending(
for chunk in chunks(count(start=48, step=-1), 10):
data = [{"updated_at": i, "package": package} for i in chunk]
# print(data)
yield data_to_item_format("json", data)
yield data_to_item_format("object", data)
if updated_at.can_close():
out_of_range.append(package)
return
Expand Down
4 changes: 2 additions & 2 deletions tests/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def expect_extracted_file(


class AssertItems(ItemTransform[TDataItem]):
def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "json") -> None:
def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "object") -> None:
self.expected_items = expected_items
self.item_type = item_type

Expand All @@ -56,7 +56,7 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]:


def data_item_to_list(from_type: TestDataItemFormat, values: List[TDataItem]):
if from_type in ["arrow", "arrow-batch"]:
if from_type in ["arrow-table", "arrow-batch"]:
return values[0].to_pylist()
elif from_type == "pandas":
return values[0].to_dict("records")
Expand Down
128 changes: 0 additions & 128 deletions tests/libs/test_arrow_csv_writer.py

This file was deleted.

Loading
Loading