Skip to content

Commit

Permalink
synapse and mssql bugfixes and improvements (#1174)
Browse files Browse the repository at this point in the history
* add chunk separation handling for select_union writer type

* include synapse in insert job client tests and make them pass

* remove obsolete test

* set max_rows_per_insert to prevent error on larger queries in synapse

* remove pipeline dependency

* make imports conditional on destination type

* include synapse in insert job client tests and make them pass

* include mssql in insert job client tests and make them pass

* make psycopg2 import conditional on destination type

---------

Co-authored-by: Jorrit Sandbrink <[email protected]>
  • Loading branch information
jorritsandbrink and Jorrit Sandbrink authored Apr 7, 2024
1 parent b67eda4 commit d3ecc9e
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 98 deletions.
30 changes: 14 additions & 16 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N
super().__init__(f, caps)
self._chunks_written = 0
self._headers_lookup: Dict[str, int] = None
self.writer_type = caps.insert_values_writer_type
if self.writer_type == "default":
self.pre, self.post, self.sep = ("(", ")", ",\n")
elif self.writer_type == "select_union":
self.pre, self.post, self.sep = ("SELECT ", "", " UNION ALL\n")

def write_header(self, columns_schema: TTableSchemaColumns) -> None:
assert self._chunks_written == 0
Expand All @@ -176,10 +181,9 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
# do not write INSERT INTO command, this must be added together with table name by the loader
self._f.write("INSERT INTO {}(")
self._f.write(",".join(map(self._caps.escape_identifier, headers)))
if self._caps.insert_values_writer_type == "default":
self._f.write(")\nVALUES\n")
elif self._caps.insert_values_writer_type == "select_union":
self._f.write(")\n")
self._f.write(")\n")
if self.writer_type == "default":
self._f.write("VALUES\n")

def write_data(self, rows: Sequence[Any]) -> None:
super().write_data(rows)
Expand All @@ -188,21 +192,15 @@ def write_row(row: StrAny, last_row: bool = False) -> None:
output = ["NULL"] * len(self._headers_lookup)
for n, v in row.items():
output[self._headers_lookup[n]] = self._caps.escape_literal(v)
if self._caps.insert_values_writer_type == "default":
self._f.write("(")
self._f.write(",".join(output))
self._f.write(")")
if not last_row:
self._f.write(",\n")
elif self._caps.insert_values_writer_type == "select_union":
self._f.write("SELECT ")
self._f.write(",".join(output))
if not last_row:
self._f.write("\nUNION ALL\n")
self._f.write(self.pre)
self._f.write(",".join(output))
self._f.write(self.post)
if not last_row:
self._f.write(self.sep)

# if next chunk add separator
if self._chunks_written > 0:
self._f.write(",\n")
self._f.write(self.sep)

# write rows
for row in rows[:-1]:
Expand Down
13 changes: 11 additions & 2 deletions dlt/destinations/impl/mssql/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,17 @@ def commit_transaction(self) -> None:

@raise_database_error
def rollback_transaction(self) -> None:
self._conn.rollback()
self._conn.autocommit = True
try:
self._conn.rollback()
except pyodbc.ProgrammingError as ex:
if (
ex.args[0] == "42000" and "(111214)" in ex.args[1]
): # "no corresponding transaction found"
pass # there was nothing to rollback, we silently ignore the error
else:
raise
finally:
self._conn.autocommit = True

@property
def native_connection(self) -> pyodbc.Connection:
Expand Down
6 changes: 6 additions & 0 deletions dlt/destinations/impl/synapse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.supports_transactions = True
caps.supports_ddl_transactions = False

# Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries."
# if number of records exceeds a certain number. Which exact number that is seems not deterministic:
# in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same.
# 10.000 records is a "safe" amount that always seems to work.
caps.max_rows_per_insert = 10000

# datetimeoffset can store 7 digits for fractional seconds
# https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16
caps.timestamp_precision = 7
Expand Down
19 changes: 9 additions & 10 deletions dlt/destinations/impl/synapse/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
from dlt.common.destination.reference import (
SupportsStagingDestination,
NewLoadJob,
CredentialsConfiguration,
)

from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint
from dlt.common.schema.utils import table_schema_has_type, get_inherited_table_hint
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.schema.utils import (
table_schema_has_type,
get_inherited_table_hint,
is_complete_column,
)

from dlt.common.configuration.specs import AzureCredentialsWithoutDefaults

from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.insert_job_client import InsertValuesJobClient
from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob
from dlt.destinations.exceptions import LoadJobTerminalException

Expand Down Expand Up @@ -127,7 +128,7 @@ def _create_replace_followup_jobs(
self, table_chain: Sequence[TTableSchema]
) -> List[NewLoadJob]:
if self.config.replace_strategy == "staging-optimized":
return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)]
return [SynapseStagingCopyJob.from_table_chain(table_chain, self)] # type: ignore[arg-type]
return super()._create_replace_followup_jobs(table_chain)

