Skip to content

Commit

Permalink
Use "arrow" format instead of "parquet", tests
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Oct 23, 2023
1 parent 0d758fc commit 935c4c8
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 30 deletions.
3 changes: 2 additions & 1 deletion dlt/common/storages/load_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ def build_job_file_name(self, table_name: str, file_id: str, retry_count: int =
# FileStorage.validate_file_name_component(file_id)
fn = f"{table_name}.{file_id}.{int(retry_count)}"
if with_extension:
return fn + f".{self.loader_file_format}"
format_spec = DataWriter.data_format_from_file_format(self.loader_file_format)
return fn + f".{format_spec.file_extension}"
return fn

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions dlt/normalize/items_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _write_with_dlt_columns(
lambda batch: (uniq_id_base64(10) for _ in range(batch.num_rows))
))
items_count = 0
as_py = load_storage.loader_file_format != "parquet"
as_py = load_storage.loader_file_format != "arrow"
with normalize_storage.storage.open_file(extracted_items_file, "rb") as f:
for batch in pyarrow.pq_stream_with_new_columns(f, new_columns, batch_size=self.RECORD_BATCH_SIZE):
items_count += batch.num_rows
Expand Down Expand Up @@ -173,7 +173,7 @@ def __call__(
) -> Tuple[List[TSchemaUpdate], int, TRowCount]:
import pyarrow as pa

