Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Synapse Destination #677

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions .github/workflows/test_destination_synapse.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@


name: test synapse

on:
pull_request:
branches:
- master
- devel

workflow_dispatch:
push:
branches:
- master
- devel
- synapse

env:
DESTINATION__SYNAPSE__CREDENTIALS: ${{ secrets.SYNAPSE_CREDENTIALS }}
Expand All @@ -20,16 +25,6 @@ env:

jobs:

build:
runs-on: ubuntu-latest

steps:
- name: Check source branch name
run: |
if [[ "${{ github.head_ref }}" != "synapse" ]]; then
exit 1
fi

run_loader:
name: Tests Synapse loader
strategy:
Expand Down
27 changes: 27 additions & 0 deletions dlt/common/data_writers/escape.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ def escape_mssql_literal(v: Any) -> Any:
return str(int(v))
return str(v)

def escape_synapse_literal(v: Any) -> Any:
#TODO: Add handling for sql statements inside of INSERT queries (example: DROP SCHEMA Public --)
if isinstance(v, str):
return _escape_extended(v, prefix="N'", escape_dict=SYNAPSE_ESCAPE_DICT, escape_re=SYNAPSE_ESCAPE_RE)
if isinstance(v, (datetime, date, time)):
return f"'{v.isoformat()}'"
if isinstance(v, (list, dict)):
return _escape_extended(json.dumps(v), prefix="N'", escape_dict=SYNAPSE_ESCAPE_DICT, escape_re=SYNAPSE_ESCAPE_RE)
if isinstance(v, bytes):
# Updated to hex: azure synapse doesn't have XML and base64Binary
hex_string = v.hex()
return f"0x{hex_string}"
if isinstance(v, bool):
return str(int(v))
return str(v)

SYNAPSE_ESCAPE_DICT = {
"'": "''",
'\n': '\n',
'\r': '\r',
'\t': '\t',
'--': '- -',
';': '\\;',
}

SYNAPSE_ESCAPE_RE = _make_sql_escape_re(SYNAPSE_ESCAPE_DICT)


def escape_redshift_identifier(v: str) -> str:
return '"' + v.replace('"', '""').replace("\\", "\\\\") + '"'
Expand Down
58 changes: 44 additions & 14 deletions dlt/destinations/insert_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.job_impl import EmptyLoadJob
from dlt.destinations.job_client_impl import SqlJobClientWithStaging
from tests.utils import ACTIVE_DESTINATIONS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can't imort test code into production code...



class InsertValuesLoadJob(LoadJob, FollowupJob):
Expand Down Expand Up @@ -58,20 +59,49 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st
until_nl = until_nl[:-1] + ";"

if max_rows is not None:
# mssql has a limit of 1000 rows per INSERT, so we need to split into separate statements
values_rows = content.splitlines(keepends=True)
len_rows = len(values_rows)
processed = 0
# Chunk by max_rows - 1 for simplicity because one more row may be added
for chunk in chunks(values_rows, max_rows - 1):
processed += len(chunk)
insert_sql.extend([header.format(qualified_table_name), values_mark])
if processed == len_rows:
# On the last chunk we need to add the extra row read
insert_sql.append("".join(chunk) + until_nl)
else:
# Replace the , with ;
insert_sql.append("".join(chunk).strip()[:-1] + ";\n")
if "synapse" in ACTIVE_DESTINATIONS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you need a different behavior here for synapse, you need to extend or replace this load job for synapse or somehow manage this with the capabilities object. this code is depended on some var that is only available in the test environment and will not work in production

# This part breaks the multiple values in an insert statement into individual insert statements
# combined with SELECT and UNION ALL
#https://stackoverflow.com/questions/36141006/how-to-insert-multiple-rows-into-sql-server-parallel-data-warehouse-table

values_rows = content.splitlines(keepends=True)
sql_rows = []

for row in values_rows:
# Remove potential leading and trailing characters such as brackets, commas, and newlines
row = row.strip(",\n() ;")

