Skip to content

Commit

Permalink
fixes escape identifiers to column escape
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Jun 13, 2024
1 parent 036e3dd commit 8546763
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 56 deletions.
4 changes: 2 additions & 2 deletions dlt/destinations/impl/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non
)

return (
f"{self.capabilities.escape_identifier(c['name'])} {type_with_nullability_modifier} {hints_str}"
f"{self.sql_client.escape_column_name(c['name'])} {type_with_nullability_modifier} {hints_str}"
.strip()
)

Expand Down Expand Up @@ -356,7 +356,7 @@ def _get_table_update_sql(
sql[0] = f"{sql[0]}\nENGINE = {TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR.get(table_type)}"

if primary_key_list := [
self.capabilities.escape_identifier(c["name"])
self.sql_client.escape_column_name(c["name"])
for c in new_columns
if c.get("primary_key")
]:
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/impl/dremio/dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ def _get_table_update_sql(

if not generate_alter:
partition_list = [
self.capabilities.escape_identifier(c["name"])
self.sql_client.escape_column_name(c["name"])
for c in new_columns
if c.get("partition")
]
if partition_list:
sql[0] += "\nPARTITION BY (" + ",".join(partition_list) + ")"

sort_list = [
self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("sort")
self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("sort")
]
if sort_list:
sql[0] += "\nLOCALSORT BY (" + ",".join(sort_list) + ")"
Expand All @@ -198,7 +198,7 @@ def _from_db_type(
return self.type_mapper.from_db_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
name = self.capabilities.escape_identifier(c["name"])
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}"
)
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.capabilities.escape_identifier(c["name"])
column_name = self.sql_client.escape_column_name(c["name"])
return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}"

def _create_replace_followup_jobs(
Expand Down
6 changes: 1 addition & 5 deletions dlt/destinations/impl/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
from dlt.destinations.sql_jobs import SqlMergeJob
from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException
from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob
from dlt.destinations.impl.postgres.configuration import PostgresCredentials
from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient
from dlt.destinations.impl.redshift import capabilities
from dlt.destinations.impl.redshift.configuration import RedshiftClientConfiguration
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_client import SqlClientBase
Expand Down Expand Up @@ -148,7 +146,6 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None:
"CREDENTIALS"
f" 'aws_access_key_id={aws_access_key};aws_secret_access_key={aws_secret_key}'"
)
table_name = table["name"]

# get format
ext = os.path.splitext(bucket_path)[1][1:]
Expand Down Expand Up @@ -188,10 +185,9 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None:
raise ValueError(f"Unsupported file type {ext} for Redshift.")

