From 470e59efcb917f1fbf69141209462dfda84617f7 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 26 Nov 2024 21:31:08 +0200 Subject: [PATCH] feat(postgres): add support for CSV loading of geometry columns Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/postgres/postgres.py | 10 ---- .../postgres/test_postgres_table_builder.py | 57 ++++++++++++++++--- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 8abfea2acb..2459ee1dbe 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -13,13 +13,11 @@ FollowupJobRequest, LoadJob, ) -from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TColumnType from dlt.common.schema.utils import is_nullable_column from dlt.common.storages.file_storage import FileStorage from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration -from dlt.destinations.impl.postgres.postgres_adapter import GEOMETRY_HINT from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.sql_client import SqlClientBase @@ -158,14 +156,6 @@ def __init__( def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if any( - column.get(GEOMETRY_HINT) for column in table["columns"].values() - ) and not file_path.endswith("insert_values"): - # Only insert_values load jobs supported for geom types. - # TODO: This isn't actually true, can make it work with geoarrow! - raise TerminalValueError( - "CSV bulk loading is not supported for tables with geometry columns." - ) job = super().create_load_job(table, file_path, load_id, restore) if not job and file_path.endswith("csv"): job = PostgresCsvCopyJob(file_path) diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index fa7a1ca8ee..e2ed0f0b2e 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -264,6 +264,30 @@ def geodata_3857_wkb_hex(): def geodata_2163_wkb_hex(): yield from generate_sample_geometry_records("wkb_hex") + @dlt.resource(file_format="csv") + def geodata_default_csv_wkt(): + yield from generate_sample_geometry_records("wkt") + + @dlt.resource(file_format="csv") + def geodata_3857_csv_wkt(): + yield from generate_sample_geometry_records("wkt") + + @dlt.resource(file_format="csv") + def geodata_2163_csv_wkt(): + yield from generate_sample_geometry_records("wkt") + + @dlt.resource(file_format="csv") + def geodata_default_csv_wkb_hex(): + yield from generate_sample_geometry_records("wkb_hex") + + @dlt.resource(file_format="csv") + def geodata_3857_csv_wkb_hex(): + yield from generate_sample_geometry_records("wkb_hex") + + @dlt.resource(file_format="csv") + def geodata_2163_csv_wkb_hex(): + yield from generate_sample_geometry_records("wkb_hex") + @dlt.resource def no_geodata(): yield from [{"a": 1}, {"a": 2}] @@ -274,6 +298,12 @@ def no_geodata(): postgres_adapter(geodata_default_wkb_hex, geometry=["geom"]) postgres_adapter(geodata_3857_wkb_hex, geometry=["geom"], srid=3857) postgres_adapter(geodata_2163_wkb_hex, geometry=["geom"], srid=2163) + postgres_adapter(geodata_default_csv_wkt, geometry=["geom"]) + postgres_adapter(geodata_3857_csv_wkt, geometry=["geom"], srid=3857) + postgres_adapter(geodata_2163_csv_wkt, geometry=["geom"], srid=2163) + postgres_adapter(geodata_default_csv_wkb_hex, geometry=["geom"]) + postgres_adapter(geodata_3857_csv_wkb_hex, geometry=["geom"], srid=3857) + postgres_adapter(geodata_2163_csv_wkb_hex, geometry=["geom"], srid=2163) @dlt.source def geodata() -> List[DltResource]: @@ -285,6 +315,12 @@ def geodata() -> List[DltResource]: geodata_3857_wkb_hex, geodata_2163_wkb_hex, no_geodata, + geodata_default_csv_wkt, + geodata_3857_csv_wkt, + geodata_2163_csv_wkt, + geodata_default_csv_wkb_hex, + geodata_3857_csv_wkb_hex, + geodata_2163_csv_wkb_hex, ] pipeline = destination_config.setup_pipeline("test_geometry_types", dev_mode=True) @@ -296,13 +332,14 @@ def geodata() -> List[DltResource]: # Assert that types were read in as PostGIS geometry types with pipeline.sql_client() as c: with c.execute_query(f"""SELECT f_geometry_column - FROM geometry_columns - WHERE f_table_name in ( - 'geodata_default_wkb', 'geodata_3857_wkb', 'geodata_2163_wkb', - 'geodata_default_wkt', 'geodata_3857_wkt', 'geodata_2163_wkt', - 'geodata_default_wkb_hex', 'geodata_3857_wkb_hex', 'geodata_2163_wkb_hex' - ) - AND f_table_schema = '{c.fully_qualified_dataset_name(escape=False)}'""") as cur: +FROM geometry_columns +WHERE f_table_name in + ('geodata_default_wkb', 'geodata_3857_wkb', 'geodata_2163_wkb', 'geodata_default_wkt', 'geodata_3857_wkt', + 'geodata_2163_wkt', 'geodata_default_wkb_hex', 'geodata_3857_wkb_hex', 'geodata_2163_wkb_hex', + 'geodata_default_csv_wkt', 'geodata_3857_csv_wkt', 'geodata_2163_csv_wkt', 'geodata_default_csv_wkb_hex', + 'geodata_3857_csv_wkb_hex', 'geodata_2163_csv_wkb_hex' + ) + AND f_table_schema = '{c.fully_qualified_dataset_name(escape=False)}'""") as cur: records = cur.fetchall() assert records assert {record[0] for record in records} == {"geom"} @@ -315,6 +352,12 @@ def geodata() -> List[DltResource]: "geodata_default_wkb_hex", "geodata_3857_wkb_hex", "geodata_2163_wkb_hex", + "geodata_default_csv_wkt", + "geodata_3857_csv_wkt", + "geodata_2163_csv_wkt", + "geodata_default_csv_wkb_hex", + "geodata_3857_csv_wkb_hex", + "geodata_2163_csv_wkb_hex", ]: srid = 4326 if resource.startswith("geodata_default") else int(resource.split("_")[1])