Skip to content

Commit

Permalink
fixes parquet test and handles duckdb TIME problem
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Oct 17, 2023
1 parent 5ac4007 commit 8443ea6
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 8 deletions.
14 changes: 14 additions & 0 deletions dlt/destinations/duckdb/duck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.data_types import TDataType
from dlt.common.exceptions import TerminalValueError
from dlt.common.schema import TColumnSchema, TColumnHint, Schema
from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState
from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat
Expand Down Expand Up @@ -63,6 +64,9 @@ class DuckDbTypeMapper(TypeMapper):
"INTEGER": "bigint",
"BIGINT": "bigint",
"HUGEINT": "bigint",
"TIMESTAMP_S": "timestamp",
"TIMESTAMP_MS": "timestamp",
"TIMESTAMP_NS": "timestamp",
}

def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
Expand All @@ -79,6 +83,16 @@ def to_db_integer_type(self, precision: Optional[int], table_format: TTableForma
return "BIGINT"
return "HUGEINT"

def to_db_datetime_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
if precision is None or precision == 6:
return super().to_db_datetime_type(precision, table_format)
if precision == 0:
return "TIMESTAMP_S"
if precision == 3:
return "TIMESTAMP_MS"
if precision == 9:
return "TIMESTAMP_NS"
raise TerminalValueError(f"timestamp {precision} cannot be mapped into duckdb TIMESTAMP typ")

def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType:
# duckdb provides the types with scale and precision
Expand Down
12 changes: 12 additions & 0 deletions dlt/destinations/type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,23 @@ def to_db_integer_type(self, precision: Optional[int], table_format: TTableForma
# Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.)
return self.sct_to_unbound_dbt["bigint"]

def to_db_datetime_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
# Override in subclass if db supports other timestamp types (e.g. with different time resolutions)
return self.sct_to_unbound_dbt["timestamp"]

def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
# Override in subclass if db supports other time types (e.g. with different time resolutions)
return self.sct_to_unbound_dbt["time"]

def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str:
precision, scale = column.get("precision"), column.get("scale")
sc_t = column["data_type"]
if sc_t == "bigint":
return self.to_db_integer_type(precision, table_format)
if sc_t == "timestamp":
return self.to_db_datetime_type(precision, table_format)
if sc_t == "time":
return self.to_db_time_type(precision, table_format)
bounded_template = self.sct_to_dbt.get(sc_t)
if not bounded_template:
return self.sct_to_unbound_dbt[sc_t]
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/dbt_tests/test_runner_dbt_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def test_infer_venv_deps() -> None:
# provide version ranges
requirements = _create_dbt_deps(["duckdb"], dbt_version=">3")
# special duckdb dependency
assert requirements[:-1] == ["dbt-core>3", "dbt-duckdb", "duckdb==0.8.1"]
assert requirements[:-1] == ["dbt-core>3", "dbt-duckdb", "duckdb==0.9.1"]
# we do not validate version ranges, pip will do it and fail when creating venv
requirements = _create_dbt_deps(["motherduck"], dbt_version="y")
assert requirements[:-1] == ["dbt-corey", "dbt-duckdb", "duckdb==0.8.1"]
assert requirements[:-1] == ["dbt-corey", "dbt-duckdb", "duckdb==0.9.1"]


def test_default_profile_name() -> None:
Expand Down
41 changes: 36 additions & 5 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def test_snowflake_delete_file_after_copy(destination_config: DestinationTestCon


# do not remove - it allows us to filter tests by destination
@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, file_format="parquet"), ids=lambda x: x.name)
@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, all_staging_configs=True, file_format="parquet"), ids=lambda x: x.name)
def test_parquet_loading(destination_config: DestinationTestConfiguration) -> None:
"""Run pipeline twice with merge write disposition
Resource with primary key falls back to append. Resource without keys falls back to replace.
Expand All @@ -641,6 +641,23 @@ def other_data():
if destination_config.destination == "bigquery":
column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text"

# duckdb 0.9.1 does not support TIME other than 6
if destination_config.destination in ["duckdb", "motherduck"]:
column_schemas["col11_precision"]["precision"] = 0

# drop TIME from databases not supporting it via parquet
if destination_config.destination in ["redshift", "athena"]:
data_types.pop("col11")
data_types.pop("col11_null")
data_types.pop("col11_precision")
column_schemas.pop("col11")
column_schemas.pop("col11_null")
column_schemas.pop("col11_precision")

if destination_config.destination == "redshift":
data_types.pop("col7_precision")
column_schemas.pop("col7_precision")

# apply the exact columns definitions so we process complex and wei types correctly!
@dlt.resource(table_name="data_types", write_disposition="merge", columns=column_schemas)
def my_resource():
Expand All @@ -653,19 +670,33 @@ def some_source():

info = pipeline.run(some_source(), loader_file_format="parquet")
package_info = pipeline.get_load_package_info(info.loads_ids[0])
# print(package_info.asstr(verbosity=2))
assert package_info.state == "loaded"
# all three jobs succeeded
assert len(package_info.jobs["failed_jobs"]) == 0
assert len(package_info.jobs["completed_jobs"]) == 5 # 3 tables + 1 state + 1 sql merge job
# 3 tables + 1 state + 4 reference jobs if staging
expected_completed_jobs = 4 + 4 if destination_config.staging else 4
# add sql merge job
if destination_config.supports_merge:
expected_completed_jobs += 1
# add iceberg copy jobs
if destination_config.force_iceberg:
expected_completed_jobs += 4
assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs

with pipeline.sql_client() as sql_client:
assert [row[0] for row in sql_client.execute_sql("SELECT * FROM other_data")] == [1, 2, 3, 4, 5]
assert [row[0] for row in sql_client.execute_sql("SELECT * FROM some_data")] == [1, 2, 3]
assert [row[0] for row in sql_client.execute_sql("SELECT * FROM other_data ORDER BY 1")] == [1, 2, 3, 4, 5]
assert [row[0] for row in sql_client.execute_sql("SELECT * FROM some_data ORDER BY 1")] == [1, 2, 3]
db_rows = sql_client.execute_sql("SELECT * FROM data_types")
assert len(db_rows) == 10
db_row = list(db_rows[0])
# "snowflake" and "bigquery" do not parse JSON form parquet string so double parse
assert_all_data_types_row(db_row[:-2], parse_complex_strings=destination_config.destination in ["snowflake", "bigquery"])
assert_all_data_types_row(
db_row,
schema=column_schemas,
parse_complex_strings=destination_config.destination in ["snowflake", "bigquery", "redshift"],
timestamp_precision= 3 if destination_config.destination == "athena" else 6
)


def simple_nested_pipeline(destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]:
Expand Down
2 changes: 1 addition & 1 deletion tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def destinations_configs(
destination_configs += [DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS if destination != "athena"]
destination_configs += [DestinationTestConfiguration(destination="duckdb", file_format="parquet")]
# athena needs filesystem staging, which will be automatically set, we have to supply a bucket url though
destination_configs += [DestinationTestConfiguration(destination="athena", supports_merge=False, bucket_url=AWS_BUCKET)]
destination_configs += [DestinationTestConfiguration(destination="athena", staging="filesystem", file_format="parquet", supports_merge=False, bucket_url=AWS_BUCKET)]
destination_configs += [DestinationTestConfiguration(destination="athena", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, force_iceberg=True, supports_merge=False, supports_dbt=False, extra_info="iceberg")]

if default_vector_configs:
Expand Down

0 comments on commit 8443ea6

Please sign in to comment.