# Separate out the individual values within the row
columns = row.split(",")

# Ensure there are no stray parentheses in columns
columns = [col.strip("()") for col in columns]

# Create the SELECT for this particular row, keeping the values as they are
sql_rows.append(f"SELECT {', '.join(columns)}")

individual_insert = " UNION ALL ".join(sql_rows)

# If individual_insert ends with a semicolon, remove it
if individual_insert.endswith(";"):
individual_insert = individual_insert[:-1]

insert_sql.extend([header.format(qualified_table_name), individual_insert + ";"])
else:
# mssql has a limit of 1000 rows per INSERT, so we need to split into separate statements
values_rows = content.splitlines(keepends=True)
len_rows = len(values_rows)
processed = 0
# Chunk by max_rows - 1 for simplicity because one more row may be added
for chunk in chunks(values_rows, max_rows - 1):
processed += len(chunk)
insert_sql.extend([header.format(qualified_table_name), values_mark])
if processed == len_rows:
# On the last chunk we need to add the extra row read
insert_sql.append("".join(chunk) + until_nl)
else:
# Replace the , with ;
insert_sql.append("".join(chunk).strip()[:-1] + ";\n")
else:
# otherwise write all content in a single INSERT INTO
insert_sql.extend([header.format(qualified_table_name), values_mark, content])
Expand Down
12 changes: 9 additions & 3 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from dlt.common import json, pendulum, logger
from dlt.common.data_types import TDataType
from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition
from dlt.common.schema.typing import COLUMN_HINTS, TColumnSchemaBase, TTableSchema, TWriteDisposition
from dlt.common.storages import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables
from dlt.common.destination.reference import StateInfo, StorageSchemaInfo,WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, DestinationClientDwhWithStagingConfiguration, NewLoadJob, WithStagingDataset, TLoadJobState, LoadJob, JobClientBase, FollowupJob, CredentialsConfiguration
Expand Down Expand Up @@ -242,13 +242,19 @@ def _null_to_bool(v: str) -> bool:
schema_c: TColumnSchemaBase = {
"name": c[0],
"nullable": _null_to_bool(c[2]),
**self._from_db_type(c[1], numeric_precision, numeric_scale), # type: ignore[misc]
"data_type": self._from_db_type(c[1], numeric_precision, numeric_scale),
}
schema_table[c[0]] = schema_c # type: ignore
return True, schema_table

@classmethod
@abstractmethod
def _from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType:
def _to_db_type(cls, schema_type: TDataType) -> str:
pass

@classmethod
@abstractmethod
def _from_db_type(cls, db_type: str, precision: Optional[int], scale: Optional[int]) -> TDataType:
pass

def get_stored_schema(self) -> StorageSchemaInfo:
Expand Down
5 changes: 5 additions & 0 deletions dlt/destinations/synapse/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# loader account setup

1. Create new database `CREATE DATABASE dlt_data`
2. Create new user, set password `CREATE USER loader WITH PASSWORD 'loader';`
3. Set as database owner (we could set lower permission) `ALTER DATABASE dlt_data OWNER TO loader`
59 changes: 59 additions & 0 deletions dlt/destinations/synapse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Type

from dlt.common.schema.schema import Schema
from dlt.common.configuration import with_config, known_sections
from dlt.common.configuration.accessors import config
from dlt.common.data_writers.escape import escape_postgres_identifier, escape_synapse_literal
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration
from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE
from dlt.common.wei import EVM_DECIMAL_PRECISION

from dlt.destinations.synapse.configuration import SynapseClientConfiguration

# TODO: Organize imports and capabilities

@with_config(spec=SynapseClientConfiguration, sections=(known_sections.DESTINATION, "synapse",))
def _configure(config: SynapseClientConfiguration = config.value) -> SynapseClientConfiguration:
return config


