From d4422cff774b1ae3593ea7f42d5f3ad752f2c9c8 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 26 Feb 2024 01:00:47 +0100 Subject: [PATCH] fixes parquet tests --- tests/load/pipeline/test_merge_disposition.py | 16 +++++++++++++--- tests/load/pipeline/test_pipelines.py | 2 +- tests/load/pipeline/utils.py | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index b0bbd6dfb6..19ee9a34c8 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -319,11 +319,18 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) ) assert_load_info(info) # make sure it was parquet or sql transforms + expected_formats = ["parquet"] + if p.staging: + # allow references if staging is present + expected_formats.append("reference") files = p.get_load_package_info(p.list_completed_load_packages()[0]).jobs["completed_jobs"] - assert all(f.job_file_info.file_format in ["parquet", "sql"] for f in files) + assert all(f.job_file_info.file_format in expected_formats + ["sql"] for f in files) github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) - assert github_1_counts["issues"] == 100 + expected_rows = 100 + if not destination_config.supports_merge: + expected_rows *= 2 + assert github_1_counts["issues"] == expected_rows # now retry with replace github_data = github() @@ -333,7 +340,10 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) assert_load_info(info) # make sure it was parquet or sql inserts files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs["completed_jobs"] - assert all(f.job_file_info.file_format in ["parquet"] for f in files) + if destination_config.force_iceberg: + # iceberg uses sql to copy tables + expected_formats.append("sql") + assert all(f.job_file_info.file_format in expected_formats for f in files) github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) assert github_1_counts["issues"] == 100 diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 77e9105258..5fa656ada9 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -828,7 +828,7 @@ def some_source(): # all three jobs succeeded assert len(package_info.jobs["failed_jobs"]) == 0 # 3 tables + 1 state + 4 reference jobs if staging - expected_completed_jobs = 4 + 4 if destination_config.staging else 4 + expected_completed_jobs = 4 + 4 if pipeline.staging else 4 # add sql merge job if destination_config.supports_merge: expected_completed_jobs += 1 diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index 54c6231dcc..7a5ef02ae6 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -65,7 +65,7 @@ def _drop_dataset(schema_name: str) -> None: for schema_name in p.schema_names: _drop_dataset(schema_name) - p._wipe_working_folder() + # p._wipe_working_folder() # deactivate context Container()[PipelineContext].deactivate()