diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index f4a0d88017..d034ef239a 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -426,7 +426,8 @@ def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = # FileStorage.validate_file_name_component(file_id) fn = f"{table_name}.{file_id}.{int(retry_count)}" if with_extension: - return fn + f".{self.loader_file_format}" + format_spec = DataWriter.data_format_from_file_format(self.loader_file_format) + return fn + f".{format_spec.file_extension}" return fn @staticmethod diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index a48d931050..02cb712c70 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -145,7 +145,7 @@ def _write_with_dlt_columns( lambda batch: (uniq_id_base64(10) for _ in range(batch.num_rows)) )) items_count = 0 - as_py = load_storage.loader_file_format != "parquet" + as_py = load_storage.loader_file_format != "arrow" with normalize_storage.storage.open_file(extracted_items_file, "rb") as f: for batch in pyarrow.pq_stream_with_new_columns(f, new_columns, batch_size=self.RECORD_BATCH_SIZE): items_count += batch.num_rows @@ -173,7 +173,7 @@ def __call__( ) -> Tuple[List[TSchemaUpdate], int, TRowCount]: import pyarrow as pa - if config.parquet_add_dlt_id or config.parquet_add_dlt_load_id or load_storage.loader_file_format != "parquet": + if config.parquet_add_dlt_id or config.parquet_add_dlt_load_id or load_storage.loader_file_format != "arrow": items_count = self._write_with_dlt_columns( extracted_items_file, normalize_storage, diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 86df88205a..a850a23378 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -84,15 +84,19 @@ def w_normalize_files( load_storages: Dict[TLoaderFileFormat, LoadStorage] = {} def _get_load_storage(file_format: TLoaderFileFormat) -> LoadStorage: - if file_format != "parquet": - file_format = destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format - if storage := load_storages.get(file_format): - return storage # TODO: capabilities.supported_*_formats can be None, it should have defaults supported_formats = list(set(destination_caps.supported_loader_file_formats or []) | set(destination_caps.supported_staging_file_formats or [])) - if file_format == "parquet" and file_format not in supported_formats: - # Use default storage if parquet is not supported to make normalizer fallback to read rows from the file + if file_format == "parquet": + if file_format in supported_formats: + supported_formats.append("arrow") # TODO: Hack to make load storage use the correct writer + file_format = "arrow" + else: + # Use default storage if parquet is not supported to make normalizer fallback to read rows from the file + file_format = destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format + else: file_format = destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format + if storage := load_storages.get(file_format): + return storage storage = load_storages[file_format] = LoadStorage(False, file_format, supported_formats, loader_storage_config) return storage diff --git a/tests/cases.py b/tests/cases.py index 8895b86990..78b581f55f 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -354,7 +354,7 @@ def arrow_table_all_data_types(object_format: TArrowFormat, include_json: bool = } if include_json: - data["json"] = [{random.choice(ascii_lowercase): random.randrange(0, 100)} for _ in range(num_rows)] + data["json"] = [{"a": random.randrange(0, 100)} for _ in range(num_rows)] if include_time: data["time"] = pd.date_range("2021-01-01", periods=num_rows, tz="UTC").time diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 8db64672b8..97d035dbb3 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -14,6 +14,7 @@ from tests.cases import arrow_table_all_data_types, TArrowFormat from tests.utils import preserve_environ from dlt.common.storages import LoadStorage +from dlt.common import json @@ -49,7 +50,8 @@ def some_data(): load_id = pipeline.list_normalized_load_packages()[0] storage = pipeline._get_load_storage() jobs = storage.list_new_jobs(load_id) - with storage.storage.open_file(jobs[0], 'rb') as f: + job = [j for j in jobs if "some_data" in j][0] + with storage.storage.open_file(job, 'rb') as f: normalized_bytes = f.read() # Normalized is linked/copied exactly and should be the same as the extracted file @@ -86,21 +88,44 @@ def some_data(): assert schema_columns['json']['data_type'] == 'complex' -# @pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) -# def test_normalize_unsupported_loader_format(item_type: TArrowFormat): -# item, _ = arrow_table_all_data_types(item_type) -# pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="dummy") +@pytest.mark.parametrize( + ("item_type", "is_list"), [("pandas", False), ("table", False), ("record_batch", False), ("pandas", True), ("table", True), ("record_batch", True)] +) +def test_normalize_jsonl(item_type: TArrowFormat, is_list: bool): + os.environ['DUMMY__LOADER_FILE_FORMAT'] = "jsonl" -# @dlt.resource -# def some_data(): -# yield item + item, records = arrow_table_all_data_types(item_type) -# pipeline.extract(some_data()) -# with pytest.raises(PipelineStepFailed) as py_ex: -# pipeline.normalize() + pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="dummy") -# assert "The destination doesn't support direct loading of arrow tables" in str(py_ex.value) + @dlt.resource + def some_data(): + if is_list: + yield [item] + else: + yield item + + + pipeline.extract(some_data()) + pipeline.normalize() + + load_id = pipeline.list_normalized_load_packages()[0] + storage = pipeline._get_load_storage() + jobs = storage.list_new_jobs(load_id) + job = [j for j in jobs if "some_data" in j][0] + with storage.storage.open_file(job, 'r') as f: + result = [json.loads(line) for line in f] + for row in result: + row['decimal'] = Decimal(row['decimal']) + + for record in records: + record['datetime'] = record['datetime'].replace(tzinfo=None) + + expected = json.loads(json.dumps(records)) + for record in expected: + record['decimal'] = Decimal(record['decimal']) + assert result == expected @pytest.mark.parametrize("item_type", ["table", "record_batch"]) @@ -125,13 +150,13 @@ def map_func(item): assert pa.compute.all(pa.compute.greater(result_tbl['int'], 80)).as_py() -@pytest.mark.parametrize("item_type", ["table"]) +@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) def test_normalize_with_dlt_columns(item_type: TArrowFormat): - item, _ = arrow_table_all_data_types(item_type, num_rows=1234) + item, records = arrow_table_all_data_types(item_type, num_rows=5432) os.environ['NORMALIZE__PARQUET_ADD_DLT_LOAD_ID'] = "True" os.environ['NORMALIZE__PARQUET_ADD_DLT_ID'] = "True" - # Make sure everything works table is larger than buffer size - os.environ['DATA_WRITER__BUFFER_MAX_ITEMS'] = "50" + # Test with buffer smaller than the number of batches to be written + os.environ['DATA_WRITER__BUFFER_MAX_ITEMS'] = "4" @dlt.resource def some_data(): @@ -145,8 +170,24 @@ def some_data(): load_id = pipeline.list_normalized_load_packages()[0] storage = pipeline._get_load_storage() jobs = storage.list_new_jobs(load_id) - with storage.storage.open_file(jobs[0], 'rb') as f: - normalized_bytes = f.read() + job = [j for j in jobs if "some_data" in j][0] + with storage.storage.open_file(job, 'rb') as f: + tbl = pa.parquet.read_table(f) - pq = pa.parquet.ParquetFile(f) - tbl = pq.read() + assert len(tbl) == 5432 + + # Test one column matches source data + assert tbl['string'].to_pylist() == [r['string'] for r in records] + + assert pa.compute.all(pa.compute.equal(tbl['_dlt_load_id'], load_id)).as_py() + + all_ids = tbl['_dlt_id'].to_pylist() + assert len(all_ids[0]) >= 14 + + # All ids are unique + assert len(all_ids) == len(set(all_ids)) + + # _dlt_id and _dlt_load_id are added to pipeline schema + schema = pipeline.default_schema + assert schema.tables['some_data']['columns']['_dlt_id']['data_type'] == 'text' + assert schema.tables['some_data']['columns']['_dlt_load_id']['data_type'] == 'text'