def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema:
Expand Down Expand Up @@ -178,10 +179,11 @@ class SynapseStagingCopyJob(SqlStagingCopyJob):
def generate_sql(
cls,
table_chain: Sequence[TTableSchema],
sql_client: SqlClientBase[Any],
job_client: SynapseClient, # type: ignore[override]
params: Optional[SqlJobParams] = None,
) -> List[str]:
sql: List[str] = []
sql_client = job_client.sql_client
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
Expand All @@ -194,11 +196,8 @@ def generate_sql(
f" {staging_table_name};"
)
# recreate staging table
job_client = current.pipeline().destination_client() # type: ignore[operator]
with job_client.with_staging_dataset():
# get table columns from schema
columns = [c for c in job_client.schema.get_table_columns(table["name"]).values()]
# generate CREATE TABLE statement
columns = list(filter(is_complete_column, table["columns"].values()))
create_table_stmt = job_client._get_table_update_sql(
table["name"], columns, generate_alter=False
)
Expand Down
25 changes: 12 additions & 13 deletions dlt/destinations/insert_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st
# the procedure below will split the inserts into max_query_length // 2 packs
with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f:
header = f.readline()
if self._sql_client.capabilities.insert_values_writer_type == "default":
writer_type = self._sql_client.capabilities.insert_values_writer_type
if writer_type == "default":
sep = ","
# properly formatted file has a values marker at the beginning
values_mark = f.readline()
assert values_mark == "VALUES\n"
elif writer_type == "select_union":
sep = " UNION ALL"

max_rows = self._sql_client.capabilities.max_rows_per_insert

Expand All @@ -57,9 +61,7 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st
# if there was anything left, until_nl contains the last line
is_eof = len(until_nl) == 0 or until_nl[-1] == ";"
if not is_eof:
# print(f'replace the "," with " {until_nl} {len(insert_sql)}')
until_nl = until_nl[:-1] + ";"

until_nl = until_nl[: -len(sep)] + ";" # replace the separator with ";"
if max_rows is not None:
# mssql has a limit of 1000 rows per INSERT, so we need to split into separate statements
values_rows = content.splitlines(keepends=True)
Expand All @@ -69,25 +71,22 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st
for chunk in chunks(values_rows, max_rows - 1):
processed += len(chunk)
insert_sql.append(header.format(qualified_table_name))
if self._sql_client.capabilities.insert_values_writer_type == "default":
if writer_type == "default":
insert_sql.append(values_mark)
if processed == len_rows:
# On the last chunk we need to add the extra row read
insert_sql.append("".join(chunk) + until_nl)
else:
# Replace the , with ;
insert_sql.append("".join(chunk).strip()[:-1] + ";\n")
insert_sql.append("".join(chunk).strip()[: -len(sep)] + ";\n")
else:
# otherwise write all content in a single INSERT INTO
if self._sql_client.capabilities.insert_values_writer_type == "default":
if writer_type == "default":
insert_sql.extend(
[header.format(qualified_table_name), values_mark, content]
[header.format(qualified_table_name), values_mark, content + until_nl]
)
elif self._sql_client.capabilities.insert_values_writer_type == "select_union":
insert_sql.extend([header.format(qualified_table_name), content])

if until_nl:
insert_sql.append(until_nl)
elif writer_type == "select_union":
insert_sql.extend([header.format(qualified_table_name), content + until_nl])

# actually this may be empty if we were able to read a full file into content
if not is_eof:
Expand Down
Loading

0 comments on commit d3ecc9e

Please sign in to comment.