diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 2aadb010e0..4a4029c6f3 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -166,6 +166,11 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N super().__init__(f, caps) self._chunks_written = 0 self._headers_lookup: Dict[str, int] = None + self.writer_type = caps.insert_values_writer_type + if self.writer_type == "default": + self.pre, self.post, self.sep = ("(", ")", ",\n") + elif self.writer_type == "select_union": + self.pre, self.post, self.sep = ("SELECT ", "", " UNION ALL\n") def write_header(self, columns_schema: TTableSchemaColumns) -> None: assert self._chunks_written == 0 @@ -176,10 +181,9 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: # do not write INSERT INTO command, this must be added together with table name by the loader self._f.write("INSERT INTO {}(") self._f.write(",".join(map(self._caps.escape_identifier, headers))) - if self._caps.insert_values_writer_type == "default": - self._f.write(")\nVALUES\n") - elif self._caps.insert_values_writer_type == "select_union": - self._f.write(")\n") + self._f.write(")\n") + if self.writer_type == "default": + self._f.write("VALUES\n") def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) @@ -188,21 +192,15 @@ def write_row(row: StrAny, last_row: bool = False) -> None: output = ["NULL"] * len(self._headers_lookup) for n, v in row.items(): output[self._headers_lookup[n]] = self._caps.escape_literal(v) - if self._caps.insert_values_writer_type == "default": - self._f.write("(") - self._f.write(",".join(output)) - self._f.write(")") - if not last_row: - self._f.write(",\n") - elif self._caps.insert_values_writer_type == "select_union": - self._f.write("SELECT ") - self._f.write(",".join(output)) - if not last_row: - self._f.write("\nUNION ALL\n") + self._f.write(self.pre) + self._f.write(",".join(output)) + self._f.write(self.post) + if not last_row: + self._f.write(self.sep) # if next chunk add separator if self._chunks_written > 0: - self._f.write(",\n") + self._f.write(self.sep) # write rows for row in rows[:-1]: diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index cd1699adea..db043bae25 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -83,8 +83,17 @@ def commit_transaction(self) -> None: @raise_database_error def rollback_transaction(self) -> None: - self._conn.rollback() - self._conn.autocommit = True + try: + self._conn.rollback() + except pyodbc.ProgrammingError as ex: + if ( + ex.args[0] == "42000" and "(111214)" in ex.args[1] + ): # "no corresponding transaction found" + pass # there was nothing to rollback, we silently ignore the error + else: + raise + finally: + self._conn.autocommit = True @property def native_connection(self) -> pyodbc.Connection: diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py index 53dbabc090..f6ad7369c1 100644 --- a/dlt/destinations/impl/synapse/__init__.py +++ b/dlt/destinations/impl/synapse/__init__.py @@ -41,6 +41,12 @@ def capabilities() -> DestinationCapabilitiesContext: caps.supports_transactions = True caps.supports_ddl_transactions = False + # Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries." + # if number of records exceeds a certain number. Which exact number that is seems not deterministic: + # in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same. + # 10.000 records is a "safe" amount that always seems to work. + caps.max_rows_per_insert = 10000 + # datetimeoffset can store 7 digits for fractional seconds # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 caps.timestamp_precision = 7 diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 457e128ba0..60d792c854 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -10,19 +10,20 @@ from dlt.common.destination.reference import ( SupportsStagingDestination, NewLoadJob, - CredentialsConfiguration, ) from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint -from dlt.common.schema.utils import table_schema_has_type, get_inherited_table_hint -from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.schema.utils import ( + table_schema_has_type, + get_inherited_table_hint, + is_complete_column, +) from dlt.common.configuration.specs import AzureCredentialsWithoutDefaults from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob from dlt.destinations.exceptions import LoadJobTerminalException @@ -127,7 +128,7 @@ def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: if self.config.replace_strategy == "staging-optimized": - return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return [SynapseStagingCopyJob.from_table_chain(table_chain, self)] # type: ignore[arg-type] return super()._create_replace_followup_jobs(table_chain) def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: @@ -178,10 +179,11 @@ class SynapseStagingCopyJob(SqlStagingCopyJob): def generate_sql( cls, table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], + job_client: SynapseClient, # type: ignore[override] params: Optional[SqlJobParams] = None, ) -> List[str]: sql: List[str] = [] + sql_client = job_client.sql_client for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) @@ -194,11 +196,8 @@ def generate_sql( f" {staging_table_name};" ) # recreate staging table - job_client = current.pipeline().destination_client() # type: ignore[operator] with job_client.with_staging_dataset(): - # get table columns from schema - columns = [c for c in job_client.schema.get_table_columns(table["name"]).values()] - # generate CREATE TABLE statement + columns = list(filter(is_complete_column, table["columns"].values())) create_table_stmt = job_client._get_table_update_sql( table["name"], columns, generate_alter=False ) diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 776176078e..6a7c56832c 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -36,10 +36,14 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # the procedure below will split the inserts into max_query_length // 2 packs with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: header = f.readline() - if self._sql_client.capabilities.insert_values_writer_type == "default": + writer_type = self._sql_client.capabilities.insert_values_writer_type + if writer_type == "default": + sep = "," # properly formatted file has a values marker at the beginning values_mark = f.readline() assert values_mark == "VALUES\n" + elif writer_type == "select_union": + sep = " UNION ALL" max_rows = self._sql_client.capabilities.max_rows_per_insert @@ -57,9 +61,7 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # if there was anything left, until_nl contains the last line is_eof = len(until_nl) == 0 or until_nl[-1] == ";" if not is_eof: - # print(f'replace the "," with " {until_nl} {len(insert_sql)}') - until_nl = until_nl[:-1] + ";" - + until_nl = until_nl[: -len(sep)] + ";" # replace the separator with ";" if max_rows is not None: # mssql has a limit of 1000 rows per INSERT, so we need to split into separate statements values_rows = content.splitlines(keepends=True) @@ -69,25 +71,22 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st for chunk in chunks(values_rows, max_rows - 1): processed += len(chunk) insert_sql.append(header.format(qualified_table_name)) - if self._sql_client.capabilities.insert_values_writer_type == "default": + if writer_type == "default": insert_sql.append(values_mark) if processed == len_rows: # On the last chunk we need to add the extra row read insert_sql.append("".join(chunk) + until_nl) else: # Replace the , with ; - insert_sql.append("".join(chunk).strip()[:-1] + ";\n") + insert_sql.append("".join(chunk).strip()[: -len(sep)] + ";\n") else: # otherwise write all content in a single INSERT INTO - if self._sql_client.capabilities.insert_values_writer_type == "default": + if writer_type == "default": insert_sql.extend( - [header.format(qualified_table_name), values_mark, content] + [header.format(qualified_table_name), values_mark, content + until_nl] ) - elif self._sql_client.capabilities.insert_values_writer_type == "select_union": - insert_sql.extend([header.format(qualified_table_name), content]) - - if until_nl: - insert_sql.append(until_nl) + elif writer_type == "select_union": + insert_sql.extend([header.format(qualified_table_name), content + until_nl]) # actually this may be empty if we were able to read a full file into content if not is_eof: diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 1c79b733e5..e353ec34eb 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -7,18 +7,14 @@ from dlt.common.storages import FileStorage from dlt.common.utils import uniq_id -from dlt.destinations.exceptions import ( - DatabaseTerminalException, - DatabaseTransientException, - DatabaseUndefinedRelation, -) +from dlt.destinations.exceptions import DatabaseTerminalException from dlt.destinations.insert_job_client import InsertValuesJobClient -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, skipifpypy +from tests.utils import TEST_STORAGE_ROOT, skipifpypy from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.pipeline.utils import destinations_configs -DEFAULT_SUBSET = ["duckdb", "redshift", "postgres"] +DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse"] @pytest.fixture @@ -41,20 +37,28 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - user_table_name = prepare_table(client) canonical_name = client.sql_client.make_qualified_table_name(user_table_name) # create insert - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" + insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\n" + writer_type = client.capabilities.insert_values_writer_type + if writer_type == "default": + insert_sql += "VALUES\n" + pre, post, sep = ("(", ")", ",\n") + elif writer_type == "select_union": + pre, post, sep = ("SELECT ", "", " UNION ALL\n") insert_values = ( - f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," - f" '{str(pendulum.now())}')" + pre + + f"'{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}'" + + post ) expect_load_file(client, file_storage, insert_sql + insert_values + ";", user_table_name) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 1 # insert 100 more rows - query = insert_sql + (insert_values + ",\n") * 99 + insert_values + ";" + query = insert_sql + (insert_values + sep) * 99 + insert_values + ";" expect_load_file(client, file_storage, query, user_table_name) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 101 - # insert null value + # insert null value (single-record insert has same syntax for both writer types) insert_sql_nc = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, text)\nVALUES\n" insert_values_nc = ( f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," @@ -73,21 +77,30 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - ids=lambda x: x.name, ) def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage) -> None: - # test expected dbiapi exceptions for supported destinations - import duckdb - from dlt.destinations.impl.postgres.sql_client import psycopg2 - - TNotNullViolation = psycopg2.errors.NotNullViolation - TNumericValueOutOfRange = psycopg2.errors.NumericValueOutOfRange - TUndefinedColumn = psycopg2.errors.UndefinedColumn - TDatatypeMismatch = psycopg2.errors.DatatypeMismatch - if client.config.destination_type == "redshift": - # redshift does not know or psycopg does not recognize those correctly - TNotNullViolation = psycopg2.errors.InternalError_ - if client.config.destination_type == "duckdb": + # test expected dbapi exceptions for supported destinations + dtype = client.config.destination_type + if dtype in ("postgres", "redshift"): + from dlt.destinations.impl.postgres.sql_client import psycopg2 + + TNotNullViolation = psycopg2.errors.NotNullViolation + TNumericValueOutOfRange = psycopg2.errors.NumericValueOutOfRange + TUndefinedColumn = psycopg2.errors.UndefinedColumn + TDatatypeMismatch = psycopg2.errors.DatatypeMismatch + if dtype == "redshift": + # redshift does not know or psycopg does not recognize those correctly + TNotNullViolation = psycopg2.errors.InternalError_ + elif dtype == "duckdb": + import duckdb + TUndefinedColumn = duckdb.BinderException TNotNullViolation = duckdb.ConstraintException TNumericValueOutOfRange = TDatatypeMismatch = duckdb.ConversionException + elif dtype in ("mssql", "synapse"): + import pyodbc + + TUndefinedColumn = pyodbc.ProgrammingError + TNotNullViolation = pyodbc.IntegrityError + TNumericValueOutOfRange = TDatatypeMismatch = pyodbc.DataError user_table_name = prepare_table(client) # insert into unknown column @@ -107,7 +120,10 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage assert type(exv.value.dbapi_exception) is TNotNullViolation # insert wrong type insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', TRUE);" + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" {client.capabilities.escape_literal(True)});" + ) with pytest.raises(DatabaseTerminalException) as exv: expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) assert type(exv.value.dbapi_exception) is TDatatypeMismatch @@ -122,7 +138,7 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ) with pytest.raises(DatabaseTerminalException) as exv: expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) in (TNumericValueOutOfRange,) + assert type(exv.value.dbapi_exception) == TNumericValueOutOfRange # numeric overflow on NUMERIC insert_sql = ( "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," @@ -145,9 +161,10 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ) with pytest.raises(DatabaseTerminalException) as exv: expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) in ( - TNumericValueOutOfRange, - psycopg2.errors.InternalError_, + assert ( + type(exv.value.dbapi_exception) == psycopg2.errors.InternalError_ + if dtype == "redshift" + else TNumericValueOutOfRange ) @@ -159,7 +176,13 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ) def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -> None: mocked_caps = client.sql_client.__class__.capabilities - insert_sql = prepare_insert_statement(10) + writer_type = client.capabilities.insert_values_writer_type + insert_sql = prepare_insert_statement(10, writer_type) + + if writer_type == "default": + pre, post, sep = ("(", ")", ",\n") + elif writer_type == "select_union": + pre, post, sep = ("SELECT ", "", " UNION ALL\n") # this guarantees that we execute inserts line by line with patch.object(mocked_caps, "max_query_length", 2), patch.object( @@ -173,14 +196,19 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - for idx, call in enumerate(mocked_fragments.call_args_list): fragment: List[str] = call.args[0] # last elem of fragment is a data list, first element is id, and must end with ;\n - assert fragment[-1].startswith(f"'{idx}'") - assert fragment[-1].endswith(");") + start = pre + "'" + str(idx) + "'" + end = post + ";" + assert fragment[-1].startswith(start) + assert fragment[-1].endswith(end) assert_load_with_max_query(client, file_storage, 10, 2) - start_idx = insert_sql.find("S\n(") - idx = insert_sql.find("),\n", len(insert_sql) // 2) + if writer_type == "default": + start_idx = insert_sql.find("S\n" + pre) + elif writer_type == "select_union": + start_idx = insert_sql.find(pre) + idx = insert_sql.find(post + sep, len(insert_sql) // 2) - # set query length so it reads data until "," (followed by \n) + # set query length so it reads data until separator ("," or " UNION ALL") (followed by \n) query_length = (idx - start_idx - 1) * 2 with patch.object(mocked_caps, "max_query_length", query_length), patch.object( client.sql_client, "execute_fragments" @@ -197,11 +225,15 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - ) as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) - # split in 2 on ',' + # split in 2 on separator ("," or " UNION ALL") assert mocked_fragments.call_count == 2 # so it reads till the last ; - query_length = (len(insert_sql) - start_idx - 3) * 2 + if writer_type == "default": + offset = 3 + elif writer_type == "select_union": + offset = 1 + query_length = (len(insert_sql) - start_idx - offset) * 2 with patch.object(mocked_caps, "max_query_length", query_length), patch.object( client.sql_client, "execute_fragments" ) as mocked_fragments: @@ -221,31 +253,31 @@ def assert_load_with_max_query( mocked_caps = client.sql_client.__class__.capabilities with patch.object(mocked_caps, "max_query_length", max_query_length): user_table_name = prepare_table(client) - insert_sql = prepare_insert_statement(insert_lines) + insert_sql = prepare_insert_statement( + insert_lines, client.capabilities.insert_values_writer_type + ) expect_load_file(client, file_storage, insert_sql, user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {user_table_name}")[0][0] + canonical_name = client.sql_client.make_qualified_table_name(user_table_name) + rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == insert_lines # get all uniq ids in order - with client.sql_client.execute_query( - f"SELECT _dlt_id FROM {user_table_name} ORDER BY timestamp ASC;" - ) as c: - rows = list(c.fetchall()) - v_ids = list(map(lambda i: i[0], rows)) + rows = client.sql_client.execute_sql( + f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;" + ) + v_ids = list(map(lambda i: i[0], rows)) assert list(map(str, range(0, insert_lines))) == v_ids - client.sql_client.execute_sql(f"DELETE FROM {user_table_name}") + client.sql_client.execute_sql(f"DELETE FROM {canonical_name}") -def prepare_insert_statement(lines: int) -> str: - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}')" - # ids = [] +def prepare_insert_statement(lines: int, writer_type: str = "default") -> str: + insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\n" + if writer_type == "default": + insert_sql += "VALUES\n" + pre, post, sep = ("(", ")", ",\n") + elif writer_type == "select_union": + pre, post, sep = ("SELECT ", "", " UNION ALL\n") + insert_values = pre + "'{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'" + post for i in range(lines): - # id_ = uniq_id() - # ids.append(id_) insert_sql += insert_values.format(str(i), uniq_id(), str(pendulum.now().add(seconds=i))) - if i < 9: - insert_sql += ",\n" - else: - insert_sql += ";" - # print(insert_sql) + insert_sql += sep if i < 9 else ";" return insert_sql