From 75be2ce54ccb486679ca1b177551c3097a2f3908 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 23 Jan 2024 20:26:06 +0100 Subject: [PATCH] rewrite naive code to prevent IndexError --- dlt/destinations/impl/synapse/configuration.py | 14 +++++++++----- dlt/pipeline/pipeline.py | 7 ++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index b5eec82e9e..119c55ad7a 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -1,8 +1,9 @@ from typing import Final, Any, List, Dict, Optional, ClassVar -from dlt.common.configuration import configspec -from dlt.common.schema.typing import TTableIndexType, TWriteDisposition from dlt.common import logger +from dlt.common.configuration import configspec +from dlt.common.schema.typing import TTableIndexType, TSchemaTables +from dlt.common.schema.utils import get_write_disposition from dlt.destinations.impl.mssql.configuration import ( MsSqlCredentials, @@ -60,13 +61,16 @@ class SynapseClientConfiguration(MsSqlClientConfiguration): "auto_disable_concurrency", ] - def get_load_workers(self, write_disposition: TWriteDisposition, workers: int) -> int: + def get_load_workers(self, tables: TSchemaTables, workers: int) -> int: + """Returns the adjusted number of load workers to prevent concurrency issues.""" + + write_dispositions = [get_write_disposition(tables, table_name) for table_name in tables] + n_replace_dispositions = len([d for d in write_dispositions if d == "replace"]) if ( - write_disposition == "replace" + n_replace_dispositions > 1 and self.replace_strategy == "staging-optimized" and workers > 1 ): - print("auto_disable_concurrency:", self.auto_disable_concurrency) warning_msg_shared = ( 'Data is being loaded into Synapse with write disposition "replace"' ' and replace strategy "staging-optimized", while the number of' diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 44a2cbdfdb..3a0a8f3931 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -45,7 +45,7 @@ TAnySchemaColumns, TSchemaContract, ) -from dlt.common.schema.utils import normalize_schema_name, get_write_disposition +from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound from dlt.common.typing import DictStrStr, TFun, TSecretValue, is_optional_type from dlt.common.runners import pool_runner as runner @@ -485,10 +485,7 @@ def load( # for synapse we might need to adjust the number of load workers if self.destination.destination_name == "synapse": - write_disposition = get_write_disposition( - self.default_schema.tables, self.default_schema.data_table_names()[0] - ) - workers = client.config.get_load_workers(write_disposition, workers) # type: ignore[attr-defined] + workers = client.config.get_load_workers(self.default_schema.tables, workers) # type: ignore[attr-defined] # create default loader config and the loader load_config = LoaderConfiguration(