Skip to content

Commit

Permalink
fixes table builder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Jun 11, 2024
1 parent b1e2b09 commit 210be70
Show file tree
Hide file tree
Showing 20 changed files with 230 additions and 39 deletions.
10 changes: 5 additions & 5 deletions dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ def get_generic_type_argument_from_instance(
Type[Any]: type argument or Any if not known
"""
orig_param_type = Any
# instance of class deriving from generic
if bases_ := get_original_bases(instance):
cls_ = bases_[0]
else:
if cls_ := getattr(instance, "__orig_class__", None):
# instance of generic class
cls_ = getattr(instance, "__orig_class__", None)
pass
elif bases_ := get_original_bases(instance):
# instance of class deriving from generic
cls_ = bases_[0]
if cls_:
orig_param_type = get_args(cls_)[0]
if orig_param_type is Any and sample_value is not None:
Expand Down
4 changes: 2 additions & 2 deletions tests/.dlt/config.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[runtime]
sentry_dsn="https://[email protected]/4504819859914752"
# [runtime]
# sentry_dsn="https://[email protected]/4504819859914752"

[tests]
bucket_url_gs="gs://ci-test-bucket"
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions tests/common/cases/normalizers/title_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dlt.common.normalizers.naming.direct import NamingConvention as DirectNamingConvention


class NamingConvention(DirectNamingConvention):
"""Test case sensitive naming that capitalizes first and last letter and leaves the rest intact"""

PATH_SEPARATOR = "__"

def normalize_identifier(self, identifier: str) -> str:
# keep prefix
if identifier == "_dlt":
return "_dlt"
identifier = super().normalize_identifier(identifier)
return identifier[0].upper() + identifier[1:-1] + identifier[-1].upper()
17 changes: 11 additions & 6 deletions tests/load/bigquery/test_bigquery_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
from dlt.common.schema import Schema
from dlt.common.utils import custom_environ
from dlt.common.utils import uniq_id

from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate
from dlt.destinations.impl.bigquery import capabilities
from dlt.destinations.impl.bigquery.bigquery import BigQueryClient
from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter
from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration

from dlt.extract import DltResource

from tests.load.pipeline.utils import (
destinations_configs,
DestinationTestConfiguration,
Expand Down Expand Up @@ -63,6 +67,7 @@ def gcp_client(empty_schema: Schema) -> BigQueryClient:
BigQueryClientConfiguration(credentials=creds)._bind_dataset_name(
dataset_name=f"test_{uniq_id()}"
),
capabilities(),
)


Expand All @@ -89,9 +94,9 @@ def test_create_table(gcp_client: BigQueryClient) -> None:
sqlfluff.parse(sql, dialect="bigquery")
assert sql.startswith("CREATE TABLE")
assert "event_test_table" in sql
assert "`col1` INTEGER NOT NULL" in sql
assert "`col1` INT64 NOT NULL" in sql
assert "`col2` FLOAT64 NOT NULL" in sql
assert "`col3` BOOLEAN NOT NULL" in sql
assert "`col3` BOOL NOT NULL" in sql
assert "`col4` TIMESTAMP NOT NULL" in sql
assert "`col5` STRING " in sql
assert "`col6` NUMERIC(38,9) NOT NULL" in sql
Expand All @@ -100,7 +105,7 @@ def test_create_table(gcp_client: BigQueryClient) -> None:
assert "`col9` JSON NOT NULL" in sql
assert "`col10` DATE" in sql
assert "`col11` TIME" in sql
assert "`col1_precision` INTEGER NOT NULL" in sql
assert "`col1_precision` INT64 NOT NULL" in sql
assert "`col4_precision` TIMESTAMP NOT NULL" in sql
assert "`col5_precision` STRING(25) " in sql
assert "`col6_precision` NUMERIC(6,2) NOT NULL" in sql
Expand All @@ -119,9 +124,9 @@ def test_alter_table(gcp_client: BigQueryClient) -> None:
assert sql.startswith("ALTER TABLE")
assert sql.count("ALTER TABLE") == 1
assert "event_test_table" in sql
assert "ADD COLUMN `col1` INTEGER NOT NULL" in sql
assert "ADD COLUMN `col1` INT64 NOT NULL" in sql
assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql
assert "ADD COLUMN `col3` BOOLEAN NOT NULL" in sql
assert "ADD COLUMN `col3` BOOL NOT NULL" in sql
assert "ADD COLUMN `col4` TIMESTAMP NOT NULL" in sql
assert "ADD COLUMN `col5` STRING" in sql
assert "ADD COLUMN `col6` NUMERIC(38,9) NOT NULL" in sql
Expand All @@ -130,7 +135,7 @@ def test_alter_table(gcp_client: BigQueryClient) -> None:
assert "ADD COLUMN `col9` JSON NOT NULL" in sql
assert "ADD COLUMN `col10` DATE" in sql
assert "ADD COLUMN `col11` TIME" in sql
assert "ADD COLUMN `col1_precision` INTEGER NOT NULL" in sql
assert "ADD COLUMN `col1_precision` INT64 NOT NULL" in sql
assert "ADD COLUMN `col4_precision` TIMESTAMP NOT NULL" in sql
assert "ADD COLUMN `col5_precision` STRING(25)" in sql
assert "ADD COLUMN `col6_precision` NUMERIC(6,2) NOT NULL" in sql
Expand Down
3 changes: 3 additions & 0 deletions tests/load/clickhouse/test_clickhouse_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from dlt.common.schema import Schema
from dlt.common.utils import custom_environ, digest128
from dlt.common.utils import uniq_id

from dlt.destinations.impl.clickhouse import capabilities
from dlt.destinations.impl.clickhouse.clickhouse import ClickHouseClient
from dlt.destinations.impl.clickhouse.configuration import (
ClickHouseCredentials,
Expand All @@ -21,6 +23,7 @@ def clickhouse_client(empty_schema: Schema) -> ClickHouseClient:
return ClickHouseClient(
empty_schema,
ClickHouseClientConfiguration(credentials=creds)._bind_dataset_name(f"test_{uniq_id()}"),
capabilities(),
)


Expand Down
3 changes: 3 additions & 0 deletions tests/load/dremio/test_dremio_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest

from dlt.common.schema import TColumnSchema, Schema

from dlt.destinations.impl.dremio import capabilities
from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration, DremioCredentials
from dlt.destinations.impl.dremio.dremio import DremioClient
from tests.load.utils import empty_schema
Expand All @@ -15,6 +17,7 @@ def dremio_client(empty_schema: Schema) -> DremioClient:
DremioClientConfiguration(credentials=creds)._bind_dataset_name(
dataset_name="test_dataset"
),
capabilities(),
)


Expand Down
3 changes: 3 additions & 0 deletions tests/load/duckdb/test_duckdb_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dlt.common.utils import uniq_id
from dlt.common.schema import Schema

from dlt.destinations.impl.duckdb import capabilities
from dlt.destinations.impl.duckdb.duck import DuckDbClient
from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration

Expand All @@ -25,6 +26,7 @@ def client(empty_schema: Schema) -> DuckDbClient:
return DuckDbClient(
empty_schema,
DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_" + uniq_id()),
capabilities(),
)


Expand Down Expand Up @@ -122,6 +124,7 @@ def test_create_table_with_hints(client: DuckDbClient) -> None:
DuckDbClientConfiguration(create_indexes=True)._bind_dataset_name(
dataset_name="test_" + uniq_id()
),
capabilities(),
)
sql = client._get_table_update_sql("event_test_table", mod_update, False)[0]
sqlfluff.parse(sql)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
import os
import pyodbc
import pytest

from dlt.common.configuration import resolve_configuration, ConfigFieldMissingException
from dlt.common.exceptions import SystemConfigurationException
from dlt.common.schema import Schema

from dlt.destinations.impl.mssql.configuration import MsSqlCredentials
from dlt.destinations import mssql
from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration

# mark all tests as essential, do not remove
pytestmark = pytest.mark.essential


def test_mssql_factory() -> None:
schema = Schema("schema")
dest = mssql()
client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset"))
assert client.config.create_indexes is False
assert client.config.has_case_sensitive_identifiers is False
assert client.capabilities.has_case_sensitive_identifiers is False
assert client.capabilities.casefold_identifier is str

# set args explicitly
dest = mssql(has_case_sensitive_identifiers=True, create_indexes=True)
client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset"))
assert client.config.create_indexes is True
assert client.config.has_case_sensitive_identifiers is True
assert client.capabilities.has_case_sensitive_identifiers is True
assert client.capabilities.casefold_identifier is str

# set args via config
os.environ["DESTINATION__CREATE_INDEXES"] = "True"
os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True"
dest = mssql()
client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset"))
assert client.config.create_indexes is True
assert client.config.has_case_sensitive_identifiers is True
assert client.capabilities.has_case_sensitive_identifiers is True
assert client.capabilities.casefold_identifier is str


def test_mssql_credentials_defaults() -> None:
creds = MsSqlCredentials()
assert creds.port == 1433
Expand Down
12 changes: 7 additions & 5 deletions tests/load/mssql/test_mssql_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

pytest.importorskip("dlt.destinations.impl.mssql.mssql", reason="MSSQL ODBC driver not installed")

from dlt.destinations.impl.mssql.mssql import MsSqlClient
from dlt.destinations.impl.mssql import capabilities
from dlt.destinations.impl.mssql.mssql import MsSqlJobClient
from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials

from tests.load.utils import TABLE_UPDATE, empty_schema
Expand All @@ -16,17 +17,18 @@


@pytest.fixture
def client(empty_schema: Schema) -> MsSqlClient:
def client(empty_schema: Schema) -> MsSqlJobClient:
# return client without opening connection
return MsSqlClient(
return MsSqlJobClient(
empty_schema,
MsSqlClientConfiguration(credentials=MsSqlCredentials())._bind_dataset_name(
dataset_name="test_" + uniq_id()
),
capabilities(),
)


def test_create_table(client: MsSqlClient) -> None:
def test_create_table(client: MsSqlJobClient) -> None:
# non existing table
sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0]
sqlfluff.parse(sql, dialect="tsql")
Expand All @@ -50,7 +52,7 @@ def test_create_table(client: MsSqlClient) -> None:
assert '"col11_precision" time(3) NOT NULL' in sql


def test_alter_table(client: MsSqlClient) -> None:
def test_alter_table(client: MsSqlJobClient) -> None:
# existing table has no columns
sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0]
sqlfluff.parse(sql, dialect="tsql")
Expand Down
42 changes: 38 additions & 4 deletions tests/load/postgres/test_postgres_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from dlt.common.exceptions import TerminalValueError
from dlt.common.utils import uniq_id
from dlt.common.schema import Schema
from dlt.common.schema import Schema, utils
from dlt.common.destination import Destination

from dlt.destinations.impl.postgres import capabilities
from dlt.destinations.impl.postgres.postgres import PostgresClient
from dlt.destinations.impl.postgres.configuration import (
PostgresClientConfiguration,
Expand All @@ -25,12 +27,26 @@

@pytest.fixture
def client(empty_schema: Schema) -> PostgresClient:
return create_client(empty_schema)


@pytest.fixture
def cs_client(empty_schema: Schema) -> PostgresClient:
# change normalizer to case sensitive
empty_schema._normalizers_config["names"] = "tests.common.cases.normalizers.title_case"
empty_schema.update_normalizers()
return create_client(empty_schema)


def create_client(empty_schema: Schema) -> PostgresClient:
# return client without opening connection
config = PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name(
dataset_name="test_" + uniq_id()
)
return PostgresClient(
empty_schema,
PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name(
dataset_name="test_" + uniq_id()
),
config,
Destination.adjust_capabilities(capabilities(), config, empty_schema.naming),
)


Expand Down Expand Up @@ -125,7 +141,25 @@ def test_create_table_with_hints(client: PostgresClient) -> None:
create_indexes=False,
credentials=PostgresCredentials(),
)._bind_dataset_name(dataset_name="test_" + uniq_id()),
capabilities(),
)
sql = client._get_table_update_sql("event_test_table", mod_update, False)[0]
sqlfluff.parse(sql, dialect="postgres")
assert '"col2" double precision NOT NULL' in sql


def test_create_table_case_sensitive(cs_client: PostgresClient) -> None:
cs_client.schema.update_table(
utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE))
)
sql = cs_client._get_table_update_sql(
"Event_test_tablE",
list(cs_client.schema.get_table_columns("Event_test_tablE").values()),
False,
)[0]
sqlfluff.parse(sql, dialect="postgres")
# everything capitalized
assert cs_client.sql_client.fully_qualified_dataset_name(escape=False)[0] == "T" # Test
# every line starts with "Col"
for line in sql.split("\n")[1:]:
assert line.startswith('"Col')
35 changes: 34 additions & 1 deletion tests/load/redshift/test_redshift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@

from dlt.common import json, pendulum
from dlt.common.configuration.resolve import resolve_configuration
from dlt.common.schema.schema import Schema
from dlt.common.schema.typing import VERSION_TABLE_NAME
from dlt.common.storages import FileStorage
from dlt.common.storages.schema_storage import SchemaStorage
from dlt.common.utils import uniq_id

from dlt.destinations.exceptions import DatabaseTerminalException
from dlt.destinations.impl.redshift.configuration import RedshiftCredentials
from dlt.destinations import redshift
from dlt.destinations.impl.redshift.configuration import (
RedshiftCredentials,
RedshiftClientConfiguration,
)
from dlt.destinations.impl.redshift.redshift import RedshiftClient, psycopg2

from tests.common.utils import COMMON_TEST_CASES_PATH
Expand Down Expand Up @@ -42,6 +47,34 @@ def test_postgres_and_redshift_credentials_defaults() -> None:
assert red_cred.port == 5439


def test_redshift_factory() -> None:
schema = Schema("schema")
dest = redshift()
client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset"))
assert client.config.staging_iam_role is None
assert client.config.has_case_sensitive_identifiers is False
assert client.capabilities.has_case_sensitive_identifiers is False
assert client.capabilities.casefold_identifier is str.lower

# set args explicitly
dest = redshift(has_case_sensitive_identifiers=True, staging_iam_role="LOADER")
client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset"))
assert client.config.staging_iam_role == "LOADER"
assert client.config.has_case_sensitive_identifiers is True
assert client.capabilities.has_case_sensitive_identifiers is True
assert client.capabilities.casefold_identifier is str

# set args via config
os.environ["DESTINATION__STAGING_IAM_ROLE"] = "LOADER"
os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True"
dest = redshift()
client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset"))
assert client.config.staging_iam_role == "LOADER"
assert client.config.has_case_sensitive_identifiers is True
assert client.capabilities.has_case_sensitive_identifiers is True
assert client.capabilities.casefold_identifier is str


@skipifpypy
def test_text_too_long(client: RedshiftClient, file_storage: FileStorage) -> None:
caps = client.capabilities
Expand Down
2 changes: 2 additions & 0 deletions tests/load/redshift/test_redshift_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dlt.common.schema import Schema
from dlt.common.configuration import resolve_configuration

from dlt.destinations.impl.redshift import capabilities
from dlt.destinations.impl.redshift.redshift import RedshiftClient
from dlt.destinations.impl.redshift.configuration import (
RedshiftClientConfiguration,
Expand All @@ -26,6 +27,7 @@ def client(empty_schema: Schema) -> RedshiftClient:
RedshiftClientConfiguration(credentials=RedshiftCredentials())._bind_dataset_name(
dataset_name="test_" + uniq_id()
),
capabilities(),
)


Expand Down
Loading

0 comments on commit 210be70

Please sign in to comment.