if config.parquet_add_dlt_id or config.parquet_add_dlt_load_id or load_storage.loader_file_format != "parquet":
if config.parquet_add_dlt_id or config.parquet_add_dlt_load_id or load_storage.loader_file_format != "arrow":
items_count = self._write_with_dlt_columns(
extracted_items_file,
normalize_storage,
Expand Down
16 changes: 10 additions & 6 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,19 @@ def w_normalize_files(
load_storages: Dict[TLoaderFileFormat, LoadStorage] = {}

def _get_load_storage(file_format: TLoaderFileFormat) -> LoadStorage:
if file_format != "parquet":
file_format = destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format
if storage := load_storages.get(file_format):
return storage
# TODO: capabilities.supported_*_formats can be None, it should have defaults
supported_formats = list(set(destination_caps.supported_loader_file_formats or []) | set(destination_caps.supported_staging_file_formats or []))
if file_format == "parquet" and file_format not in supported_formats:
# Use default storage if parquet is not supported to make normalizer fallback to read rows from the file
if file_format == "parquet":
if file_format in supported_formats:
supported_formats.append("arrow") # TODO: Hack to make load storage use the correct writer
file_format = "arrow"
else:
# Use default storage if parquet is not supported to make normalizer fallback to read rows from the file
file_format = destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format
else:
file_format = destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format
if storage := load_storages.get(file_format):
return storage
storage = load_storages[file_format] = LoadStorage(False, file_format, supported_formats, loader_storage_config)
return storage

Expand Down
2 changes: 1 addition & 1 deletion tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def arrow_table_all_data_types(object_format: TArrowFormat, include_json: bool =
}

if include_json:
data["json"] = [{random.choice(ascii_lowercase): random.randrange(0, 100)} for _ in range(num_rows)]
data["json"] = [{"a": random.randrange(0, 100)} for _ in range(num_rows)]

if include_time:
data["time"] = pd.date_range("2021-01-01", periods=num_rows, tz="UTC").time
Expand Down
81 changes: 61 additions & 20 deletions tests/pipeline/test_arrow_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tests.cases import arrow_table_all_data_types, TArrowFormat
from tests.utils import preserve_environ
from dlt.common.storages import LoadStorage
from dlt.common import json



Expand Down Expand Up @@ -49,7 +50,8 @@ def some_data():
load_id = pipeline.list_normalized_load_packages()[0]
storage = pipeline._get_load_storage()
jobs = storage.list_new_jobs(load_id)
with storage.storage.open_file(jobs[0], 'rb') as f:
job = [j for j in jobs if "some_data" in j][0]
with storage.storage.open_file(job, 'rb') as f:
normalized_bytes = f.read()

# Normalized is linked/copied exactly and should be the same as the extracted file
Expand Down Expand Up @@ -86,21 +88,44 @@ def some_data():
assert schema_columns['json']['data_type'] == 'complex'


# @pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"])
# def test_normalize_unsupported_loader_format(item_type: TArrowFormat):
# item, _ = arrow_table_all_data_types(item_type)

# pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="dummy")
@pytest.mark.parametrize(
("item_type", "is_list"), [("pandas", False), ("table", False), ("record_batch", False), ("pandas", True), ("table", True), ("record_batch", True)]
)
def test_normalize_jsonl(item_type: TArrowFormat, is_list: bool):
os.environ['DUMMY__LOADER_FILE_FORMAT'] = "jsonl"

# @dlt.resource
# def some_data():
# yield item
item, records = arrow_table_all_data_types(item_type)

# pipeline.extract(some_data())
# with pytest.raises(PipelineStepFailed) as py_ex:
# pipeline.normalize()
pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="dummy")

# assert "The destination doesn't support direct loading of arrow tables" in str(py_ex.value)
@dlt.resource
def some_data():
if is_list:
yield [item]
else:
yield item


pipeline.extract(some_data())
pipeline.normalize()

load_id = pipeline.list_normalized_load_packages()[0]
storage = pipeline._get_load_storage()
jobs = storage.list_new_jobs(load_id)
job = [j for j in jobs if "some_data" in j][0]
with storage.storage.open_file(job, 'r') as f:
result = [json.loads(line) for line in f]
for row in result:
row['decimal'] = Decimal(row['decimal'])

for record in records:
record['datetime'] = record['datetime'].replace(tzinfo=None)

expected = json.loads(json.dumps(records))
for record in expected:
record['decimal'] = Decimal(record['decimal'])
assert result == expected


@pytest.mark.parametrize("item_type", ["table", "record_batch"])
Expand All @@ -125,13 +150,13 @@ def map_func(item):
assert pa.compute.all(pa.compute.greater(result_tbl['int'], 80)).as_py()


@pytest.mark.parametrize("item_type", ["table"])
@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"])
def test_normalize_with_dlt_columns(item_type: TArrowFormat):
item, _ = arrow_table_all_data_types(item_type, num_rows=1234)
item, records = arrow_table_all_data_types(item_type, num_rows=5432)
os.environ['NORMALIZE__PARQUET_ADD_DLT_LOAD_ID'] = "True"
os.environ['NORMALIZE__PARQUET_ADD_DLT_ID'] = "True"
# Make sure everything works table is larger than buffer size
os.environ['DATA_WRITER__BUFFER_MAX_ITEMS'] = "50"
# Test with buffer smaller than the number of batches to be written
os.environ['DATA_WRITER__BUFFER_MAX_ITEMS'] = "4"

@dlt.resource
def some_data():
Expand All @@ -145,8 +170,24 @@ def some_data():
load_id = pipeline.list_normalized_load_packages()[0]
storage = pipeline._get_load_storage()
jobs = storage.list_new_jobs(load_id)
with storage.storage.open_file(jobs[0], 'rb') as f:
normalized_bytes = f.read()
job = [j for j in jobs if "some_data" in j][0]
with storage.storage.open_file(job, 'rb') as f:
tbl = pa.parquet.read_table(f)

pq = pa.parquet.ParquetFile(f)
tbl = pq.read()
assert len(tbl) == 5432

# Test one column matches source data
assert tbl['string'].to_pylist() == [r['string'] for r in records]

assert pa.compute.all(pa.compute.equal(tbl['_dlt_load_id'], load_id)).as_py()

all_ids = tbl['_dlt_id'].to_pylist()
assert len(all_ids[0]) >= 14

# All ids are unique
assert len(all_ids) == len(set(all_ids))

# _dlt_id and _dlt_load_id are added to pipeline schema
schema = pipeline.default_schema
assert schema.tables['some_data']['columns']['_dlt_id']['data_type'] == 'text'
assert schema.tables['some_data']['columns']['_dlt_load_id']['data_type'] == 'text'

0 comments on commit 935c4c8

Please sign in to comment.