Skip to content

Commit

Permalink
add support for staged Parquet loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorrit Sandbrink committed Jan 25, 2024
1 parent 75be2ce commit 014543a
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 25 deletions.
4 changes: 2 additions & 2 deletions dlt/destinations/impl/synapse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ def capabilities() -> DestinationCapabilitiesContext:

caps.preferred_loader_file_format = "insert_values"
caps.supported_loader_file_formats = ["insert_values"]
caps.preferred_staging_file_format = None
caps.supported_staging_file_formats = []
caps.preferred_staging_file_format = "parquet"
caps.supported_staging_file_formats = ["parquet"]

caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299

Expand Down
8 changes: 6 additions & 2 deletions dlt/destinations/impl/synapse/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,21 @@ class SynapseClientConfiguration(MsSqlClientConfiguration):
# are tricky in Synapse: they are NOT ENFORCED and can lead to innacurate
# results if the user does not ensure all column values are unique.
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints
create_indexes: Optional[bool] = False
create_indexes: bool = False
"""Whether `primary_key` and `unique` column hints are applied."""

# Concurrency is disabled by overriding the configured number of workers to 1 at runtime.
auto_disable_concurrency: Optional[bool] = True
auto_disable_concurrency: bool = True
"""Whether concurrency is automatically disabled in cases where it might cause issues."""

staging_use_msi: bool = False
"""Whether the managed identity of the Synapse workspace is used to authorize access to the staging Storage Account."""

__config_gen_annotations__: ClassVar[List[str]] = [
"default_table_index_type",
"create_indexes",
"auto_disable_concurrency",
"staging_use_msi",
]