with self._sql_client.begin_transaction():
dataset_name = self._sql_client.dataset_name
# TODO: if we ever support csv here remember to add column names to COPY
self._sql_client.execute_sql(f"""
COPY {dataset_name}.{table_name}
COPY {self._sql_client.make_qualified_table_name(table['name'])}
FROM '{bucket_path}'
{file_type}
{dateformat}
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _generate_insert_sql(
table_name = sql_client.make_qualified_table_name(table["name"])
columns = ", ".join(
map(
sql_client.capabilities.escape_identifier,
sql_client.escape_column_name,
get_columns_names_with_prop(table, "name"),
)
)
Expand Down
4 changes: 3 additions & 1 deletion dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,7 +1612,9 @@ def _bump_version_and_extract_state(
schema, reuse_exiting_package=True
)
data, doc = state_resource(state, load_id)
extract_.original_data = data
# keep the original data to be used in the metrics
if extract_.original_data is None:
extract_.original_data = data
# append pipeline state to package state
load_package_state_update = load_package_state_update or {}
load_package_state_update["pipeline_state"] = doc
Expand Down
12 changes: 12 additions & 0 deletions tests/common/storages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
)
from dlt.common.storages import DataItemStorage, FileStorage
from dlt.common.storages.fsspec_filesystem import FileItem, FileItemDict
from dlt.common.storages.schema_storage import SchemaStorage
from dlt.common.typing import StrAny, TDataItems
from dlt.common.utils import uniq_id

from tests.common.utils import load_yml_case

TEST_SAMPLE_FILES = "tests/common/storages/samples"
MINIMALLY_EXPECTED_RELATIVE_PATHS = {
"csv/freshman_kgs.csv",
Expand Down Expand Up @@ -199,3 +202,12 @@ def assert_package_info(
# get dict
package_info.asdict()
return package_info


def prepare_eth_import_folder(storage: SchemaStorage) -> Schema:
eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9")
# remove processing hints before installing as import schema
# ethereum schema is a "dirty" schema with processing hints
eth = Schema.from_dict(eth_V9, remove_processing_hints=True)
storage._export_schema(eth, storage.config.import_schema_path)
return eth
35 changes: 31 additions & 4 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from dlt.common.destination.reference import WithStagingDataset
from dlt.common.schema.exceptions import CannotCoerceColumnException
from dlt.common.schema.schema import Schema
from dlt.common.schema.typing import VERSION_TABLE_NAME
from dlt.common.schema.typing import PIPELINE_STATE_TABLE_NAME, VERSION_TABLE_NAME
from dlt.common.schema.utils import pipeline_state_table
from dlt.common.typing import TDataItem
from dlt.common.utils import uniq_id

Expand Down Expand Up @@ -137,10 +138,27 @@ def data_fun() -> Iterator[Any]:
destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True),
ids=lambda x: x.name,
)
def test_default_schema_name(destination_config: DestinationTestConfiguration) -> None:
@pytest.mark.parametrize("use_single_dataset", [True, False])
@pytest.mark.parametrize(
"naming_convention",
[
"duck_case",
"snake_case",
"sql_cs_v1",
],
)
def test_default_schema_name(
destination_config: DestinationTestConfiguration,
use_single_dataset: bool,
naming_convention: str,
) -> None:
os.environ["SCHEMA__NAMING"] = naming_convention
destination_config.setup()
dataset_name = "dataset_" + uniq_id()
data = ["a", "b", "c"]
data = [
{"id": idx, "CamelInfo": uniq_id(), "GEN_ERIC": alpha}
for idx, alpha in [(0, "A"), (0, "B"), (0, "C")]
]

p = dlt.pipeline(
"test_default_schema_name",
Expand All @@ -149,16 +167,25 @@ def test_default_schema_name(destination_config: DestinationTestConfiguration) -
staging=destination_config.staging,
dataset_name=dataset_name,
)
p.config.use_single_dataset = use_single_dataset
p.extract(data, table_name="test", schema=Schema("default"))
p.normalize()
info = p.load()
print(info)

# try to restore pipeline
r_p = dlt.attach("test_default_schema_name", TEST_STORAGE_ROOT)
schema = r_p.default_schema
assert schema.name == "default"

assert_table(p, "test", data, info=info)
# check if dlt ables have exactly the required schemas
# TODO: uncomment to check dlt tables schemas
# assert (
# r_p.default_schema.tables[PIPELINE_STATE_TABLE_NAME]["columns"]
# == pipeline_state_table()["columns"]
# )

# assert_table(p, "test", data, info=info)


@pytest.mark.parametrize(
Expand Down
31 changes: 23 additions & 8 deletions tests/load/pipeline/test_restore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dlt
from dlt.common import pendulum
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema.schema import Schema, utils
from dlt.common.schema.utils import normalize_table_identifiers
from dlt.common.utils import uniq_id
Expand Down Expand Up @@ -199,7 +200,6 @@ def test_silently_skip_on_invalid_credentials(
[
"tests.common.cases.normalizers.title_case",
"snake_case",
"tests.common.cases.normalizers.sql_upper",
],
)
def test_get_schemas_from_destination(
Expand All @@ -213,6 +213,7 @@ def test_get_schemas_from_destination(
dataset_name = "state_test_" + uniq_id()

p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name)
assert_naming_to_caps(destination_config.destination, p.destination.capabilities())
p.config.use_single_dataset = use_single_dataset

def _make_dn_name(schema_name: str) -> str:
Expand Down Expand Up @@ -287,7 +288,10 @@ def _make_dn_name(schema_name: str) -> str:
@pytest.mark.parametrize(
"destination_config",
destinations_configs(
default_sql_configs=True, default_vector_configs=True, all_buckets_filesystem_configs=True
default_sql_configs=True,
all_staging_configs=True,
default_vector_configs=True,
all_buckets_filesystem_configs=True,
),
ids=lambda x: x.name,
)
Expand All @@ -296,7 +300,6 @@ def _make_dn_name(schema_name: str) -> str:
[
"tests.common.cases.normalizers.title_case",
"snake_case",
"tests.common.cases.normalizers.sql_upper",
],
)
def test_restore_state_pipeline(
Expand All @@ -308,6 +311,7 @@ def test_restore_state_pipeline(
pipeline_name = "pipe_" + uniq_id()
dataset_name = "state_test_" + uniq_id()
p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name)
assert_naming_to_caps(destination_config.destination, p.destination.capabilities())

def some_data_gen(param: str) -> Any:
dlt.current.source_state()[param] = param
Expand Down Expand Up @@ -735,11 +739,9 @@ def some_data(param: str) -> Any:


def prepare_import_folder(p: Pipeline) -> None:
os.makedirs(p._schema_storage.config.import_schema_path, exist_ok=True)
shutil.copy(
common_yml_case_path("schemas/eth/ethereum_schema_v5"),
os.path.join(p._schema_storage.config.import_schema_path, "ethereum.schema.yaml"),
)
from tests.common.storages.utils import prepare_eth_import_folder

prepare_eth_import_folder(p._schema_storage)


def set_naming_env(destination: str, naming_convention: str) -> None:
Expand All @@ -752,3 +754,16 @@ def set_naming_env(destination: str, naming_convention: str) -> None:
else:
naming_convention = "dlt.destinations.impl.weaviate.ci_naming"
os.environ["SCHEMA__NAMING"] = naming_convention


def assert_naming_to_caps(destination: str, caps: DestinationCapabilitiesContext) -> None:
naming = Schema("test").naming
if (
not caps.has_case_sensitive_identifiers
and caps.casefold_identifier is not str
and naming.is_case_sensitive
):
pytest.skip(
f"Skipping for case insensitive destination {destination} with case folding because"
f" naming {naming.name()} is case sensitive"
)
55 changes: 25 additions & 30 deletions tests/load/test_insert_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage
from tests.load.pipeline.utils import destinations_configs

DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse"]
DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse", "motherduck"]


@pytest.fixture
Expand Down Expand Up @@ -176,7 +176,6 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage
ids=lambda x: x.name,
)
def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -> None:
mocked_caps = client.sql_client.__class__.capabilities
writer_type = client.capabilities.insert_values_writer_type
insert_sql = prepare_insert_statement(10, writer_type)

Expand All @@ -185,10 +184,10 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -
elif writer_type == "select_union":
pre, post, sep = ("SELECT ", "", " UNION ALL\n")

# caps are instance and are attr of sql client instance so it is safe to mock them
client.sql_client.capabilities.max_query_length = 2
# this guarantees that we execute inserts line by line
with patch.object(mocked_caps, "max_query_length", 2), patch.object(
client.sql_client, "execute_fragments"
) as mocked_fragments:
with patch.object(client.sql_client, "execute_fragments") as mocked_fragments:
user_table_name = prepare_table(client)
expect_load_file(client, file_storage, insert_sql, user_table_name)
# print(mocked_fragments.mock_calls)
Expand All @@ -211,19 +210,17 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -

# set query length so it reads data until separator ("," or " UNION ALL") (followed by \n)
query_length = (idx - start_idx - 1) * 2
with patch.object(mocked_caps, "max_query_length", query_length), patch.object(
client.sql_client, "execute_fragments"
) as mocked_fragments:
client.sql_client.capabilities.max_query_length = query_length
with patch.object(client.sql_client, "execute_fragments") as mocked_fragments:
user_table_name = prepare_table(client)
expect_load_file(client, file_storage, insert_sql, user_table_name)
# split in 2 on ','
assert mocked_fragments.call_count == 2

# so it reads until "\n"
query_length = (idx - start_idx) * 2
with patch.object(mocked_caps, "max_query_length", query_length), patch.object(
client.sql_client, "execute_fragments"
) as mocked_fragments:
client.sql_client.capabilities.max_query_length = query_length
with patch.object(client.sql_client, "execute_fragments") as mocked_fragments:
user_table_name = prepare_table(client)
expect_load_file(client, file_storage, insert_sql, user_table_name)
# split in 2 on separator ("," or " UNION ALL")
Expand All @@ -235,9 +232,8 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -
elif writer_type == "select_union":
offset = 1
query_length = (len(insert_sql) - start_idx - offset) * 2
with patch.object(mocked_caps, "max_query_length", query_length), patch.object(
client.sql_client, "execute_fragments"
) as mocked_fragments:
client.sql_client.capabilities.max_query_length = query_length
with patch.object(client.sql_client, "execute_fragments") as mocked_fragments:
user_table_name = prepare_table(client)
expect_load_file(client, file_storage, insert_sql, user_table_name)
# split in 2 on ','
Expand All @@ -251,22 +247,21 @@ def assert_load_with_max_query(
max_query_length: int,
) -> None:
# load and check for real
mocked_caps = client.sql_client.__class__.capabilities
with patch.object(mocked_caps, "max_query_length", max_query_length):
user_table_name = prepare_table(client)
insert_sql = prepare_insert_statement(
insert_lines, client.capabilities.insert_values_writer_type
)
expect_load_file(client, file_storage, insert_sql, user_table_name)
canonical_name = client.sql_client.make_qualified_table_name(user_table_name)
rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0]
assert rows_count == insert_lines
# get all uniq ids in order
rows = client.sql_client.execute_sql(
f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;"
)
v_ids = list(map(lambda i: i[0], rows))
assert list(map(str, range(0, insert_lines))) == v_ids
client.sql_client.capabilities.max_query_length = max_query_length
user_table_name = prepare_table(client)
insert_sql = prepare_insert_statement(
insert_lines, client.capabilities.insert_values_writer_type
)
expect_load_file(client, file_storage, insert_sql, user_table_name)
canonical_name = client.sql_client.make_qualified_table_name(user_table_name)
rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0]
assert rows_count == insert_lines
# get all uniq ids in order
rows = client.sql_client.execute_sql(
f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;"
)
v_ids = list(map(lambda i: i[0], rows))
assert list(map(str, range(0, insert_lines))) == v_ids
client.sql_client.execute_sql(f"DELETE FROM {canonical_name}")


Expand Down
1 change: 1 addition & 0 deletions tests/pipeline/test_dlt_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None:
)
assert github_schema["engine_version"] == 9
assert "schema_version_hash" in github_schema["tables"][LOADS_TABLE_NAME]["columns"]
# print(github_schema["tables"][PIPELINE_STATE_TABLE_NAME])
# load state
state_dict = json.loads(
test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/state.json")
Expand Down
Loading

0 comments on commit 8546763

Please sign in to comment.