From c394533a183c17f553860850deb7c50693bd0ad8 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 2 Apr 2024 19:22:16 +0400 Subject: [PATCH 1/9] add chunk separation handling for select_union writer type --- dlt/common/data_writers/writers.py | 5 ++- .../pipeline/test_insert_values_writer.py | 33 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 tests/load/pipeline/test_insert_values_writer.py diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 2aadb010e0..f977b42a48 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -202,7 +202,10 @@ def write_row(row: StrAny, last_row: bool = False) -> None: # if next chunk add separator if self._chunks_written > 0: - self._f.write(",\n") + if self._caps.insert_values_writer_type == "default": + self._f.write(",\n") + elif self._caps.insert_values_writer_type == "select_union": + self._f.write("\nUNION ALL\n") # write rows for row in rows[:-1]: diff --git a/tests/load/pipeline/test_insert_values_writer.py b/tests/load/pipeline/test_insert_values_writer.py new file mode 100644 index 0000000000..18676bb344 --- /dev/null +++ b/tests/load/pipeline/test_insert_values_writer.py @@ -0,0 +1,33 @@ +import os +import pytest + +import dlt + +from tests.pipeline.utils import assert_load_info, load_table_counts +from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "synapse"]), + ids=lambda x: x.name, +) +def test_buffering(destination_config: DestinationTestConfiguration) -> None: + @dlt.resource(write_disposition="replace") + def items(): + yield [{"id": i} for i in range(10)] + + # set buffer size less than number of data items + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "5" + + # ensure both writer types are tested + p = destination_config.setup_pipeline("abstract", full_refresh=True) + if destination_config.destination == "duckdb": + assert p.destination.capabilities().insert_values_writer_type == "default" + elif destination_config.destination == "synapse": + assert p.destination.capabilities().insert_values_writer_type == "select_union" + + # run pipeline and assert expectations + info = p.run(items()) + assert_load_info(info) + assert load_table_counts(p, "items")["items"] == 10 From 609dd5647e9ad0c08952ce0adf19be0cb904a8ff Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 5 Apr 2024 15:06:26 +0400 Subject: [PATCH 2/9] include synapse in insert job client tests and make them pass --- dlt/common/data_writers/writers.py | 33 +++---- dlt/destinations/impl/mssql/sql_client.py | 13 ++- dlt/destinations/insert_job_client.py | 17 ++-- tests/load/test_insert_job_client.py | 113 ++++++++++++++-------- 4 files changed, 105 insertions(+), 71 deletions(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index f977b42a48..468248c00a 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -166,6 +166,11 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N super().__init__(f, caps) self._chunks_written = 0 self._headers_lookup: Dict[str, int] = None + self.writer_type = caps.insert_values_writer_type + if self.writer_type == "default": + self.pre, self.post, self.sep = ("(", ")", ",\n") + elif self.writer_type == "select_union": + self.pre, self.post, self.sep = ("SELECT ", " ", "UNION ALL\n") def write_header(self, columns_schema: TTableSchemaColumns) -> None: assert self._chunks_written == 0 @@ -176,10 +181,9 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: # do not write INSERT INTO command, this must be added together with table name by the loader self._f.write("INSERT INTO {}(") self._f.write(",".join(map(self._caps.escape_identifier, headers))) - if self._caps.insert_values_writer_type == "default": - self._f.write(")\nVALUES\n") - elif self._caps.insert_values_writer_type == "select_union": - self._f.write(")\n") + self._f.write(")\n") + if self.writer_type == "default": + self._f.write("VALUES\n") def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) @@ -188,24 +192,15 @@ def write_row(row: StrAny, last_row: bool = False) -> None: output = ["NULL"] * len(self._headers_lookup) for n, v in row.items(): output[self._headers_lookup[n]] = self._caps.escape_literal(v) - if self._caps.insert_values_writer_type == "default": - self._f.write("(") - self._f.write(",".join(output)) - self._f.write(")") - if not last_row: - self._f.write(",\n") - elif self._caps.insert_values_writer_type == "select_union": - self._f.write("SELECT ") - self._f.write(",".join(output)) - if not last_row: - self._f.write("\nUNION ALL\n") + self._f.write(self.pre) + self._f.write(",".join(output)) + self._f.write(self.post) + if not last_row: + self._f.write(self.sep) # if next chunk add separator if self._chunks_written > 0: - if self._caps.insert_values_writer_type == "default": - self._f.write(",\n") - elif self._caps.insert_values_writer_type == "select_union": - self._f.write("\nUNION ALL\n") + self._f.write(self.sep) # write rows for row in rows[:-1]: diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index cd1699adea..db043bae25 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -83,8 +83,17 @@ def commit_transaction(self) -> None: @raise_database_error def rollback_transaction(self) -> None: - self._conn.rollback() - self._conn.autocommit = True + try: + self._conn.rollback() + except pyodbc.ProgrammingError as ex: + if ( + ex.args[0] == "42000" and "(111214)" in ex.args[1] + ): # "no corresponding transaction found" + pass # there was nothing to rollback, we silently ignore the error + else: + raise + finally: + self._conn.autocommit = True @property def native_connection(self) -> pyodbc.Connection: diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 776176078e..c25e8b9384 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -36,7 +36,8 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # the procedure below will split the inserts into max_query_length // 2 packs with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: header = f.readline() - if self._sql_client.capabilities.insert_values_writer_type == "default": + writer_type = self._sql_client.capabilities.insert_values_writer_type + if writer_type == "default": # properly formatted file has a values marker at the beginning values_mark = f.readline() assert values_mark == "VALUES\n" @@ -57,9 +58,11 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # if there was anything left, until_nl contains the last line is_eof = len(until_nl) == 0 or until_nl[-1] == ";" if not is_eof: - # print(f'replace the "," with " {until_nl} {len(insert_sql)}') - until_nl = until_nl[:-1] + ";" - + if writer_type == "default": + sep = "," + elif writer_type == "select_union": + sep = " UNION ALL" + until_nl = until_nl[: -len(sep)] + ";" # replace the separator with ";" if max_rows is not None: # mssql has a limit of 1000 rows per INSERT, so we need to split into separate statements values_rows = content.splitlines(keepends=True) @@ -69,7 +72,7 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st for chunk in chunks(values_rows, max_rows - 1): processed += len(chunk) insert_sql.append(header.format(qualified_table_name)) - if self._sql_client.capabilities.insert_values_writer_type == "default": + if writer_type == "default": insert_sql.append(values_mark) if processed == len_rows: # On the last chunk we need to add the extra row read @@ -79,11 +82,11 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st insert_sql.append("".join(chunk).strip()[:-1] + ";\n") else: # otherwise write all content in a single INSERT INTO - if self._sql_client.capabilities.insert_values_writer_type == "default": + if writer_type == "default": insert_sql.extend( [header.format(qualified_table_name), values_mark, content] ) - elif self._sql_client.capabilities.insert_values_writer_type == "select_union": + elif writer_type == "select_union": insert_sql.extend([header.format(qualified_table_name), content]) if until_nl: diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 1c79b733e5..75c8440672 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -7,18 +7,14 @@ from dlt.common.storages import FileStorage from dlt.common.utils import uniq_id -from dlt.destinations.exceptions import ( - DatabaseTerminalException, - DatabaseTransientException, - DatabaseUndefinedRelation, -) +from dlt.destinations.exceptions import DatabaseTerminalException from dlt.destinations.insert_job_client import InsertValuesJobClient -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, skipifpypy +from tests.utils import TEST_STORAGE_ROOT, skipifpypy from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.pipeline.utils import destinations_configs -DEFAULT_SUBSET = ["duckdb", "redshift", "postgres"] +DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "synapse"] @pytest.fixture @@ -41,20 +37,28 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - user_table_name = prepare_table(client) canonical_name = client.sql_client.make_qualified_table_name(user_table_name) # create insert - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" + insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\n" + writer_type = client.capabilities.insert_values_writer_type + if writer_type == "default": + insert_sql += "VALUES\n" + pre, post, sep = ("(", ")", ",\n") + elif writer_type == "select_union": + pre, post, sep = ("SELECT ", " ", "UNION ALL\n") insert_values = ( - f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," - f" '{str(pendulum.now())}')" + pre + + f"'{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" '{str(pendulum.now())}'" + + post ) expect_load_file(client, file_storage, insert_sql + insert_values + ";", user_table_name) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 1 # insert 100 more rows - query = insert_sql + (insert_values + ",\n") * 99 + insert_values + ";" + query = insert_sql + (insert_values + sep) * 99 + insert_values + ";" expect_load_file(client, file_storage, query, user_table_name) rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == 101 - # insert null value + # insert null value (single-record insert has same syntax for both writer types) insert_sql_nc = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, text)\nVALUES\n" insert_values_nc = ( f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," @@ -73,8 +77,9 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - ids=lambda x: x.name, ) def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage) -> None: - # test expected dbiapi exceptions for supported destinations + # test expected dbapi exceptions for supported destinations import duckdb + import pyodbc from dlt.destinations.impl.postgres.sql_client import psycopg2 TNotNullViolation = psycopg2.errors.NotNullViolation @@ -88,6 +93,10 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage TUndefinedColumn = duckdb.BinderException TNotNullViolation = duckdb.ConstraintException TNumericValueOutOfRange = TDatatypeMismatch = duckdb.ConversionException + if client.config.destination_type == "synapse": + TUndefinedColumn = pyodbc.ProgrammingError + TNotNullViolation = pyodbc.IntegrityError + TNumericValueOutOfRange = TDatatypeMismatch = pyodbc.DataError user_table_name = prepare_table(client) # insert into unknown column @@ -107,7 +116,10 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage assert type(exv.value.dbapi_exception) is TNotNullViolation # insert wrong type insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', TRUE);" + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," + f" {client.capabilities.escape_literal(True)});" + ) with pytest.raises(DatabaseTerminalException) as exv: expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) assert type(exv.value.dbapi_exception) is TDatatypeMismatch @@ -122,7 +134,7 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ) with pytest.raises(DatabaseTerminalException) as exv: expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) in (TNumericValueOutOfRange,) + assert type(exv.value.dbapi_exception) == TNumericValueOutOfRange # numeric overflow on NUMERIC insert_sql = ( "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," @@ -159,7 +171,8 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ) def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -> None: mocked_caps = client.sql_client.__class__.capabilities - insert_sql = prepare_insert_statement(10) + writer_type = client.capabilities.insert_values_writer_type + insert_sql = prepare_insert_statement(10, writer_type) # this guarantees that we execute inserts line by line with patch.object(mocked_caps, "max_query_length", 2), patch.object( @@ -173,14 +186,24 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - for idx, call in enumerate(mocked_fragments.call_args_list): fragment: List[str] = call.args[0] # last elem of fragment is a data list, first element is id, and must end with ;\n - assert fragment[-1].startswith(f"'{idx}'") - assert fragment[-1].endswith(");") + if writer_type == "default": + start = f"'{idx}'" + end = ");" + elif writer_type == "select_union": + start = f"ELECT '{idx}'" + end = ";" + assert fragment[-1].startswith(start) + assert fragment[-1].endswith(end) assert_load_with_max_query(client, file_storage, 10, 2) - start_idx = insert_sql.find("S\n(") - idx = insert_sql.find("),\n", len(insert_sql) // 2) + if writer_type == "default": + start_idx = insert_sql.find("S\n(") + idx = insert_sql.find("),\n", len(insert_sql) // 2) + elif writer_type == "select_union": + start_idx = insert_sql.find("SELECT ") + idx = insert_sql.find(" UNION ALL\n", len(insert_sql) // 2) - # set query length so it reads data until "," (followed by \n) + # set query length so it reads data until separator ("," or " UNION ALL") (followed by \n) query_length = (idx - start_idx - 1) * 2 with patch.object(mocked_caps, "max_query_length", query_length), patch.object( client.sql_client, "execute_fragments" @@ -197,11 +220,15 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - ) as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) - # split in 2 on ',' + # split in 2 on separator ("," or " UNION ALL") assert mocked_fragments.call_count == 2 # so it reads till the last ; - query_length = (len(insert_sql) - start_idx - 3) * 2 + if writer_type == "default": + offset = 3 + elif writer_type == "select_union": + offset = 1 + query_length = (len(insert_sql) - start_idx - offset) * 2 with patch.object(mocked_caps, "max_query_length", query_length), patch.object( client.sql_client, "execute_fragments" ) as mocked_fragments: @@ -221,31 +248,31 @@ def assert_load_with_max_query( mocked_caps = client.sql_client.__class__.capabilities with patch.object(mocked_caps, "max_query_length", max_query_length): user_table_name = prepare_table(client) - insert_sql = prepare_insert_statement(insert_lines) + insert_sql = prepare_insert_statement( + insert_lines, client.capabilities.insert_values_writer_type + ) expect_load_file(client, file_storage, insert_sql, user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {user_table_name}")[0][0] + canonical_name = client.sql_client.make_qualified_table_name(user_table_name) + rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] assert rows_count == insert_lines # get all uniq ids in order - with client.sql_client.execute_query( - f"SELECT _dlt_id FROM {user_table_name} ORDER BY timestamp ASC;" - ) as c: - rows = list(c.fetchall()) - v_ids = list(map(lambda i: i[0], rows)) + rows = client.sql_client.execute_sql( + f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;" + ) + v_ids = list(map(lambda i: i[0], rows)) assert list(map(str, range(0, insert_lines))) == v_ids - client.sql_client.execute_sql(f"DELETE FROM {user_table_name}") + client.sql_client.execute_sql(f"DELETE FROM {canonical_name}") -def prepare_insert_statement(lines: int) -> str: - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}')" - # ids = [] +def prepare_insert_statement(lines: int, writer_type: str = "default") -> str: + insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\n" + if writer_type == "default": + insert_sql += "VALUES\n" + pre, post, sep = ("(", ")", ",\n") + elif writer_type == "select_union": + pre, post, sep = ("SELECT ", " ", "UNION ALL\n") + insert_values = pre + "'{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'" + post for i in range(lines): - # id_ = uniq_id() - # ids.append(id_) insert_sql += insert_values.format(str(i), uniq_id(), str(pendulum.now().add(seconds=i))) - if i < 9: - insert_sql += ",\n" - else: - insert_sql += ";" - # print(insert_sql) + insert_sql += sep if i < 9 else ";" return insert_sql From 2fed10a8c5ca4711f8e7656302c697d8967ab818 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 5 Apr 2024 17:13:36 +0400 Subject: [PATCH 3/9] remove obsolete test --- .../pipeline/test_insert_values_writer.py | 33 ------------------- 1 file changed, 33 deletions(-) delete mode 100644 tests/load/pipeline/test_insert_values_writer.py diff --git a/tests/load/pipeline/test_insert_values_writer.py b/tests/load/pipeline/test_insert_values_writer.py deleted file mode 100644 index 18676bb344..0000000000 --- a/tests/load/pipeline/test_insert_values_writer.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -import pytest - -import dlt - -from tests.pipeline.utils import assert_load_info, load_table_counts -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - - -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["duckdb", "synapse"]), - ids=lambda x: x.name, -) -def test_buffering(destination_config: DestinationTestConfiguration) -> None: - @dlt.resource(write_disposition="replace") - def items(): - yield [{"id": i} for i in range(10)] - - # set buffer size less than number of data items - os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "5" - - # ensure both writer types are tested - p = destination_config.setup_pipeline("abstract", full_refresh=True) - if destination_config.destination == "duckdb": - assert p.destination.capabilities().insert_values_writer_type == "default" - elif destination_config.destination == "synapse": - assert p.destination.capabilities().insert_values_writer_type == "select_union" - - # run pipeline and assert expectations - info = p.run(items()) - assert_load_info(info) - assert load_table_counts(p, "items")["items"] == 10 From 7ebb64db27d44f6a7dfdd80a84497045181972cd Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 5 Apr 2024 17:30:14 +0400 Subject: [PATCH 4/9] set max_rows_per_insert to prevent error on larger queries in synapse --- dlt/destinations/impl/synapse/__init__.py | 6 ++++++ dlt/destinations/insert_job_client.py | 9 ++++----- tests/load/test_insert_job_client.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py index 53dbabc090..f6ad7369c1 100644 --- a/dlt/destinations/impl/synapse/__init__.py +++ b/dlt/destinations/impl/synapse/__init__.py @@ -41,6 +41,12 @@ def capabilities() -> DestinationCapabilitiesContext: caps.supports_transactions = True caps.supports_ddl_transactions = False + # Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries." + # if number of records exceeds a certain number. Which exact number that is seems not deterministic: + # in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same. + # 10.000 records is a "safe" amount that always seems to work. + caps.max_rows_per_insert = 10000 + # datetimeoffset can store 7 digits for fractional seconds # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 caps.timestamp_precision = 7 diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index c25e8b9384..e3ce7265d1 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -38,9 +38,12 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st header = f.readline() writer_type = self._sql_client.capabilities.insert_values_writer_type if writer_type == "default": + sep = "," # properly formatted file has a values marker at the beginning values_mark = f.readline() assert values_mark == "VALUES\n" + elif writer_type == "select_union": + sep = " UNION ALL" max_rows = self._sql_client.capabilities.max_rows_per_insert @@ -58,10 +61,6 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # if there was anything left, until_nl contains the last line is_eof = len(until_nl) == 0 or until_nl[-1] == ";" if not is_eof: - if writer_type == "default": - sep = "," - elif writer_type == "select_union": - sep = " UNION ALL" until_nl = until_nl[: -len(sep)] + ";" # replace the separator with ";" if max_rows is not None: # mssql has a limit of 1000 rows per INSERT, so we need to split into separate statements @@ -79,7 +78,7 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st insert_sql.append("".join(chunk) + until_nl) else: # Replace the , with ; - insert_sql.append("".join(chunk).strip()[:-1] + ";\n") + insert_sql.append("".join(chunk).strip()[: -len(sep)] + ";\n") else: # otherwise write all content in a single INSERT INTO if writer_type == "default": diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 75c8440672..bd20ea9930 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -190,7 +190,7 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - start = f"'{idx}'" end = ");" elif writer_type == "select_union": - start = f"ELECT '{idx}'" + start = f"SELECT '{idx}'" end = ";" assert fragment[-1].startswith(start) assert fragment[-1].endswith(end) From 764c766a1e5fb90fa68fc3cbc3ecee34fc3d599f Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 6 Apr 2024 13:07:50 +0400 Subject: [PATCH 5/9] remove pipeline dependency --- dlt/destinations/impl/synapse/synapse.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 457e128ba0..60d792c854 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -10,19 +10,20 @@ from dlt.common.destination.reference import ( SupportsStagingDestination, NewLoadJob, - CredentialsConfiguration, ) from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint -from dlt.common.schema.utils import table_schema_has_type, get_inherited_table_hint -from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.schema.utils import ( + table_schema_has_type, + get_inherited_table_hint, + is_complete_column, +) from dlt.common.configuration.specs import AzureCredentialsWithoutDefaults from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob from dlt.destinations.exceptions import LoadJobTerminalException @@ -127,7 +128,7 @@ def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: if self.config.replace_strategy == "staging-optimized": - return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return [SynapseStagingCopyJob.from_table_chain(table_chain, self)] # type: ignore[arg-type] return super()._create_replace_followup_jobs(table_chain) def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: @@ -178,10 +179,11 @@ class SynapseStagingCopyJob(SqlStagingCopyJob): def generate_sql( cls, table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], + job_client: SynapseClient, # type: ignore[override] params: Optional[SqlJobParams] = None, ) -> List[str]: sql: List[str] = [] + sql_client = job_client.sql_client for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) @@ -194,11 +196,8 @@ def generate_sql( f" {staging_table_name};" ) # recreate staging table - job_client = current.pipeline().destination_client() # type: ignore[operator] with job_client.with_staging_dataset(): - # get table columns from schema - columns = [c for c in job_client.schema.get_table_columns(table["name"]).values()] - # generate CREATE TABLE statement + columns = list(filter(is_complete_column, table["columns"].values())) create_table_stmt = job_client._get_table_update_sql( table["name"], columns, generate_alter=False ) From b2e64c987872945526414e858849b2bf6c87a5e1 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 6 Apr 2024 13:27:51 +0400 Subject: [PATCH 6/9] make imports conditional on destination type --- tests/load/test_insert_job_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index bd20ea9930..2b7bfd1341 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -78,8 +78,6 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - ) def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage) -> None: # test expected dbapi exceptions for supported destinations - import duckdb - import pyodbc from dlt.destinations.impl.postgres.sql_client import psycopg2 TNotNullViolation = psycopg2.errors.NotNullViolation @@ -90,10 +88,14 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage # redshift does not know or psycopg does not recognize those correctly TNotNullViolation = psycopg2.errors.InternalError_ if client.config.destination_type == "duckdb": + import duckdb + TUndefinedColumn = duckdb.BinderException TNotNullViolation = duckdb.ConstraintException TNumericValueOutOfRange = TDatatypeMismatch = duckdb.ConversionException if client.config.destination_type == "synapse": + import pyodbc + TUndefinedColumn = pyodbc.ProgrammingError TNotNullViolation = pyodbc.IntegrityError TNumericValueOutOfRange = TDatatypeMismatch = pyodbc.DataError From 1e9fc8a42552cff22c499097c20b19bbe843c367 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 6 Apr 2024 17:42:51 +0400 Subject: [PATCH 7/9] include synapse in insert job client tests and make them pass --- dlt/common/data_writers/writers.py | 2 +- dlt/destinations/insert_job_client.py | 7 ++----- tests/load/test_insert_job_client.py | 28 +++++++++++++-------------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 468248c00a..4a4029c6f3 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -170,7 +170,7 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N if self.writer_type == "default": self.pre, self.post, self.sep = ("(", ")", ",\n") elif self.writer_type == "select_union": - self.pre, self.post, self.sep = ("SELECT ", " ", "UNION ALL\n") + self.pre, self.post, self.sep = ("SELECT ", "", " UNION ALL\n") def write_header(self, columns_schema: TTableSchemaColumns) -> None: assert self._chunks_written == 0 diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index e3ce7265d1..6a7c56832c 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -83,13 +83,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # otherwise write all content in a single INSERT INTO if writer_type == "default": insert_sql.extend( - [header.format(qualified_table_name), values_mark, content] + [header.format(qualified_table_name), values_mark, content + until_nl] ) elif writer_type == "select_union": - insert_sql.extend([header.format(qualified_table_name), content]) - - if until_nl: - insert_sql.append(until_nl) + insert_sql.extend([header.format(qualified_table_name), content + until_nl]) # actually this may be empty if we were able to read a full file into content if not is_eof: diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 2b7bfd1341..100464b5a7 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -14,7 +14,7 @@ from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage from tests.load.pipeline.utils import destinations_configs -DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "synapse"] +DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse"] @pytest.fixture @@ -43,7 +43,7 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - insert_sql += "VALUES\n" pre, post, sep = ("(", ")", ",\n") elif writer_type == "select_union": - pre, post, sep = ("SELECT ", " ", "UNION ALL\n") + pre, post, sep = ("SELECT ", "", " UNION ALL\n") insert_values = ( pre + f"'{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," @@ -93,7 +93,7 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage TUndefinedColumn = duckdb.BinderException TNotNullViolation = duckdb.ConstraintException TNumericValueOutOfRange = TDatatypeMismatch = duckdb.ConversionException - if client.config.destination_type == "synapse": + if client.config.destination_type in ("mssql", "synapse"): import pyodbc TUndefinedColumn = pyodbc.ProgrammingError @@ -176,6 +176,11 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - writer_type = client.capabilities.insert_values_writer_type insert_sql = prepare_insert_statement(10, writer_type) + if writer_type == "default": + pre, post, sep = ("(", ")", ",\n") + elif writer_type == "select_union": + pre, post, sep = ("SELECT ", "", " UNION ALL\n") + # this guarantees that we execute inserts line by line with patch.object(mocked_caps, "max_query_length", 2), patch.object( client.sql_client, "execute_fragments" @@ -188,22 +193,17 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - for idx, call in enumerate(mocked_fragments.call_args_list): fragment: List[str] = call.args[0] # last elem of fragment is a data list, first element is id, and must end with ;\n - if writer_type == "default": - start = f"'{idx}'" - end = ");" - elif writer_type == "select_union": - start = f"SELECT '{idx}'" - end = ";" + start = pre + "'" + str(idx) + "'" + end = post + ";" assert fragment[-1].startswith(start) assert fragment[-1].endswith(end) assert_load_with_max_query(client, file_storage, 10, 2) if writer_type == "default": - start_idx = insert_sql.find("S\n(") - idx = insert_sql.find("),\n", len(insert_sql) // 2) + start_idx = insert_sql.find("S\n" + pre) elif writer_type == "select_union": - start_idx = insert_sql.find("SELECT ") - idx = insert_sql.find(" UNION ALL\n", len(insert_sql) // 2) + start_idx = insert_sql.find(pre) + idx = insert_sql.find(post + sep, len(insert_sql) // 2) # set query length so it reads data until separator ("," or " UNION ALL") (followed by \n) query_length = (idx - start_idx - 1) * 2 @@ -272,7 +272,7 @@ def prepare_insert_statement(lines: int, writer_type: str = "default") -> str: insert_sql += "VALUES\n" pre, post, sep = ("(", ")", ",\n") elif writer_type == "select_union": - pre, post, sep = ("SELECT ", " ", "UNION ALL\n") + pre, post, sep = ("SELECT ", "", " UNION ALL\n") insert_values = pre + "'{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'" + post for i in range(lines): insert_sql += insert_values.format(str(i), uniq_id(), str(pendulum.now().add(seconds=i))) From 7636a84842ed568d47afe1ec7dee8de91449d296 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 6 Apr 2024 17:42:51 +0400 Subject: [PATCH 8/9] include mssql in insert job client tests and make them pass --- dlt/common/data_writers/writers.py | 2 +- dlt/destinations/insert_job_client.py | 7 ++----- tests/load/test_insert_job_client.py | 28 +++++++++++++-------------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 468248c00a..4a4029c6f3 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -170,7 +170,7 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N if self.writer_type == "default": self.pre, self.post, self.sep = ("(", ")", ",\n") elif self.writer_type == "select_union": - self.pre, self.post, self.sep = ("SELECT ", " ", "UNION ALL\n") + self.pre, self.post, self.sep = ("SELECT ", "", " UNION ALL\n") def write_header(self, columns_schema: TTableSchemaColumns) -> None: assert self._chunks_written == 0 diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index e3ce7265d1..6a7c56832c 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -83,13 +83,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # otherwise write all content in a single INSERT INTO if writer_type == "default": insert_sql.extend( - [header.format(qualified_table_name), values_mark, content] + [header.format(qualified_table_name), values_mark, content + until_nl] ) elif writer_type == "select_union": - insert_sql.extend([header.format(qualified_table_name), content]) - - if until_nl: - insert_sql.append(until_nl) + insert_sql.extend([header.format(qualified_table_name), content + until_nl]) # actually this may be empty if we were able to read a full file into content if not is_eof: diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 2b7bfd1341..100464b5a7 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -14,7 +14,7 @@ from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage from tests.load.pipeline.utils import destinations_configs -DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "synapse"] +DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse"] @pytest.fixture @@ -43,7 +43,7 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - insert_sql += "VALUES\n" pre, post, sep = ("(", ")", ",\n") elif writer_type == "select_union": - pre, post, sep = ("SELECT ", " ", "UNION ALL\n") + pre, post, sep = ("SELECT ", "", " UNION ALL\n") insert_values = ( pre + f"'{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," @@ -93,7 +93,7 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage TUndefinedColumn = duckdb.BinderException TNotNullViolation = duckdb.ConstraintException TNumericValueOutOfRange = TDatatypeMismatch = duckdb.ConversionException - if client.config.destination_type == "synapse": + if client.config.destination_type in ("mssql", "synapse"): import pyodbc TUndefinedColumn = pyodbc.ProgrammingError @@ -176,6 +176,11 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - writer_type = client.capabilities.insert_values_writer_type insert_sql = prepare_insert_statement(10, writer_type) + if writer_type == "default": + pre, post, sep = ("(", ")", ",\n") + elif writer_type == "select_union": + pre, post, sep = ("SELECT ", "", " UNION ALL\n") + # this guarantees that we execute inserts line by line with patch.object(mocked_caps, "max_query_length", 2), patch.object( client.sql_client, "execute_fragments" @@ -188,22 +193,17 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - for idx, call in enumerate(mocked_fragments.call_args_list): fragment: List[str] = call.args[0] # last elem of fragment is a data list, first element is id, and must end with ;\n - if writer_type == "default": - start = f"'{idx}'" - end = ");" - elif writer_type == "select_union": - start = f"SELECT '{idx}'" - end = ";" + start = pre + "'" + str(idx) + "'" + end = post + ";" assert fragment[-1].startswith(start) assert fragment[-1].endswith(end) assert_load_with_max_query(client, file_storage, 10, 2) if writer_type == "default": - start_idx = insert_sql.find("S\n(") - idx = insert_sql.find("),\n", len(insert_sql) // 2) + start_idx = insert_sql.find("S\n" + pre) elif writer_type == "select_union": - start_idx = insert_sql.find("SELECT ") - idx = insert_sql.find(" UNION ALL\n", len(insert_sql) // 2) + start_idx = insert_sql.find(pre) + idx = insert_sql.find(post + sep, len(insert_sql) // 2) # set query length so it reads data until separator ("," or " UNION ALL") (followed by \n) query_length = (idx - start_idx - 1) * 2 @@ -272,7 +272,7 @@ def prepare_insert_statement(lines: int, writer_type: str = "default") -> str: insert_sql += "VALUES\n" pre, post, sep = ("(", ")", ",\n") elif writer_type == "select_union": - pre, post, sep = ("SELECT ", " ", "UNION ALL\n") + pre, post, sep = ("SELECT ", "", " UNION ALL\n") insert_values = pre + "'{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'" + post for i in range(lines): insert_sql += insert_values.format(str(i), uniq_id(), str(pendulum.now().add(seconds=i))) From ee1910b72d9dac4b529e5496972da9f01b105acb Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 6 Apr 2024 22:42:35 +0400 Subject: [PATCH 9/9] make psycopg2 import conditional on destination type --- tests/load/test_insert_job_client.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 100464b5a7..e353ec34eb 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -78,22 +78,24 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - ) def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage) -> None: # test expected dbapi exceptions for supported destinations - from dlt.destinations.impl.postgres.sql_client import psycopg2 + dtype = client.config.destination_type + if dtype in ("postgres", "redshift"): + from dlt.destinations.impl.postgres.sql_client import psycopg2 - TNotNullViolation = psycopg2.errors.NotNullViolation - TNumericValueOutOfRange = psycopg2.errors.NumericValueOutOfRange - TUndefinedColumn = psycopg2.errors.UndefinedColumn - TDatatypeMismatch = psycopg2.errors.DatatypeMismatch - if client.config.destination_type == "redshift": - # redshift does not know or psycopg does not recognize those correctly - TNotNullViolation = psycopg2.errors.InternalError_ - if client.config.destination_type == "duckdb": + TNotNullViolation = psycopg2.errors.NotNullViolation + TNumericValueOutOfRange = psycopg2.errors.NumericValueOutOfRange + TUndefinedColumn = psycopg2.errors.UndefinedColumn + TDatatypeMismatch = psycopg2.errors.DatatypeMismatch + if dtype == "redshift": + # redshift does not know or psycopg does not recognize those correctly + TNotNullViolation = psycopg2.errors.InternalError_ + elif dtype == "duckdb": import duckdb TUndefinedColumn = duckdb.BinderException TNotNullViolation = duckdb.ConstraintException TNumericValueOutOfRange = TDatatypeMismatch = duckdb.ConversionException - if client.config.destination_type in ("mssql", "synapse"): + elif dtype in ("mssql", "synapse"): import pyodbc TUndefinedColumn = pyodbc.ProgrammingError @@ -159,9 +161,10 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ) with pytest.raises(DatabaseTerminalException) as exv: expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) in ( - TNumericValueOutOfRange, - psycopg2.errors.InternalError_, + assert ( + type(exv.value.dbapi_exception) == psycopg2.errors.InternalError_ + if dtype == "redshift" + else TNumericValueOutOfRange )