def get_load_workers(self, tables: TSchemaTables, workers: int) -> int:
Expand Down
5 changes: 4 additions & 1 deletion dlt/destinations/impl/synapse/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
credentials: t.Union[SynapseCredentials, t.Dict[str, t.Any], str] = None,
default_table_index_type: t.Optional[TTableIndexType] = "heap",
create_indexes: bool = False,
auto_disable_concurrency: t.Optional[bool] = True,
auto_disable_concurrency: bool = True,
staging_use_msi: bool = False,
destination_name: t.Optional[str] = None,
environment: t.Optional[str] = None,
**kwargs: t.Any,
Expand All @@ -45,13 +46,15 @@ def __init__(
default_table_index_type: Maps directly to the default_table_index_type attribute of the SynapseClientConfiguration object.
create_indexes: Maps directly to the create_indexes attribute of the SynapseClientConfiguration object.
auto_disable_concurrency: Maps directly to the auto_disable_concurrency attribute of the SynapseClientConfiguration object.
auto_disable_concurrency: Maps directly to the staging_use_msi attribute of the SynapseClientConfiguration object.
**kwargs: Additional arguments passed to the destination config
"""
super().__init__(
credentials=credentials,
default_table_index_type=default_table_index_type,
create_indexes=create_indexes,
auto_disable_concurrency=auto_disable_concurrency,
staging_use_msi=staging_use_msi,
destination_name=destination_name,
environment=environment,
**kwargs,
Expand Down
115 changes: 112 additions & 3 deletions dlt/destinations/impl/synapse/synapse.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
import os
from typing import ClassVar, Sequence, List, Dict, Any, Optional, cast
from copy import deepcopy
from textwrap import dedent
from urllib.parse import urlparse, urlunparse

from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import SupportsStagingDestination, NewLoadJob
from dlt.common.destination.reference import (
SupportsStagingDestination,
NewLoadJob,
CredentialsConfiguration,
)

from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint
from dlt.common.schema.utils import table_schema_has_type
from dlt.common.schema.typing import TTableSchemaColumns, TTableIndexType

from dlt.common.configuration.specs import AzureCredentialsWithoutDefaults

from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.insert_job_client import InsertValuesJobClient
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob
from dlt.destinations.exceptions import LoadJobTerminalException

from dlt.destinations.impl.mssql.mssql import (
MsSqlTypeMapper,
Expand All @@ -35,7 +46,7 @@
}


class SynapseClient(MsSqlClient):
class SynapseClient(MsSqlClient, SupportsStagingDestination):
capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None:
Expand Down Expand Up @@ -140,6 +151,21 @@ def get_storage_table_index_type(self, table_name: str) -> TTableIndexType:
table_index_type = sql_client.execute_sql(sql)[0][0]
return cast(TTableIndexType, table_index_type)

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
job = super().start_file_load(table, file_path, load_id)
if not job:
assert NewReferenceJob.is_reference_job(
file_path
), "Synapse must use staging to load files"
job = SynapseCopyFileLoadJob(
table,
file_path,
self.sql_client,
cast(AzureCredentialsWithoutDefaults, self.config.staging_config.credentials),
self.config.staging_use_msi,
)
return job


class SynapseStagingCopyJob(SqlStagingCopyJob):
@classmethod
Expand Down Expand Up @@ -173,3 +199,86 @@ def generate_sql(
)

return sql


class SynapseCopyFileLoadJob(CopyRemoteFileLoadJob):
def __init__(
self,
table: TTableSchema,
file_path: str,
sql_client: SqlClientBase[Any],
staging_credentials: Optional[AzureCredentialsWithoutDefaults] = None,
staging_use_msi: bool = False,
) -> None:
self.staging_use_msi = staging_use_msi
super().__init__(table, file_path, sql_client, staging_credentials)

def execute(self, table: TTableSchema, bucket_path: str) -> None:
# get format
ext = os.path.splitext(bucket_path)[1][1:]
if ext == "parquet":
if table_schema_has_type(table, "time"):
# Synapse interprets Parquet TIME columns as bigint, resulting in
# an incompatibility error.
raise LoadJobTerminalException(
self.file_name(),
"Synapse cannot load TIME columns from Parquet files. Switch to direct INSERT"
" file format or convert `datetime.time` objects in your data to `str` or"
" `datetime.datetime`",
)
file_type = "PARQUET"

# dlt-generated DDL statements will still create the table, but
# enabling AUTO_CREATE_TABLE prevents a MalformedInputException.
auto_create_table = "ON"
else:
raise ValueError(f"Unsupported file type {ext} for Synapse.")

staging_credentials = self._staging_credentials
assert staging_credentials is not None
assert isinstance(staging_credentials, AzureCredentialsWithoutDefaults)
azure_storage_account_name = staging_credentials.azure_storage_account_name
https_path = self._get_https_path(bucket_path, azure_storage_account_name)
table_name = table["name"]

if self.staging_use_msi:
credential = "IDENTITY = 'Managed Identity'"
else:
sas_token = staging_credentials.azure_storage_sas_token
credential = f"IDENTITY = 'Shared Access Signature', SECRET = '{sas_token}'"

# Copy data from staging file into Synapse table.
with self._sql_client.begin_transaction():
dataset_name = self._sql_client.dataset_name
sql = dedent(f"""
COPY INTO [{dataset_name}].[{table_name}]
FROM '{https_path}'
WITH (
FILE_TYPE = '{file_type}',
CREDENTIAL = ({credential}),
AUTO_CREATE_TABLE = '{auto_create_table}'
)
""")
self._sql_client.execute_sql(sql)

def exception(self) -> str:
# this part of code should be never reached
raise NotImplementedError()

def _get_https_path(self, bucket_path: str, storage_account_name: str) -> str:
"""
Converts a path in the form of az://<container_name>/<path> to
https://<storage_account_name>.blob.core.windows.net/<container_name>/<path>
as required by Synapse.
"""
bucket_url = urlparse(bucket_path)
# "blob" endpoint has better performance than "dfs" endoint
# https://learn.microsoft.com/en-us/sql/t-sql/statements/copy-into-transact-sql?view=azure-sqldw-latest#external-locations
endpoint = "blob"
_path = "/" + bucket_url.netloc + bucket_url.path
https_url = bucket_url._replace(
scheme="https",
netloc=f"{storage_account_name}.{endpoint}.core.windows.net",
path=_path,
)
return urlunparse(https_url)
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ cli = ["pipdeptree", "cron-descriptor"]
athena = ["pyathena", "pyarrow", "s3fs", "botocore"]
weaviate = ["weaviate-client"]
mssql = ["pyodbc"]
synapse = ["pyodbc"]
synapse = ["pyodbc", "adlfs"]
qdrant = ["qdrant-client"]

[tool.poetry.scripts]
Expand Down
17 changes: 9 additions & 8 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def other_data():
column_schemas["col11_precision"]["precision"] = 0

# drop TIME from databases not supporting it via parquet
if destination_config.destination in ["redshift", "athena"]:
if destination_config.destination in ["redshift", "athena", "synapse"]:
data_types.pop("col11")
data_types.pop("col11_null")
data_types.pop("col11_precision")
Expand Down Expand Up @@ -827,15 +827,16 @@ def some_source():
assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs

with pipeline.sql_client() as sql_client:
qual_name = sql_client.make_qualified_table_name
assert [
row[0] for row in sql_client.execute_sql("SELECT * FROM other_data ORDER BY 1")
row[0]
for row in sql_client.execute_sql(f"SELECT * FROM {qual_name('other_data')} ORDER BY 1")
] == [1, 2, 3, 4, 5]
assert [row[0] for row in sql_client.execute_sql("SELECT * FROM some_data ORDER BY 1")] == [
1,
2,
3,
]
db_rows = sql_client.execute_sql("SELECT * FROM data_types")
assert [
row[0]
for row in sql_client.execute_sql(f"SELECT * FROM {qual_name('some_data')} ORDER BY 1")
] == [1, 2, 3]
db_rows = sql_client.execute_sql(f"SELECT * FROM {qual_name('data_types')}")
assert len(db_rows) == 10
db_row = list(db_rows[0])
# "snowflake" and "bigquery" do not parse JSON form parquet string so double parse
Expand Down
35 changes: 29 additions & 6 deletions tests/load/pipeline/test_stage_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None:

# check item of first row in db
with pipeline.sql_client() as sql_client:
rows = sql_client.execute_sql("SELECT url FROM issues WHERE id = 388089021 LIMIT 1")
if destination_config.destination in ["mssql", "synapse"]:
qual_name = sql_client.make_qualified_table_name
rows = sql_client.execute_sql(
f"SELECT TOP 1 url FROM {qual_name('issues')} WHERE id = 388089021"
)
else:
rows = sql_client.execute_sql("SELECT url FROM issues WHERE id = 388089021 LIMIT 1")
assert rows[0][0] == "https://api.github.com/repos/duckdb/duckdb/issues/71"

if destination_config.supports_merge:
Expand All @@ -109,10 +115,23 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None:

# check changes where merged in
with pipeline.sql_client() as sql_client:
rows = sql_client.execute_sql("SELECT number FROM issues WHERE id = 1232152492 LIMIT 1")
assert rows[0][0] == 105
rows = sql_client.execute_sql("SELECT number FROM issues WHERE id = 1142699354 LIMIT 1")
assert rows[0][0] == 300
if destination_config.destination in ["mssql", "synapse"]:
qual_name = sql_client.make_qualified_table_name
rows_1 = sql_client.execute_sql(
f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1232152492"
)
rows_2 = sql_client.execute_sql(
f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1142699354"
)
else:
rows_1 = sql_client.execute_sql(
"SELECT number FROM issues WHERE id = 1232152492 LIMIT 1"
)
rows_2 = sql_client.execute_sql(
"SELECT number FROM issues WHERE id = 1142699354 LIMIT 1"
)
assert rows_1[0][0] == 105
assert rows_2[0][0] == 300

# test append
info = pipeline.run(
Expand Down Expand Up @@ -161,6 +180,9 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non
) and destination_config.file_format in ("parquet", "jsonl"):
# Redshift copy doesn't support TIME column
exclude_types.append("time")
if destination_config.destination == "synapse" and destination_config.file_format == "parquet":
# TIME columns are not supported for staged parquet loads into Synapse
exclude_types.append("time")
if destination_config.destination == "redshift" and destination_config.file_format in (
"parquet",
"jsonl",
Expand Down Expand Up @@ -199,7 +221,8 @@ def my_source():
assert_load_info(info)

with pipeline.sql_client() as sql_client:
db_rows = sql_client.execute_sql("SELECT * FROM data_types")
qual_name = sql_client.make_qualified_table_name
db_rows = sql_client.execute_sql(f"SELECT * FROM {qual_name('data_types')}")
assert len(db_rows) == 10
db_row = list(db_rows[0])
# parquet is not really good at inserting json, best we get are strings in JSON columns
Expand Down
17 changes: 17 additions & 0 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class DestinationTestConfiguration:
bucket_url: Optional[str] = None
stage_name: Optional[str] = None
staging_iam_role: Optional[str] = None
staging_use_msi: bool = False
extra_info: Optional[str] = None
supports_merge: bool = True # TODO: take it from client base class
force_iceberg: bool = False
Expand All @@ -118,6 +119,7 @@ def setup(self) -> None:
os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = self.bucket_url or ""
os.environ["DESTINATION__STAGE_NAME"] = self.stage_name or ""
os.environ["DESTINATION__STAGING_IAM_ROLE"] = self.staging_iam_role or ""
os.environ["DESTINATION__STAGING_USE_MSI"] = str(self.staging_use_msi) or ""
os.environ["DESTINATION__FORCE_ICEBERG"] = str(self.force_iceberg) or ""

"""For the filesystem destinations we disable compression to make analyzing the result easier"""
Expand Down Expand Up @@ -254,6 +256,21 @@ def destinations_configs(
bucket_url=AZ_BUCKET,
extra_info="az-authorization",
),
DestinationTestConfiguration(
destination="synapse",
staging="filesystem",
file_format="parquet",
bucket_url=AZ_BUCKET,
extra_info="az-authorization",
),
DestinationTestConfiguration(
destination="synapse",
staging="filesystem",
file_format="parquet",
bucket_url=AZ_BUCKET,
staging_use_msi=True,
extra_info="az-managed-identity",
),
]

if all_staging_configs:
Expand Down

0 comments on commit 014543a

Please sign in to comment.