def capabilities() -> DestinationCapabilitiesContext:
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-service-capacity-limits
caps = DestinationCapabilitiesContext()
caps.preferred_loader_file_format = "insert_values"
caps.supported_loader_file_formats = ["insert_values"]
caps.preferred_staging_file_format = None
caps.supported_staging_file_formats = []
#TODO: Add a blob_storage preferred_staging_file_format capability for azure synapse
#caps.preferred_staging_file_format = "csv"
#caps.supported_staging_file_formats = ["csv","parquet"]
caps.escape_identifier = escape_postgres_identifier
caps.escape_literal = escape_synapse_literal
caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE)
caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0)
caps.max_identifier_length = 128
caps.max_column_identifier_length = 128
caps.max_query_length = 4 * 1024 * 64 * 1024
caps.is_max_query_length_in_bytes = True
caps.max_text_data_type_length = 4000
caps.is_max_text_data_type_length_in_bytes = False
caps.supports_ddl_transactions = False
caps.max_rows_per_insert = 1000

#TODO: Add and test supports_truncate_command capability in azure synapse (TRUNCATE works in some cases for synapse)


return caps


def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase:
# import client when creating instance so capabilities and config specs can be accessed without dependencies installed
from dlt.destinations.synapse.synapse import SynapseClient

return SynapseClient(schema, _configure(initial_config)) # type: ignore[arg-type]


def spec() -> Type[DestinationClientConfiguration]:
return SynapseClientConfiguration

88 changes: 88 additions & 0 deletions dlt/destinations/synapse/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Final, ClassVar, Any, List, Optional
from sqlalchemy.engine import URL

from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ConnectionStringCredentials
from dlt.common.utils import digest128
from dlt.common.typing import TSecretValue
from dlt.common.exceptions import SystemConfigurationException

from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration


@configspec
class SynapseCredentials(ConnectionStringCredentials):
drivername: Final[str] = "sqlserver" # type: ignore
password: TSecretValue
host: str
port: int = 1433
connect_timeout: int = 15
odbc_driver: str = None

__config_gen_annotations__: ClassVar[List[str]] = ["port", "connect_timeout"]

def parse_native_representation(self, native_value: Any) -> None:
# TODO: Support ODBC connection string or sqlalchemy URL
super().parse_native_representation(native_value)
self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout))
if not self.is_partial():
self.resolve()

def on_resolved(self) -> None:
self.database = self.database.lower()

def to_url(self) -> URL:
url = super().to_url()
url.update_query_pairs([("connect_timeout", str(self.connect_timeout))])
return url

def on_partial(self) -> None:
self.odbc_driver = self._get_odbc_driver()
if not self.is_partial():
self.resolve()

def _get_odbc_driver(self) -> str:
if self.odbc_driver:
return self.odbc_driver
# Pick a default driver if available
supported_drivers = ['ODBC Driver 18 for SQL Server', 'ODBC Driver 17 for SQL Server']
import pyodbc
available_drivers = pyodbc.drivers()
for driver in supported_drivers:
if driver in available_drivers:
return driver
docs_url = "https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16"
raise SystemConfigurationException(
f"No supported ODBC driver found for MS SQL Server. "
f"See {docs_url} for information on how to install the '{supported_drivers[0]}' on your platform."
)

def to_odbc_dsn(self) -> str:
params = {
"DRIVER": self.odbc_driver,
"SERVER": self.host,
"PORT": self.port,
"DATABASE": self.database,
"UID": self.username,
"PWD": self.password,
"LongAsMax": "yes",
"MARS_Connection": "yes"
}
if self.query:
params.update(self.query)
return ";".join([f"{k}={v}" for k, v in params.items()])



@configspec
class SynapseClientConfiguration(DestinationClientDwhWithStagingConfiguration):
destination_name: Final[str] = "synapse" # type: ignore
credentials: SynapseCredentials

create_indexes: bool = False

def fingerprint(self) -> str:
"""Returns a fingerprint of host part of a connection string"""
if self.credentials and self.credentials.host:
return digest128(self.credentials.host)
return ""
Loading