diff --git a/.github/workflows/test_destination_lancedb.yml b/.github/workflows/test_destination_lancedb.yml new file mode 100644 index 0000000000..02b5ef66eb --- /dev/null +++ b/.github/workflows/test_destination_lancedb.yml @@ -0,0 +1,81 @@ +name: dest | lancedb + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + schedule: + - cron: '0 2 * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + + ACTIVE_DESTINATIONS: "[\"lancedb\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} + + run_loader: + name: dest | lancedb tests + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + defaults: + run: + shell: bash + runs-on: "ubuntu-latest" + + steps: + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.11.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-gcp + + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + - name: Install dependencies + run: poetry install --no-interaction -E lancedb -E parquet --with sentry-sdk --with pipeline + + - name: Install embedding provider dependencies + run: poetry run pip install openai + + - run: | + poetry run pytest tests/load -m "essential" + name: Run essential tests Linux + if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} + + - run: | + poetry run pytest tests/load + name: Run all tests Linux + if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index cb6417a4ab..6d4e6dda53 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -32,6 +32,26 @@ jobs: # Do not run on forks, unless allowed, secrets are used here if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres + # Provide the password for postgres + env: + POSTGRES_DB: dlt_data + POSTGRES_USER: loader + POSTGRES_PASSWORD: loader + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: - name: Check out @@ -61,7 +81,7 @@ jobs: - name: Install dependencies # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction -E duckdb -E weaviate -E parquet -E qdrant -E bigquery -E postgres --with docs,sentry-sdk --without airflow + run: poetry install --no-interaction -E duckdb -E weaviate -E parquet -E qdrant -E bigquery -E postgres -E lancedb --with docs,sentry-sdk --without airflow - name: create secrets.toml for examples run: pwd && echo "$DLT_SECRETS_TOML" > docs/examples/.dlt/secrets.toml diff --git a/Makefile b/Makefile index fd0920d188..15fb895a9f 100644 --- a/Makefile +++ b/Makefile @@ -67,9 +67,9 @@ lint-and-test-snippets: cd docs/website/docs && poetry run pytest --ignore=node_modules lint-and-test-examples: - poetry run mypy --config-file mypy.ini docs/examples - poetry run flake8 --max-line-length=200 docs/examples cd docs/tools && poetry run python prepare_examples_tests.py + poetry run flake8 --max-line-length=200 docs/examples + poetry run mypy --config-file mypy.ini docs/examples cd docs/examples && poetry run pytest diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index ed85aae8ba..dcb78683fb 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -17,7 +17,9 @@ class RunConfiguration(BaseConfiguration): dlthub_telemetry: bool = True # enable or disable dlthub telemetry dlthub_telemetry_endpoint: Optional[str] = "https://telemetry.scalevector.ai" dlthub_telemetry_segment_write_key: Optional[str] = None - log_format: str = "{asctime}|[{levelname:<21}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}" + log_format: str = ( + "{asctime}|[{levelname}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}" + ) log_level: str = "WARNING" request_timeout: float = 60 """Timeout for http requests""" diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py index 97451d8be7..945e74a37b 100644 --- a/dlt/common/data_writers/__init__.py +++ b/dlt/common/data_writers/__init__.py @@ -3,6 +3,7 @@ DataWriterMetrics, TDataItemFormat, FileWriterSpec, + create_import_spec, resolve_best_writer_spec, get_best_writer_spec, is_native_writer, @@ -11,12 +12,13 @@ from dlt.common.data_writers.escape import ( escape_redshift_literal, escape_redshift_identifier, - escape_bigquery_identifier, + escape_hive_identifier, ) __all__ = [ "DataWriter", "FileWriterSpec", + "create_import_spec", "resolve_best_writer_spec", "get_best_writer_spec", "is_native_writer", @@ -26,5 +28,5 @@ "new_file_id", "escape_redshift_literal", "escape_redshift_identifier", - "escape_bigquery_identifier", + "escape_hive_identifier", ] diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index bd32c68c49..8077007edb 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,11 +1,13 @@ import gzip import time -from typing import ClassVar, List, IO, Any, Optional, Type, Generic +import contextlib +from typing import ClassVar, Iterator, List, IO, Any, Optional, Type, Generic from dlt.common.typing import TDataItem, TDataItems from dlt.common.data_writers.exceptions import ( BufferedDataWriterClosed, DestinationCapabilitiesRequired, + FileImportNotFound, InvalidFileNameTemplateException, ) from dlt.common.data_writers.writers import TWriter, DataWriter, DataWriterMetrics, FileWriterSpec @@ -138,18 +140,31 @@ def write_empty_file(self, columns: TTableSchemaColumns) -> DataWriterMetrics: self._last_modified = time.time() return self._rotate_file(allow_empty_file=True) - def import_file(self, file_path: str, metrics: DataWriterMetrics) -> DataWriterMetrics: + def import_file( + self, file_path: str, metrics: DataWriterMetrics, with_extension: str = None + ) -> DataWriterMetrics: """Import a file from `file_path` into items storage under a new file name. Does not check the imported file format. Uses counts from `metrics` as a base. Logically closes the imported file The preferred import method is a hard link to avoid copying the data. If current filesystem does not support it, a regular copy is used. + + Alternative extension may be provided via `with_extension` so various file formats may be imported into the same folder. """ # TODO: we should separate file storage from other storages. this creates circular deps from dlt.common.storages import FileStorage - self._rotate_file() - FileStorage.link_hard_with_fallback(file_path, self._file_name) + # import file with alternative extension + spec = self.writer_spec + if with_extension: + spec = self.writer_spec._replace(file_extension=with_extension) + with self.alternative_spec(spec): + self._rotate_file() + try: + FileStorage.link_hard_with_fallback(file_path, self._file_name) + except FileNotFoundError as f_ex: + raise FileImportNotFound(file_path, self._file_name) from f_ex + self._last_modified = time.time() metrics = metrics._replace( file_path=self._file_name, @@ -176,6 +191,16 @@ def close(self, skip_flush: bool = False) -> None: def closed(self) -> bool: return self._closed + @contextlib.contextmanager + def alternative_spec(self, spec: FileWriterSpec) -> Iterator[FileWriterSpec]: + """Temporarily changes the writer spec ie. for the moment file is rotated""" + old_spec = self.writer_spec + try: + self.writer_spec = spec + yield spec + finally: + self.writer_spec = old_spec + def __enter__(self) -> "BufferedDataWriter[TWriter]": return self diff --git a/dlt/common/data_writers/configuration.py b/dlt/common/data_writers/configuration.py new file mode 100644 index 0000000000..a837cb47b0 --- /dev/null +++ b/dlt/common/data_writers/configuration.py @@ -0,0 +1,31 @@ +from typing import ClassVar, Literal, Optional +from dlt.common.configuration import configspec, known_sections +from dlt.common.configuration.specs import BaseConfiguration + +CsvQuoting = Literal["quote_all", "quote_needed"] + + +@configspec +class CsvFormatConfiguration(BaseConfiguration): + delimiter: str = "," + include_header: bool = True + quoting: CsvQuoting = "quote_needed" + + # read options + on_error_continue: bool = False + encoding: str = "utf-8" + + __section__: ClassVar[str] = known_sections.DATA_WRITER + + +@configspec +class ParquetFormatConfiguration(BaseConfiguration): + flavor: Optional[str] = None # could be ie. "spark" + version: Optional[str] = "2.4" + data_page_size: Optional[int] = None + timestamp_timezone: str = "UTC" + row_group_size: Optional[int] = None + coerce_timestamps: Optional[Literal["s", "ms", "us", "ns"]] = None + allow_truncated_timestamps: bool = False + + __section__: ClassVar[str] = known_sections.DATA_WRITER diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 580b057716..06c8d7a95a 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -124,7 +124,7 @@ def escape_redshift_identifier(v: str) -> str: escape_dremio_identifier = escape_postgres_identifier -def escape_bigquery_identifier(v: str) -> str: +def escape_hive_identifier(v: str) -> str: # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical return "`" + v.replace("\\", "\\\\").replace("`", "\\`") + "`" @@ -132,10 +132,10 @@ def escape_bigquery_identifier(v: str) -> str: def escape_snowflake_identifier(v: str) -> str: # Snowcase uppercase all identifiers unless quoted. Match this here so queries on information schema work without issue # See also https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers - return escape_postgres_identifier(v.upper()) + return escape_postgres_identifier(v) -escape_databricks_identifier = escape_bigquery_identifier +escape_databricks_identifier = escape_hive_identifier DATABRICKS_ESCAPE_DICT = {"'": "\\'", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py index 1d5c58f787..3b11ed70fc 100644 --- a/dlt/common/data_writers/exceptions.py +++ b/dlt/common/data_writers/exceptions.py @@ -22,6 +22,16 @@ def __init__(self, file_name: str): super().__init__(f"Writer with recent file name {file_name} is already closed") +class FileImportNotFound(DataWriterException, FileNotFoundError): + def __init__(self, import_file_path: str, local_file_path: str) -> None: + self.import_file_path = import_file_path + self.local_file_path = local_file_path + super().__init__( + f"Attempt to import non existing file {import_file_path} into extract storage file" + f" {local_file_path}" + ) + + class DestinationCapabilitiesRequired(DataWriterException, ValueError): def __init__(self, file_format: TLoaderFileFormat): self.file_format = file_format diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 8936dae605..d324792a83 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -4,7 +4,6 @@ IO, TYPE_CHECKING, Any, - ClassVar, Dict, List, Literal, @@ -17,8 +16,7 @@ ) from dlt.common.json import json -from dlt.common.configuration import configspec, known_sections, with_config -from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration import with_config from dlt.common.data_writers.exceptions import ( SpecLookupFailed, DataWriterNotFound, @@ -26,15 +24,25 @@ FileSpecNotFound, InvalidDataItem, ) -from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat +from dlt.common.data_writers.configuration import ( + CsvFormatConfiguration, + CsvQuoting, + ParquetFormatConfiguration, +) +from dlt.common.destination import ( + DestinationCapabilitiesContext, + TLoaderFileFormat, + ALL_SUPPORTED_FILE_FORMATS, +) from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.typing import StrAny + if TYPE_CHECKING: from dlt.common.libs.pyarrow import pyarrow as pa -TDataItemFormat = Literal["arrow", "object"] +TDataItemFormat = Literal["arrow", "object", "file"] TWriter = TypeVar("TWriter", bound="DataWriter") @@ -124,6 +132,9 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat: return "object" elif extension == "parquet": return "arrow" + # those files may be imported by normalizer as is + elif extension in ALL_SUPPORTED_FILE_FORMATS: + return "file" else: raise ValueError(f"Cannot figure out data item format for extension {extension}") @@ -132,6 +143,8 @@ def writer_class_from_spec(spec: FileWriterSpec) -> Type["DataWriter"]: try: return WRITER_SPECS[spec] except KeyError: + if spec.data_item_format == "file": + return ImportFileWriter raise FileSpecNotFound(spec.file_format, spec.data_item_format, spec) @staticmethod @@ -147,6 +160,19 @@ def class_factory( raise FileFormatForItemFormatNotFound(file_format, data_item_format) +class ImportFileWriter(DataWriter): + """May only import files, fails on any open/write operations""" + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + raise NotImplementedError( + "ImportFileWriter cannot write any files. You have bug in your code." + ) + + @classmethod + def writer_spec(cls) -> FileWriterSpec: + raise NotImplementedError("ImportFileWriter has no single spec") + + class JsonlWriter(DataWriter): def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) @@ -260,21 +286,8 @@ def writer_spec(cls) -> FileWriterSpec: ) -@configspec -class ParquetDataWriterConfiguration(BaseConfiguration): - flavor: Optional[str] = None # could be ie. "spark" - version: Optional[str] = "2.4" - data_page_size: Optional[int] = None - timestamp_timezone: str = "UTC" - row_group_size: Optional[int] = None - coerce_timestamps: Optional[Literal["s", "ms", "us", "ns"]] = None - allow_truncated_timestamps: bool = False - - __section__: ClassVar[str] = known_sections.DATA_WRITER - - class ParquetDataWriter(DataWriter): - @with_config(spec=ParquetDataWriterConfiguration) + @with_config(spec=ParquetFormatConfiguration) def __init__( self, f: IO[Any], @@ -381,20 +394,8 @@ def writer_spec(cls) -> FileWriterSpec: ) -CsvQuoting = Literal["quote_all", "quote_needed"] - - -@configspec -class CsvDataWriterConfiguration(BaseConfiguration): - delimiter: str = "," - include_header: bool = True - quoting: CsvQuoting = "quote_needed" - - __section__: ClassVar[str] = known_sections.DATA_WRITER - - class CsvWriter(DataWriter): - @with_config(spec=CsvDataWriterConfiguration) + @with_config(spec=CsvFormatConfiguration) def __init__( self, f: IO[Any], @@ -525,7 +526,7 @@ def writer_spec(cls) -> FileWriterSpec: class ArrowToCsvWriter(DataWriter): - @with_config(spec=CsvDataWriterConfiguration) + @with_config(spec=CsvFormatConfiguration) def __init__( self, f: IO[Any], @@ -783,3 +784,16 @@ def get_best_writer_spec( return DataWriter.class_factory(file_format, item_format, native_writers).writer_spec() except DataWriterNotFound: return DataWriter.class_factory(file_format, item_format, ALL_WRITERS).writer_spec() + + +def create_import_spec( + item_file_format: TLoaderFileFormat, + possible_file_formats: Sequence[TLoaderFileFormat], +) -> FileWriterSpec: + """Creates writer spec that may be used only to import files""" + # can the item file be directly imported? + if item_file_format not in possible_file_formats: + raise SpecLookupFailed("file", possible_file_formats, item_file_format) + + spec = DataWriter.class_factory(item_file_format, "object", ALL_WRITERS).writer_spec() + return spec._replace(data_item_format="file") diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index d8361d7140..595d3e0d26 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -10,7 +10,8 @@ Protocol, get_args, ) - +from dlt.common.normalizers.typing import TNamingConventionReferenceArg +from dlt.common.typing import TLoaderFileFormat from dlt.common.configuration.utils import serialize_value from dlt.common.configuration import configspec from dlt.common.configuration.specs import ContainerInjectableContext @@ -19,19 +20,11 @@ DestinationLoadingViaStagingNotSupported, DestinationLoadingWithoutStagingNotSupported, ) -from dlt.common.utils import identity - +from dlt.common.normalizers.naming import NamingConvention from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION -# known loader file formats -# jsonl - new line separated json documents -# typed-jsonl - internal extract -> normalize format bases on jsonl -# insert_values - insert SQL statements -# sql - any sql statement -TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv"] TLoaderParallelismStrategy = Literal["parallel", "table-sequential", "sequential"] - ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) @@ -61,9 +54,15 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): """Recommended file size in bytes when writing extract/load files""" preferred_staging_file_format: Optional[TLoaderFileFormat] = None supported_staging_file_formats: Sequence[TLoaderFileFormat] = None + format_datetime_literal: Callable[..., str] = None escape_identifier: Callable[[str], str] = None + "Escapes table name, column name and other identifiers" escape_literal: Callable[[Any], Any] = None - format_datetime_literal: Callable[..., str] = None + "Escapes string literal" + casefold_identifier: Callable[[str], str] = str + """Casing function applied by destination to represent case insensitive identifiers.""" + has_case_sensitive_identifiers: bool = None + """Tells if identifiers in destination are case sensitive, before case_identifier function is applied""" decimal_precision: Tuple[int, int] = None wei_precision: Tuple[int, int] = None max_identifier_length: int = None @@ -74,7 +73,8 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): is_max_text_data_type_length_in_bytes: bool = None supports_transactions: bool = None supports_ddl_transactions: bool = None - naming_convention: str = "snake_case" + # use naming convention in the schema + naming_convention: TNamingConventionReferenceArg = None alter_add_multi_column: bool = True supports_truncate_command: bool = True schema_supports_numeric_precision: bool = True @@ -99,6 +99,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): @staticmethod def generic_capabilities( preferred_loader_file_format: TLoaderFileFormat = None, + naming_convention: TNamingConventionReferenceArg = None, loader_file_format_adapter: LoaderFileFormatAdapter = None, supported_table_formats: Sequence["TTableFormat"] = None, # type: ignore[name-defined] # noqa: F821 ) -> "DestinationCapabilitiesContext": @@ -110,9 +111,12 @@ def generic_capabilities( caps.loader_file_format_adapter = loader_file_format_adapter caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] + caps.naming_convention = naming_convention or caps.naming_convention + caps.escape_identifier = str caps.supported_table_formats = supported_table_formats or [] - caps.escape_identifier = identity caps.escape_literal = serialize_value + caps.casefold_identifier = str + caps.has_case_sensitive_identifiers = True caps.format_datetime_literal = format_datetime_literal caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (EVM_DECIMAL_PRECISION, 0) diff --git a/dlt/common/destination/exceptions.py b/dlt/common/destination/exceptions.py index cd8f50bcce..c5f30401df 100644 --- a/dlt/common/destination/exceptions.py +++ b/dlt/common/destination/exceptions.py @@ -124,3 +124,16 @@ def __init__(self, schema_name: str, version_hash: str, stored_version_hash: str " schema in load package, you should first save it into schema storage. You can also" " use schema._bump_version() in test code to remove modified flag." ) + + +class DestinationInvalidFileFormat(DestinationTerminalException): + def __init__( + self, destination_type: str, file_format: str, file_name: str, message: str + ) -> None: + self.destination_type = destination_type + self.file_format = file_format + self.message = message + super().__init__( + f"Destination {destination_type} cannot process file {file_name} with format" + f" {file_format}: {message}" + ) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 9bb843a4c5..90f89b85d7 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -25,30 +25,27 @@ import inspect from dlt.common import logger +from dlt.common.configuration.specs.base_configuration import extract_inner_hint +from dlt.common.destination.utils import verify_schema_capabilities +from dlt.common.normalizers.naming import NamingConvention from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import MERGE_STRATEGIES -from dlt.common.schema.exceptions import SchemaException from dlt.common.schema.utils import ( + get_file_format, get_write_disposition, get_table_format, - get_columns_names_with_prop, - has_column_with_prop, - get_first_column_name_with_prop, ) from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.destination.exceptions import ( - IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule, DestinationSchemaTampered, ) -from dlt.common.schema.utils import is_complete_column from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.storages.load_package import LoadJobInfo +from dlt.common.storages.load_package import LoadJobInfo, TPipelineStateDoc TLoaderReplaceStrategy = Literal["truncate-and-insert", "insert-from-staging", "staging-optimized"] TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") @@ -67,13 +64,23 @@ class StorageSchemaInfo(NamedTuple): schema: str -class StateInfo(NamedTuple): +@dataclasses.dataclass +class StateInfo: version: int engine_version: int pipeline_name: str state: str created_at: datetime.datetime - dlt_load_id: str = None + version_hash: Optional[str] = None + _dlt_load_id: Optional[str] = None + + def as_doc(self) -> TPipelineStateDoc: + doc: TPipelineStateDoc = dataclasses.asdict(self) # type: ignore[assignment] + if self._dlt_load_id is None: + doc.pop("_dlt_load_id") + if self.version_hash is None: + doc.pop("version_hash") + return doc @configspec @@ -98,6 +105,25 @@ def __str__(self) -> str: def on_resolved(self) -> None: self.destination_name = self.destination_name or self.destination_type + @classmethod + def credentials_type( + cls, config: "DestinationClientConfiguration" = None + ) -> Type[CredentialsConfiguration]: + """Figure out credentials type, using hint resolvers for dynamic types + + For correct type resolution of filesystem, config should have bucket_url populated + """ + key = "credentials" + type_ = cls.get_resolvable_fields()[key] + if key in cls.__hint_resolvers__ and config is not None: + try: + # Type hint for this field is created dynamically + type_ = cls.__hint_resolvers__[key](config) + except Exception: + # we suppress failed hint resolutions + pass + return extract_inner_hint(type_) + @configspec class DestinationClientDwhConfiguration(DestinationClientConfiguration): @@ -253,11 +279,15 @@ class DoNothingFollowupJob(DoNothingJob, FollowupJob): class JobClientBase(ABC): - capabilities: ClassVar[DestinationCapabilitiesContext] = None - - def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> None: + def __init__( + self, + schema: Schema, + config: DestinationClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: self.schema = schema self.config = config + self.capabilities = capabilities @abstractmethod def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: @@ -315,7 +345,7 @@ def should_truncate_table_before_load(self, table: TTableSchema) -> bool: def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] @@ -336,96 +366,13 @@ def __exit__( pass def _verify_schema(self) -> None: - """Verifies and cleans up a schema before loading - - * Checks all table and column name lengths against destination capabilities and raises on too long identifiers - * Removes and warns on (unbound) incomplete columns - """ - - for table in self.schema.data_tables(): - table_name = table["name"] - if len(table_name) > self.capabilities.max_identifier_length: - raise IdentifierTooLongException( - self.config.destination_type, - "table", - table_name, - self.capabilities.max_identifier_length, - ) - if table.get("write_disposition") == "merge": - if "x-merge-strategy" in table and table["x-merge-strategy"] not in MERGE_STRATEGIES: # type: ignore[typeddict-item] - raise SchemaException( - f'"{table["x-merge-strategy"]}" is not a valid merge strategy. ' # type: ignore[typeddict-item] - f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""" - ) - if ( - table.get("x-merge-strategy") == "delete-insert" - and not has_column_with_prop(table, "primary_key") - and not has_column_with_prop(table, "merge_key") - ): - logger.warning( - f"Table {table_name} has `write_disposition` set to `merge`" - " and `merge_strategy` set to `delete-insert`, but no primary or" - " merge keys defined." - " dlt will fall back to `append` for this table." - ) - if has_column_with_prop(table, "hard_delete"): - if len(get_columns_names_with_prop(table, "hard_delete")) > 1: - raise SchemaException( - f'Found multiple "hard_delete" column hints for table "{table_name}" in' - f' schema "{self.schema.name}" while only one is allowed:' - f' {", ".join(get_columns_names_with_prop(table, "hard_delete"))}.' - ) - if table.get("write_disposition") in ("replace", "append"): - logger.warning( - f"""The "hard_delete" column hint for column "{get_first_column_name_with_prop(table, 'hard_delete')}" """ - f'in table "{table_name}" with write disposition' - f' "{table.get("write_disposition")}"' - f' in schema "{self.schema.name}" will be ignored.' - ' The "hard_delete" column hint is only applied when using' - ' the "merge" write disposition.' - ) - if has_column_with_prop(table, "dedup_sort"): - if len(get_columns_names_with_prop(table, "dedup_sort")) > 1: - raise SchemaException( - f'Found multiple "dedup_sort" column hints for table "{table_name}" in' - f' schema "{self.schema.name}" while only one is allowed:' - f' {", ".join(get_columns_names_with_prop(table, "dedup_sort"))}.' - ) - if table.get("write_disposition") in ("replace", "append"): - logger.warning( - f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ - f'in table "{table_name}" with write disposition' - f' "{table.get("write_disposition")}"' - f' in schema "{self.schema.name}" will be ignored.' - ' The "dedup_sort" column hint is only applied when using' - ' the "merge" write disposition.' - ) - if table.get("write_disposition") == "merge" and not has_column_with_prop( - table, "primary_key" - ): - logger.warning( - f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ - f'in table "{table_name}" with write disposition' - f' "{table.get("write_disposition")}"' - f' in schema "{self.schema.name}" will be ignored.' - ' The "dedup_sort" column hint is only applied when a' - " primary key has been specified." - ) - for column_name, column in dict(table["columns"]).items(): - if len(column_name) > self.capabilities.max_column_identifier_length: - raise IdentifierTooLongException( - self.config.destination_type, - "column", - f"{table_name}.{column_name}", - self.capabilities.max_column_identifier_length, - ) - if not is_complete_column(column): - logger.warning( - f"A column {column_name} in table {table_name} in schema" - f" {self.schema.name} is incomplete. It was not bound to the data during" - " normalizations stage and its data type is unknown. Did you add this" - " column manually in code ie. as a merge key?" - ) + """Verifies schema before loading""" + if exceptions := verify_schema_capabilities( + self.schema, self.capabilities, self.config.destination_type, warnings=False + ): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] def prepare_load_table( self, table_name: str, prepare_for_staging: bool = False @@ -438,9 +385,11 @@ def prepare_load_table( table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) if "table_format" not in table: table["table_format"] = get_table_format(self.schema.tables, table_name) + if "file_format" not in table: + table["file_format"] = get_file_format(self.schema.tables, table_name) return table except KeyError: - raise UnknownTableException(table_name) + raise UnknownTableException(self.schema.name, table_name) class WithStateSync(ABC): @@ -497,7 +446,10 @@ class Destination(ABC, Generic[TDestinationConfig, TDestinationClient]): with credentials and other config params. """ - config_params: Optional[Dict[str, Any]] = None + config_params: Dict[str, Any] + """Explicit config params, overriding any injected or default values.""" + caps_params: Dict[str, Any] + """Explicit capabilities params, overriding any default values for this destination""" def __init__(self, **kwargs: Any) -> None: # Create initial unresolved destination config @@ -505,9 +457,27 @@ def __init__(self, **kwargs: Any) -> None: # to supersede config from the environment or pipeline args sig = inspect.signature(self.__class__.__init__) params = sig.parameters - self.config_params = { - k: v for k, v in kwargs.items() if k not in params or v != params[k].default - } + + # get available args + spec = self.spec + spec_fields = spec.get_resolvable_fields() + caps_fields = DestinationCapabilitiesContext.get_resolvable_fields() + + # remove default kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in params or v != params[k].default} + + # warn on unknown params + for k in list(kwargs): + if k not in spec_fields and k not in caps_fields: + logger.warning( + f"When initializing destination factory of type {self.destination_type}," + f" argument {k} is not a valid field in {spec.__name__} or destination" + " capabilities" + ) + kwargs.pop(k) + + self.config_params = {k: v for k, v in kwargs.items() if k in spec_fields} + self.caps_params = {k: v for k, v in kwargs.items() if k in caps_fields} @property @abstractmethod @@ -515,9 +485,37 @@ def spec(self) -> Type[TDestinationConfig]: """A spec of destination configuration that also contains destination credentials""" ... + def capabilities( + self, config: Optional[TDestinationConfig] = None, naming: Optional[NamingConvention] = None + ) -> DestinationCapabilitiesContext: + """Destination capabilities ie. supported loader file formats, identifier name lengths, naming conventions, escape function etc. + Explicit caps arguments passed to the factory init and stored in `caps_params` are applied. + + If `config` is provided, it is used to adjust the capabilities, otherwise the explicit config composed just of `config_params` passed + to factory init is applied + If `naming` is provided, the case sensitivity and case folding are adjusted. + """ + caps = self._raw_capabilities() + caps.update(self.caps_params) + # get explicit config if final config not passed + if config is None: + # create mock credentials to avoid credentials being resolved + init_config = self.spec() + init_config.update(self.config_params) + credentials = self.spec.credentials_type(init_config)() + credentials.__is_resolved__ = True + config = self.spec(credentials=credentials) + try: + config = self.configuration(config, accept_partial=True) + except Exception: + # in rare cases partial may fail ie. when invalid native value is present + # in that case we fallback to "empty" config + pass + return self.adjust_capabilities(caps, config, naming) + @abstractmethod - def capabilities(self) -> DestinationCapabilitiesContext: - """Destination capabilities ie. supported loader file formats, identifier name lengths, naming conventions, escape function etc.""" + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + """Returns raw capabilities, before being adjusted with naming convention and config""" ... @property @@ -540,16 +538,61 @@ def client_class(self) -> Type[TDestinationClient]: """A job client class responsible for starting and resuming load jobs""" ... - def configuration(self, initial_config: TDestinationConfig) -> TDestinationConfig: + def configuration( + self, initial_config: TDestinationConfig, accept_partial: bool = False + ) -> TDestinationConfig: """Get a fully resolved destination config from the initial config""" + config = resolve_configuration( - initial_config, + initial_config or self.spec(), sections=(known_sections.DESTINATION, self.destination_name), # Already populated values will supersede resolved env config explicit_value=self.config_params, + accept_partial=accept_partial, ) return config + def client( + self, schema: Schema, initial_config: TDestinationConfig = None + ) -> TDestinationClient: + """Returns a configured instance of the destination's job client""" + config = self.configuration(initial_config) + return self.client_class(schema, config, self.capabilities(config, schema.naming)) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: TDestinationConfig, + naming: Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + """Adjust the capabilities to match the case sensitivity as requested by naming convention.""" + # if naming not provided, skip the adjustment + if not naming or not naming.is_case_sensitive: + # all destinations are configured to be case insensitive so there's nothing to adjust + return caps + if not caps.has_case_sensitive_identifiers: + if caps.casefold_identifier is str: + logger.info( + f"Naming convention {naming.name()} is case sensitive but the destination does" + " not support case sensitive identifiers. Nevertheless identifier casing will" + " be preserved in the destination schema." + ) + else: + logger.warn( + f"Naming convention {naming.name()} is case sensitive but the destination does" + " not support case sensitive identifiers. Destination will case fold all the" + f" identifiers with {caps.casefold_identifier}" + ) + else: + # adjust case folding to store casefold identifiers in the schema + if caps.casefold_identifier is not str: + caps.casefold_identifier = str + logger.info( + f"Enabling case sensitive identifiers for naming convention {naming.name()}" + ) + return caps + @staticmethod def to_name(ref: TDestinationReferenceArg) -> str: if ref is None: @@ -562,7 +605,7 @@ def to_name(ref: TDestinationReferenceArg) -> str: @staticmethod def normalize_type(destination_type: str) -> str: - """Normalizes destination type string into a canonical form. Assumes that type names without dots correspond to build in destinations.""" + """Normalizes destination type string into a canonical form. Assumes that type names without dots correspond to built in destinations.""" if "." not in destination_type: destination_type = "dlt.destinations." + destination_type # the next two lines shorten the dlt internal destination paths to dlt.destinations. @@ -625,11 +668,5 @@ def from_reference( raise InvalidDestinationReference(ref) from e return dest - def client( - self, schema: Schema, initial_config: TDestinationConfig = None - ) -> TDestinationClient: - """Returns a configured instance of the destination's job client""" - return self.client_class(schema, self.configuration(initial_config)) - TDestination = Destination[DestinationClientConfiguration, JobClientBase] diff --git a/dlt/common/destination/utils.py b/dlt/common/destination/utils.py new file mode 100644 index 0000000000..2c5e97df14 --- /dev/null +++ b/dlt/common/destination/utils.py @@ -0,0 +1,115 @@ +from typing import List + +from dlt.common import logger +from dlt.common.destination.exceptions import IdentifierTooLongException +from dlt.common.schema import Schema +from dlt.common.schema.exceptions import ( + SchemaIdentifierNormalizationCollision, +) +from dlt.common.schema.utils import is_complete_column +from dlt.common.typing import DictStrStr + +from .capabilities import DestinationCapabilitiesContext + + +def verify_schema_capabilities( + schema: Schema, + capabilities: DestinationCapabilitiesContext, + destination_type: str, + warnings: bool = True, +) -> List[Exception]: + """Verifies schema tables before loading against capabilities. Returns a list of exceptions representing critical problems with the schema. + It will log warnings by default. It is up to the caller to eventually raise exception + + * Checks all table and column name lengths against destination capabilities and raises on too long identifiers + * Checks if schema has collisions due to case sensitivity of the identifiers + """ + + log = logger.warning if warnings else logger.info + # collect all exceptions to show all problems in the schema + exception_log: List[Exception] = [] + # combined casing function + case_identifier = lambda ident: capabilities.casefold_identifier( + (str if capabilities.has_case_sensitive_identifiers else str.casefold)(ident) # type: ignore + ) + table_name_lookup: DictStrStr = {} + # name collision explanation + collision_msg = "Destination is case " + ( + "sensitive" if capabilities.has_case_sensitive_identifiers else "insensitive" + ) + if capabilities.casefold_identifier is not str: + collision_msg += ( + f" but it uses {capabilities.casefold_identifier} to generate case insensitive" + " identifiers. You may try to change the destination capabilities by changing the" + " `casefold_identifier` to `str`" + ) + collision_msg += ( + ". Please clean up your data before loading so the entities have different name. You can" + " also change to case insensitive naming convention. Note that in that case data from both" + " columns will be merged into one." + ) + + # check for any table clashes + for table in schema.data_tables(): + table_name = table["name"] + # detect table name conflict + cased_table_name = case_identifier(table_name) + if cased_table_name in table_name_lookup: + conflict_table_name = table_name_lookup[cased_table_name] + exception_log.append( + SchemaIdentifierNormalizationCollision( + schema.name, + table_name, + "table", + table_name, + conflict_table_name, + schema.naming.name(), + collision_msg, + ) + ) + table_name_lookup[cased_table_name] = table_name + if len(table_name) > capabilities.max_identifier_length: + exception_log.append( + IdentifierTooLongException( + destination_type, + "table", + table_name, + capabilities.max_identifier_length, + ) + ) + + column_name_lookup: DictStrStr = {} + for column_name, column in dict(table["columns"]).items(): + # detect table name conflict + cased_column_name = case_identifier(column_name) + if cased_column_name in column_name_lookup: + conflict_column_name = column_name_lookup[cased_column_name] + exception_log.append( + SchemaIdentifierNormalizationCollision( + schema.name, + table_name, + "column", + column_name, + conflict_column_name, + schema.naming.name(), + collision_msg, + ) + ) + column_name_lookup[cased_column_name] = column_name + if len(column_name) > capabilities.max_column_identifier_length: + exception_log.append( + IdentifierTooLongException( + destination_type, + "column", + f"{table_name}.{column_name}", + capabilities.max_column_identifier_length, + ) + ) + if not is_complete_column(column): + log( + f"A column {column_name} in table {table_name} in schema" + f" {schema.name} is incomplete. It was not bound to the data during" + " normalizations stage and its data type is unknown. Did you add this" + " column manually in code ie. as a merge key?" + ) + return exception_log diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 8a6dc68078..ee249b111c 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -348,13 +348,13 @@ def normalize_py_arrow_item( def get_normalized_arrow_fields_mapping(schema: pyarrow.Schema, naming: NamingConvention) -> StrStr: - """Normalizes schema field names and returns mapping from original to normalized name. Raises on name clashes""" + """Normalizes schema field names and returns mapping from original to normalized name. Raises on name collisions""" norm_f = naming.normalize_identifier name_mapping = {n.name: norm_f(n.name) for n in schema} # verify if names uniquely normalize normalized_names = set(name_mapping.values()) if len(name_mapping) != len(normalized_names): - raise NameNormalizationClash( + raise NameNormalizationCollision( f"Arrow schema fields normalized from {list(name_mapping.keys())} to" f" {list(normalized_names)}" ) @@ -497,7 +497,7 @@ def cast_arrow_schema_types( return schema -class NameNormalizationClash(ValueError): +class NameNormalizationCollision(ValueError): def __init__(self, reason: str) -> None: - msg = f"Arrow column name clash after input data normalization. {reason}" + msg = f"Arrow column name collision after input data normalization. {reason}" super().__init__(msg) diff --git a/dlt/common/normalizers/__init__.py b/dlt/common/normalizers/__init__.py index 2ff41d4c12..af6add6a19 100644 --- a/dlt/common/normalizers/__init__.py +++ b/dlt/common/normalizers/__init__.py @@ -1,11 +1,9 @@ -from dlt.common.normalizers.configuration import NormalizersConfiguration from dlt.common.normalizers.typing import TJSONNormalizer, TNormalizersConfig -from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers +from dlt.common.normalizers.naming import NamingConvention + __all__ = [ - "NormalizersConfiguration", + "NamingConvention", "TJSONNormalizer", "TNormalizersConfig", - "explicit_normalizers", - "import_normalizers", ] diff --git a/dlt/common/normalizers/configuration.py b/dlt/common/normalizers/configuration.py index 54b725db1f..6011ba4774 100644 --- a/dlt/common/normalizers/configuration.py +++ b/dlt/common/normalizers/configuration.py @@ -1,9 +1,8 @@ -from typing import ClassVar, Optional, TYPE_CHECKING +from typing import ClassVar, Optional from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, known_sections -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.normalizers.typing import TJSONNormalizer +from dlt.common.normalizers.typing import TNamingConventionReferenceArg from dlt.common.typing import DictStrAny @@ -12,22 +11,6 @@ class NormalizersConfiguration(BaseConfiguration): # always in section __section__: ClassVar[str] = known_sections.SCHEMA - naming: Optional[str] = None + naming: Optional[TNamingConventionReferenceArg] = None # Union[str, NamingConvention] json_normalizer: Optional[DictStrAny] = None - destination_capabilities: Optional[DestinationCapabilitiesContext] = None # injectable - - def on_resolved(self) -> None: - # get naming from capabilities if not present - if self.naming is None: - if self.destination_capabilities: - self.naming = self.destination_capabilities.naming_convention - # if max_table_nesting is set, we need to set the max_table_nesting in the json_normalizer - if ( - self.destination_capabilities - and self.destination_capabilities.max_table_nesting is not None - ): - self.json_normalizer = self.json_normalizer or {} - self.json_normalizer.setdefault("config", {}) - self.json_normalizer["config"][ - "max_nesting" - ] = self.destination_capabilities.max_table_nesting + allow_identifier_change_on_table_with_data: Optional[bool] = None diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index bad275ca4f..91af42a6c5 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -5,7 +5,7 @@ from dlt.common.normalizers.typing import TJSONNormalizer from dlt.common.normalizers.utils import generate_dlt_id, DLT_ID_LENGTH_BYTES -from dlt.common.typing import DictStrAny, DictStrStr, TDataItem, StrAny +from dlt.common.typing import DictStrAny, TDataItem, StrAny from dlt.common.schema import Schema from dlt.common.schema.typing import ( TColumnSchema, @@ -23,28 +23,10 @@ ) from dlt.common.validation import validate_dict -EMPTY_KEY_IDENTIFIER = "_empty" # replace empty keys with this - - -class TDataItemRow(TypedDict, total=False): - _dlt_id: str # unique id of current row - - -class TDataItemRowRoot(TDataItemRow, total=False): - _dlt_load_id: (str) # load id to identify records loaded together that ie. need to be processed - # _dlt_meta: TEventDLTMeta # stores metadata, should never be sent to the normalizer - - -class TDataItemRowChild(TDataItemRow, total=False): - _dlt_root_id: str # unique id of top level parent - _dlt_parent_id: str # unique id of parent row - _dlt_list_idx: int # position in the list of rows - value: Any # for lists of simple types - class RelationalNormalizerConfigPropagation(TypedDict, total=False): - root: Optional[Mapping[str, TColumnName]] - tables: Optional[Mapping[str, Mapping[str, TColumnName]]] + root: Optional[Dict[TColumnName, TColumnName]] + tables: Optional[Dict[str, Dict[TColumnName, TColumnName]]] class RelationalNormalizerConfig(TypedDict, total=False): @@ -54,6 +36,23 @@ class RelationalNormalizerConfig(TypedDict, total=False): class DataItemNormalizer(DataItemNormalizerBase[RelationalNormalizerConfig]): + # known normalizer props + C_DLT_ID = "_dlt_id" + """unique id of current row""" + C_DLT_LOAD_ID = "_dlt_load_id" + """load id to identify records loaded together that ie. need to be processed""" + C_DLT_ROOT_ID = "_dlt_root_id" + """unique id of top level parent""" + C_DLT_PARENT_ID = "_dlt_parent_id" + """unique id of parent row""" + C_DLT_LIST_IDX = "_dlt_list_idx" + """position in the list of rows""" + C_VALUE = "value" + """for lists of simple types""" + + # other constants + EMPTY_KEY_IDENTIFIER = "_empty" # replace empty keys with this + normalizer_config: RelationalNormalizerConfig propagation_config: RelationalNormalizerConfigPropagation max_nesting: int @@ -63,12 +62,29 @@ def __init__(self, schema: Schema) -> None: """This item normalizer works with nested dictionaries. It flattens dictionaries and descends into lists. It yields row dictionaries at each nesting level.""" self.schema = schema + self.naming = schema.naming self._reset() def _reset(self) -> None: - self.normalizer_config = ( - self.schema._normalizers_config["json"].get("config") or {} # type: ignore[assignment] + # normalize known normalizer column identifiers + self.c_dlt_id: TColumnName = TColumnName(self.naming.normalize_identifier(self.C_DLT_ID)) + self.c_dlt_load_id: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_LOAD_ID) + ) + self.c_dlt_root_id: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_ROOT_ID) + ) + self.c_dlt_parent_id: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_PARENT_ID) + ) + self.c_dlt_list_idx: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_LIST_IDX) ) + self.c_value: TColumnName = TColumnName(self.naming.normalize_identifier(self.C_VALUE)) + + # normalize config + + self.normalizer_config = self.schema._normalizers_config["json"].get("config") or {} # type: ignore[assignment] self.propagation_config = self.normalizer_config.get("propagation", None) self.max_nesting = self.normalizer_config.get("max_nesting", 1000) self._skip_primary_key = {} @@ -103,8 +119,8 @@ def _is_complex_type(self, table_name: str, field_name: str, _r_lvl: int) -> boo return data_type == "complex" def _flatten( - self, table: str, dict_row: TDataItemRow, _r_lvl: int - ) -> Tuple[TDataItemRow, Dict[Tuple[str, ...], Sequence[Any]]]: + self, table: str, dict_row: DictStrAny, _r_lvl: int + ) -> Tuple[DictStrAny, Dict[Tuple[str, ...], Sequence[Any]]]: out_rec_row: DictStrAny = {} out_rec_list: Dict[Tuple[str, ...], Sequence[Any]] = {} schema_naming = self.schema.naming @@ -115,7 +131,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - norm_k = schema_naming.normalize_identifier(k) else: # for empty keys in the data use _ - norm_k = EMPTY_KEY_IDENTIFIER + norm_k = self.EMPTY_KEY_IDENTIFIER # if norm_k != k: # print(f"{k} -> {norm_k}") child_name = ( @@ -139,7 +155,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - out_rec_row[child_name] = v norm_row_dicts(dict_row, _r_lvl) - return cast(TDataItemRow, out_rec_row), out_rec_list + return out_rec_row, out_rec_list @staticmethod def get_row_hash(row: Dict[str, Any]) -> str: @@ -160,7 +176,7 @@ def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> return digest128(f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES) @staticmethod - def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDataItemRowChild: + def _link_row(row: DictStrAny, parent_row_id: str, list_idx: int) -> DictStrAny: assert parent_row_id row["_dlt_parent_id"] = parent_row_id row["_dlt_list_idx"] = list_idx @@ -168,11 +184,11 @@ def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDat return row @staticmethod - def _extend_row(extend: DictStrAny, row: TDataItemRow) -> None: - row.update(extend) # type: ignore + def _extend_row(extend: DictStrAny, row: DictStrAny) -> None: + row.update(extend) def _add_row_id( - self, table: str, row: TDataItemRow, parent_row_id: str, pos: int, _r_lvl: int + self, table: str, row: DictStrAny, parent_row_id: str, pos: int, _r_lvl: int ) -> str: # row_id is always random, no matter if primary_key is present or not row_id = generate_dlt_id() @@ -182,17 +198,17 @@ def _add_row_id( # child table row deterministic hash row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) # link to parent table - DataItemNormalizer._link_row(cast(TDataItemRowChild, row), parent_row_id, pos) - row["_dlt_id"] = row_id + DataItemNormalizer._link_row(row, parent_row_id, pos) + row[self.c_dlt_id] = row_id return row_id - def _get_propagated_values(self, table: str, row: TDataItemRow, _r_lvl: int) -> StrAny: + def _get_propagated_values(self, table: str, row: DictStrAny, _r_lvl: int) -> StrAny: extend: DictStrAny = {} config = self.propagation_config if config: # mapping(k:v): propagate property with name "k" as property with name "v" in child table - mappings: DictStrStr = {} + mappings: Dict[TColumnName, TColumnName] = {} if _r_lvl == 0: mappings.update(config.get("root") or {}) if table in (config.get("tables") or {}): @@ -200,7 +216,7 @@ def _get_propagated_values(self, table: str, row: TDataItemRow, _r_lvl: int) -> # look for keys and create propagation as values for prop_from, prop_as in mappings.items(): if prop_from in row: - extend[prop_as] = row[prop_from] # type: ignore + extend[prop_as] = row[prop_from] return extend @@ -214,7 +230,7 @@ def _normalize_list( parent_row_id: Optional[str] = None, _r_lvl: int = 0, ) -> TNormalizedRowIterator: - v: TDataItemRowChild = None + v: DictStrAny = None table = self.schema.naming.shorten_fragments(*parent_path, *ident_path) for idx, v in enumerate(seq): @@ -238,14 +254,14 @@ def _normalize_list( # list of simple types child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx) wrap_v = wrap_in_dict(v) - wrap_v["_dlt_id"] = child_row_hash + wrap_v[self.c_dlt_id] = child_row_hash e = DataItemNormalizer._link_row(wrap_v, parent_row_id, idx) DataItemNormalizer._extend_row(extend, e) yield (table, self.schema.naming.shorten_fragments(*parent_path)), e def _normalize_row( self, - dict_row: TDataItemRow, + dict_row: DictStrAny, extend: DictStrAny, ident_path: Tuple[str, ...], parent_path: Tuple[str, ...] = (), @@ -258,14 +274,14 @@ def _normalize_row( table = schema.naming.shorten_fragments(*parent_path, *ident_path) # compute row hash and set as row id if row_hash: - row_id = self.get_row_hash(dict_row) # type: ignore[arg-type] - dict_row["_dlt_id"] = row_id + row_id = self.get_row_hash(dict_row) + dict_row[self.c_dlt_id] = row_id # flatten current row and extract all lists to recur into flattened_row, lists = self._flatten(table, dict_row, _r_lvl) # always extend row DataItemNormalizer._extend_row(extend, flattened_row) # infer record hash or leave existing primary key if present - row_id = flattened_row.get("_dlt_id", None) + row_id = flattened_row.get(self.c_dlt_id, None) if not row_id: row_id = self._add_row_id(table, flattened_row, parent_row_id, pos, _r_lvl) @@ -292,43 +308,55 @@ def _normalize_row( ) def extend_schema(self) -> None: - # validate config + """Extends Schema with normalizer-specific hints and settings. + + This method is called by Schema when instance is created or restored from storage. + """ config = cast( RelationalNormalizerConfig, self.schema._normalizers_config["json"].get("config") or {}, ) DataItemNormalizer._validate_normalizer_config(self.schema, config) - # quick check to see if hints are applied - default_hints = self.schema.settings.get("default_hints") or {} - if "not_null" in default_hints and "^_dlt_id$" in default_hints["not_null"]: - return - # add hints - self.schema.merge_hints( + # add hints, do not compile. + self.schema._merge_hints( { "not_null": [ - TSimpleRegex("_dlt_id"), - TSimpleRegex("_dlt_root_id"), - TSimpleRegex("_dlt_parent_id"), - TSimpleRegex("_dlt_list_idx"), - TSimpleRegex("_dlt_load_id"), + TSimpleRegex(self.c_dlt_id), + TSimpleRegex(self.c_dlt_root_id), + TSimpleRegex(self.c_dlt_parent_id), + TSimpleRegex(self.c_dlt_list_idx), + TSimpleRegex(self.c_dlt_load_id), ], - "foreign_key": [TSimpleRegex("_dlt_parent_id")], - "root_key": [TSimpleRegex("_dlt_root_id")], - "unique": [TSimpleRegex("_dlt_id")], - } + "foreign_key": [TSimpleRegex(self.c_dlt_parent_id)], + "root_key": [TSimpleRegex(self.c_dlt_root_id)], + "unique": [TSimpleRegex(self.c_dlt_id)], + }, + normalize_identifiers=False, # already normalized ) for table_name in self.schema.tables.keys(): self.extend_table(table_name) def extend_table(self, table_name: str) -> None: - # if the table has a merge w_d, add propagation info to normalizer + """If the table has a merge write disposition, add propagation info to normalizer + + Called by Schema when new table is added to schema or table is updated with partial table. + Table name should be normalized. + """ table = self.schema.tables.get(table_name) if not table.get("parent") and table.get("write_disposition") == "merge": DataItemNormalizer.update_normalizer_config( self.schema, - {"propagation": {"tables": {table_name: {"_dlt_id": TColumnName("_dlt_root_id")}}}}, + { + "propagation": { + "tables": { + table_name: { + TColumnName(self.c_dlt_id): TColumnName(self.c_dlt_root_id) + } + } + } + }, ) def normalize_data_item( @@ -338,18 +366,20 @@ def normalize_data_item( if not isinstance(item, dict): item = wrap_in_dict(item) # we will extend event with all the fields necessary to load it as root row - row = cast(TDataItemRowRoot, item) + row = cast(DictStrAny, item) # identify load id if loaded data must be processed after loading incrementally - row["_dlt_load_id"] = load_id + row[self.c_dlt_load_id] = load_id + # determine if row hash should be used as dlt id row_hash = False if self._is_scd2_table(self.schema, table_name): - row_hash = self._dlt_id_is_row_hash(self.schema, table_name) + row_hash = self._dlt_id_is_row_hash(self.schema, table_name, self.c_dlt_id) self._validate_validity_column_names( - self._get_validity_column_names(self.schema, table_name), item + self.schema.name, self._get_validity_column_names(self.schema, table_name), item ) + yield from self._normalize_row( - cast(TDataItemRowChild, row), + row, {}, (self.schema.naming.normalize_table_identifier(table_name),), row_hash=row_hash, @@ -365,12 +395,12 @@ def ensure_this_normalizer(cls, norm_config: TJSONNormalizer) -> None: @classmethod def update_normalizer_config(cls, schema: Schema, config: RelationalNormalizerConfig) -> None: cls._validate_normalizer_config(schema, config) - norm_config = schema._normalizers_config["json"] - cls.ensure_this_normalizer(norm_config) - if "config" in norm_config: - update_dict_nested(norm_config["config"], config) # type: ignore + existing_config = schema._normalizers_config["json"] + cls.ensure_this_normalizer(existing_config) + if "config" in existing_config: + update_dict_nested(existing_config["config"], config) # type: ignore else: - norm_config["config"] = config + existing_config["config"] = config @classmethod def get_normalizer_config(cls, schema: Schema) -> RelationalNormalizerConfig: @@ -380,6 +410,29 @@ def get_normalizer_config(cls, schema: Schema) -> RelationalNormalizerConfig: @staticmethod def _validate_normalizer_config(schema: Schema, config: RelationalNormalizerConfig) -> None: + """Normalizes all known column identifiers according to the schema and then validates the configuration""" + + def _normalize_prop( + mapping: Mapping[TColumnName, TColumnName] + ) -> Dict[TColumnName, TColumnName]: + return { + TColumnName(schema.naming.normalize_path(from_col)): TColumnName( + schema.naming.normalize_path(to_col) + ) + for from_col, to_col in mapping.items() + } + + # normalize the identifiers first + propagation_config = config.get("propagation") + if propagation_config: + if "root" in propagation_config: + propagation_config["root"] = _normalize_prop(propagation_config["root"]) + if "tables" in propagation_config: + for table_name in propagation_config["tables"]: + propagation_config["tables"][table_name] = _normalize_prop( + propagation_config["tables"][table_name] + ) + validate_dict( RelationalNormalizerConfig, config, @@ -410,21 +463,22 @@ def _get_validity_column_names(schema: Schema, table_name: str) -> List[Optional @staticmethod @lru_cache(maxsize=None) - def _dlt_id_is_row_hash(schema: Schema, table_name: str) -> bool: + def _dlt_id_is_row_hash(schema: Schema, table_name: str, c_dlt_id: str) -> bool: return ( schema.get_table(table_name)["columns"] # type: ignore[return-value] - .get("_dlt_id", {}) + .get(c_dlt_id, {}) .get("x-row-version", False) ) @staticmethod def _validate_validity_column_names( - validity_column_names: List[Optional[str]], item: TDataItem + schema_name: str, validity_column_names: List[Optional[str]], item: TDataItem ) -> None: """Raises exception if configured validity column name appears in data item.""" for validity_column_name in validity_column_names: if validity_column_name in item.keys(): raise ColumnNameConflictException( + schema_name, "Found column in data item with same name as validity column" - f' "{validity_column_name}".' + f' "{validity_column_name}".', ) diff --git a/dlt/common/normalizers/naming/__init__.py b/dlt/common/normalizers/naming/__init__.py index 967fb9643e..2b3ecd74d0 100644 --- a/dlt/common/normalizers/naming/__init__.py +++ b/dlt/common/normalizers/naming/__init__.py @@ -1,3 +1,3 @@ -from .naming import SupportsNamingConvention, NamingConvention +from .naming import NamingConvention -__all__ = ["SupportsNamingConvention", "NamingConvention"] +__all__ = ["NamingConvention"] diff --git a/dlt/common/normalizers/naming/direct.py b/dlt/common/normalizers/naming/direct.py index 0998650852..fc146dbc4c 100644 --- a/dlt/common/normalizers/naming/direct.py +++ b/dlt/common/normalizers/naming/direct.py @@ -1,20 +1,23 @@ -from typing import Any, Sequence +from typing import ClassVar from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention class NamingConvention(BaseNamingConvention): - PATH_SEPARATOR = "▶" + """Case sensitive naming convention that maps source identifiers to destination identifiers with + only minimal changes. New line characters, double and single quotes are replaced with underscores. - _CLEANUP_TABLE = str.maketrans(".\n\r'\"▶", "______") + Uses ▶ as path separator. + """ + + PATH_SEPARATOR: ClassVar[str] = "▶" + _CLEANUP_TABLE = str.maketrans("\n\r'\"▶", "_____") def normalize_identifier(self, identifier: str) -> str: identifier = super().normalize_identifier(identifier) norm_identifier = identifier.translate(self._CLEANUP_TABLE) return self.shorten_identifier(norm_identifier, identifier, self.max_length) - def make_path(self, *identifiers: Any) -> str: - return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers)) - - def break_path(self, path: str) -> Sequence[str]: - return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] + @property + def is_case_sensitive(self) -> bool: + return True diff --git a/dlt/common/normalizers/naming/duck_case.py b/dlt/common/normalizers/naming/duck_case.py index 063482a799..3801660ba8 100644 --- a/dlt/common/normalizers/naming/duck_case.py +++ b/dlt/common/normalizers/naming/duck_case.py @@ -5,8 +5,15 @@ class NamingConvention(SnakeCaseNamingConvention): + """Case sensitive naming convention preserving all unicode characters except new line(s). Uses __ for path + separation and will replace multiple underscores with a single one. + """ + _CLEANUP_TABLE = str.maketrans('\n\r"', "___") - _RE_LEADING_DIGITS = None # do not remove leading digits + + @property + def is_case_sensitive(self) -> bool: + return True @staticmethod @lru_cache(maxsize=None) @@ -17,5 +24,5 @@ def _normalize_identifier(identifier: str, max_length: int) -> str: # shorten identifier return NamingConvention.shorten_identifier( - NamingConvention._RE_UNDERSCORES.sub("_", normalized_ident), identifier, max_length + NamingConvention.RE_UNDERSCORES.sub("_", normalized_ident), identifier, max_length ) diff --git a/dlt/common/normalizers/naming/exceptions.py b/dlt/common/normalizers/naming/exceptions.py index 572fc7e0d0..0b22ae2dd5 100644 --- a/dlt/common/normalizers/naming/exceptions.py +++ b/dlt/common/normalizers/naming/exceptions.py @@ -5,21 +5,33 @@ class NormalizersException(DltException): pass -class UnknownNamingModule(NormalizersException): +class UnknownNamingModule(ImportError, NormalizersException): def __init__(self, naming_module: str) -> None: self.naming_module = naming_module if "." in naming_module: msg = f"Naming module {naming_module} could not be found and imported" else: - msg = f"Naming module {naming_module} is not one of the standard dlt naming convention" + msg = ( + f"Naming module {naming_module} is not one of the standard dlt naming conventions" + " and could not be locally imported" + ) super().__init__(msg) -class InvalidNamingModule(NormalizersException): - def __init__(self, naming_module: str) -> None: +class NamingTypeNotFound(ImportError, NormalizersException): + def __init__(self, naming_module: str, naming_class: str) -> None: + self.naming_module = naming_module + self.naming_class = naming_class + msg = f"In naming module '{naming_module}' type '{naming_class}' does not exist" + super().__init__(msg) + + +class InvalidNamingType(NormalizersException): + def __init__(self, naming_module: str, naming_class: str) -> None: self.naming_module = naming_module + self.naming_class = naming_class msg = ( - f"Naming module {naming_module} does not implement required SupportsNamingConvention" - " protocol" + f"In naming module '{naming_module}' the class '{naming_class}' is not a" + " NamingConvention" ) super().__init__(msg) diff --git a/dlt/common/normalizers/naming/naming.py b/dlt/common/normalizers/naming/naming.py index fccb147981..5ae5847963 100644 --- a/dlt/common/normalizers/naming/naming.py +++ b/dlt/common/normalizers/naming/naming.py @@ -3,16 +3,28 @@ from functools import lru_cache import math import hashlib -from typing import Any, List, Protocol, Sequence, Type +from typing import Sequence, ClassVar class NamingConvention(ABC): - _TR_TABLE = bytes.maketrans(b"/+", b"ab") - _DEFAULT_COLLISION_PROB = 0.001 + """Initializes naming convention to generate identifier with `max_length` if specified. Base naming convention + is case sensitive by default + """ + + _TR_TABLE: ClassVar[bytes] = bytes.maketrans(b"/+", b"ab") + _DEFAULT_COLLISION_PROB: ClassVar[float] = 0.001 + PATH_SEPARATOR: ClassVar[str] = "__" + """Subsequent nested fields will be separated with the string below, applies both to field and table names""" def __init__(self, max_length: int = None) -> None: self.max_length = max_length + @property + @abstractmethod + def is_case_sensitive(self) -> bool: + """Tells if given naming convention is producing case insensitive or case sensitive identifiers.""" + pass + @abstractmethod def normalize_identifier(self, identifier: str) -> str: """Normalizes and shortens the identifier according to naming convention in this function code""" @@ -27,15 +39,13 @@ def normalize_table_identifier(self, identifier: str) -> str: """Normalizes and shortens identifier that will function as a dataset, table or schema name, defaults to `normalize_identifier`""" return self.normalize_identifier(identifier) - @abstractmethod def make_path(self, *identifiers: str) -> str: """Builds path out of identifiers. Identifiers are neither normalized nor shortened""" - pass + return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers)) - @abstractmethod def break_path(self, path: str) -> Sequence[str]: """Breaks path into sequence of identifiers""" - pass + return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] def normalize_path(self, path: str) -> str: """Breaks path into identifiers, normalizes components, reconstitutes and shortens the path""" @@ -58,6 +68,21 @@ def shorten_fragments(self, *normalized_idents: str) -> str: path_str = self.make_path(*normalized_idents) return self.shorten_identifier(path_str, path_str, self.max_length) + @classmethod + def name(cls) -> str: + """Naming convention name is the name of the module in which NamingConvention is defined""" + if cls.__module__.startswith("dlt.common.normalizers.naming."): + # return last component + return cls.__module__.split(".")[-1] + return cls.__module__ + + def __str__(self) -> str: + name = self.name() + name += "_cs" if self.is_case_sensitive else "_ci" + if self.max_length: + name += f"_{self.max_length}" + return name + @staticmethod @lru_cache(maxsize=None) def shorten_identifier( @@ -100,10 +125,3 @@ def _trim_and_tag(identifier: str, tag: str, max_length: int) -> str: ) assert len(identifier) == max_length return identifier - - -class SupportsNamingConvention(Protocol): - """Expected of modules defining naming convention""" - - NamingConvention: Type[NamingConvention] - """A class with a name NamingConvention deriving from normalizers.naming.NamingConvention""" diff --git a/dlt/common/normalizers/naming/snake_case.py b/dlt/common/normalizers/naming/snake_case.py index b3c65e9b8d..d38841a238 100644 --- a/dlt/common/normalizers/naming/snake_case.py +++ b/dlt/common/normalizers/naming/snake_case.py @@ -1,42 +1,54 @@ import re -from typing import Any, List, Sequence from functools import lru_cache +from typing import ClassVar from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention +from dlt.common.normalizers.naming.sql_cs_v1 import ( + RE_UNDERSCORES, + RE_LEADING_DIGITS, + RE_NON_ALPHANUMERIC, +) +from dlt.common.typing import REPattern class NamingConvention(BaseNamingConvention): - _RE_UNDERSCORES = re.compile("__+") - _RE_LEADING_DIGITS = re.compile(r"^\d+") - # _RE_ENDING_UNDERSCORES = re.compile(r"_+$") - _RE_NON_ALPHANUMERIC = re.compile(r"[^a-zA-Z\d_]+") + """Case insensitive naming convention, converting source identifiers into lower case snake case with reduced alphabet. + + - Spaces around identifier are trimmed + - Removes all ascii characters except ascii alphanumerics and underscores + - Prepends `_` if name starts with number. + - Multiples of `_` are converted into single `_`. + - Replaces all trailing `_` with `x` + - Replaces `+` and `*` with `x`, `-` with `_`, `@` with `a` and `|` with `l` + + Uses __ as patent-child separator for tables and flattened column names. + """ + + RE_UNDERSCORES: ClassVar[REPattern] = RE_UNDERSCORES + RE_LEADING_DIGITS: ClassVar[REPattern] = RE_LEADING_DIGITS + RE_NON_ALPHANUMERIC: ClassVar[REPattern] = RE_NON_ALPHANUMERIC + _SNAKE_CASE_BREAK_1 = re.compile("([^_])([A-Z][a-z]+)") _SNAKE_CASE_BREAK_2 = re.compile("([a-z0-9])([A-Z])") _REDUCE_ALPHABET = ("+-*@|", "x_xal") _TR_REDUCE_ALPHABET = str.maketrans(_REDUCE_ALPHABET[0], _REDUCE_ALPHABET[1]) - # subsequent nested fields will be separated with the string below, applies both to field and table names - PATH_SEPARATOR = "__" + @property + def is_case_sensitive(self) -> bool: + return False def normalize_identifier(self, identifier: str) -> str: identifier = super().normalize_identifier(identifier) # print(f"{identifier} -> {self.shorten_identifier(identifier, self.max_length)} ({self.max_length})") return self._normalize_identifier(identifier, self.max_length) - def make_path(self, *identifiers: str) -> str: - # only non empty identifiers participate - return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers)) - - def break_path(self, path: str) -> Sequence[str]: - return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] - @staticmethod @lru_cache(maxsize=None) def _normalize_identifier(identifier: str, max_length: int) -> str: """Normalizes the identifier according to naming convention represented by this function""" # all characters that are not letters digits or a few special chars are replaced with underscore normalized_ident = identifier.translate(NamingConvention._TR_REDUCE_ALPHABET) - normalized_ident = NamingConvention._RE_NON_ALPHANUMERIC.sub("_", normalized_ident) + normalized_ident = NamingConvention.RE_NON_ALPHANUMERIC.sub("_", normalized_ident) # shorten identifier return NamingConvention.shorten_identifier( @@ -50,7 +62,7 @@ def _to_snake_case(cls, identifier: str) -> str: identifier = cls._SNAKE_CASE_BREAK_2.sub(r"\1_\2", identifier).lower() # leading digits will be prefixed (if regex is defined) - if cls._RE_LEADING_DIGITS and cls._RE_LEADING_DIGITS.match(identifier): + if cls.RE_LEADING_DIGITS and cls.RE_LEADING_DIGITS.match(identifier): identifier = "_" + identifier # replace trailing _ with x @@ -59,5 +71,5 @@ def _to_snake_case(cls, identifier: str) -> str: stripped_ident += "x" * strip_count # identifier = cls._RE_ENDING_UNDERSCORES.sub("x", identifier) - # replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR - return cls._RE_UNDERSCORES.sub("_", stripped_ident) + # replace consecutive underscores with single one to prevent name collisions with PATH_SEPARATOR + return cls.RE_UNDERSCORES.sub("_", stripped_ident) diff --git a/dlt/common/normalizers/naming/sql_ci_v1.py b/dlt/common/normalizers/naming/sql_ci_v1.py new file mode 100644 index 0000000000..4fff52ffd6 --- /dev/null +++ b/dlt/common/normalizers/naming/sql_ci_v1.py @@ -0,0 +1,12 @@ +from dlt.common.normalizers.naming.sql_cs_v1 import NamingConvention as SqlCsNamingConvention + + +class NamingConvention(SqlCsNamingConvention): + """A variant of sql_cs which lower cases all identifiers.""" + + def normalize_identifier(self, identifier: str) -> str: + return super().normalize_identifier(identifier).lower() + + @property + def is_case_sensitive(self) -> bool: + return False diff --git a/dlt/common/normalizers/naming/sql_cs_v1.py b/dlt/common/normalizers/naming/sql_cs_v1.py new file mode 100644 index 0000000000..788089fa7d --- /dev/null +++ b/dlt/common/normalizers/naming/sql_cs_v1.py @@ -0,0 +1,44 @@ +import re +from typing import ClassVar + +from dlt.common.typing import REPattern +from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention + + +RE_UNDERSCORES = re.compile("__+") +RE_LEADING_DIGITS = re.compile(r"^\d+") +RE_ENDING_UNDERSCORES = re.compile(r"_+$") +RE_NON_ALPHANUMERIC = re.compile(r"[^a-zA-Z\d_]+") + + +class NamingConvention(BaseNamingConvention): + """Generates case sensitive SQL safe identifiers, preserving the source casing. + + - Spaces around identifier are trimmed + - Removes all ascii characters except ascii alphanumerics and underscores + - Prepends `_` if name starts with number. + - Removes all trailing underscores. + - Multiples of `_` are converted into single `_`. + """ + + RE_NON_ALPHANUMERIC: ClassVar[REPattern] = RE_NON_ALPHANUMERIC + RE_UNDERSCORES: ClassVar[REPattern] = RE_UNDERSCORES + RE_ENDING_UNDERSCORES: ClassVar[REPattern] = RE_ENDING_UNDERSCORES + + def normalize_identifier(self, identifier: str) -> str: + identifier = super().normalize_identifier(identifier) + # remove non alpha characters + norm_identifier = self.RE_NON_ALPHANUMERIC.sub("_", identifier) + # remove leading digits + if RE_LEADING_DIGITS.match(norm_identifier): + norm_identifier = "_" + norm_identifier + # remove trailing underscores to not mess with how we break paths + if norm_identifier != "_": + norm_identifier = self.RE_ENDING_UNDERSCORES.sub("", norm_identifier) + # contract multiple __ + norm_identifier = self.RE_UNDERSCORES.sub("_", norm_identifier) + return self.shorten_identifier(norm_identifier, identifier, self.max_length) + + @property + def is_case_sensitive(self) -> bool: + return True diff --git a/dlt/common/normalizers/typing.py b/dlt/common/normalizers/typing.py index 599426259f..9ea6f3cf11 100644 --- a/dlt/common/normalizers/typing.py +++ b/dlt/common/normalizers/typing.py @@ -1,14 +1,19 @@ -from typing import List, Optional, TypedDict +from types import ModuleType +from typing import List, Optional, Type, TypedDict, Union from dlt.common.typing import StrAny +from dlt.common.normalizers.naming import NamingConvention + +TNamingConventionReferenceArg = Union[str, Type[NamingConvention], ModuleType] class TJSONNormalizer(TypedDict, total=False): module: str - config: Optional[StrAny] # config is a free form and is consumed by `module` + config: Optional[StrAny] # config is a free form and is validated by `module` class TNormalizersConfig(TypedDict, total=False): names: str + allow_identifier_change_on_table_with_data: Optional[bool] detections: Optional[List[str]] json: TJSONNormalizer diff --git a/dlt/common/normalizers/utils.py b/dlt/common/normalizers/utils.py index 645bad2bea..beacf03e4e 100644 --- a/dlt/common/normalizers/utils.py +++ b/dlt/common/normalizers/utils.py @@ -1,60 +1,164 @@ from importlib import import_module -from typing import Any, Type, Tuple, cast, List +from types import ModuleType +from typing import Any, Dict, Optional, Type, Tuple, cast, List import dlt +from dlt.common import logger from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs import known_sections from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers.configuration import NormalizersConfiguration +from dlt.common.normalizers.exceptions import InvalidJsonNormalizer from dlt.common.normalizers.json import SupportsDataItemNormalizer, DataItemNormalizer -from dlt.common.normalizers.naming import NamingConvention, SupportsNamingConvention -from dlt.common.normalizers.naming.exceptions import UnknownNamingModule, InvalidNamingModule -from dlt.common.normalizers.typing import TJSONNormalizer, TNormalizersConfig -from dlt.common.utils import uniq_id_base64, many_uniq_ids_base64 +from dlt.common.normalizers.naming import NamingConvention +from dlt.common.normalizers.naming.exceptions import ( + NamingTypeNotFound, + UnknownNamingModule, + InvalidNamingType, +) +from dlt.common.normalizers.typing import ( + TJSONNormalizer, + TNormalizersConfig, + TNamingConventionReferenceArg, +) +from dlt.common.typing import is_subclass +from dlt.common.utils import get_full_class_name, uniq_id_base64, many_uniq_ids_base64 -DEFAULT_NAMING_MODULE = "dlt.common.normalizers.naming.snake_case" +DEFAULT_NAMING_NAMESPACE = "dlt.common.normalizers.naming" DLT_ID_LENGTH_BYTES = 10 +DEFAULT_NAMING_MODULE = "snake_case" -@with_config(spec=NormalizersConfiguration) +def _section_for_schema(kwargs: Dict[str, Any]) -> Tuple[str, ...]: + """Uses the schema name to generate dynamic section normalizer settings""" + if schema_name := kwargs.get("schema_name"): + return (known_sections.SOURCES, schema_name) + else: + return (known_sections.SOURCES,) + + +@with_config(spec=NormalizersConfiguration, sections=_section_for_schema) # type: ignore[call-overload] def explicit_normalizers( - naming: str = dlt.config.value, json_normalizer: TJSONNormalizer = dlt.config.value + naming: TNamingConventionReferenceArg = dlt.config.value, + json_normalizer: TJSONNormalizer = dlt.config.value, + allow_identifier_change_on_table_with_data: bool = None, + schema_name: Optional[str] = None, ) -> TNormalizersConfig: - """Gets explicitly configured normalizers - via config or destination caps. May return None as naming or normalizer""" - return {"names": naming, "json": json_normalizer} + """Gets explicitly configured normalizers without any defaults or capabilities injection. If `naming` + is a module or a type it will get converted into string form via import. + + If `schema_name` is present, a section ("sources", schema_name, "schema") is used to inject the config + """ + + norm_conf: TNormalizersConfig = {"names": serialize_reference(naming), "json": json_normalizer} + if allow_identifier_change_on_table_with_data is not None: + norm_conf["allow_identifier_change_on_table_with_data"] = ( + allow_identifier_change_on_table_with_data + ) + return norm_conf @with_config def import_normalizers( - normalizers_config: TNormalizersConfig, + explicit_normalizers: TNormalizersConfig, + default_normalizers: TNormalizersConfig = None, destination_capabilities: DestinationCapabilitiesContext = None, ) -> Tuple[TNormalizersConfig, NamingConvention, Type[DataItemNormalizer[Any]]]: """Imports the normalizers specified in `normalizers_config` or taken from defaults. Returns the updated config and imported modules. - `destination_capabilities` are used to get max length of the identifier. + `destination_capabilities` are used to get naming convention, max length of the identifier and max nesting level. """ + if default_normalizers is None: + default_normalizers = {} # add defaults to normalizer_config - normalizers_config["names"] = names = normalizers_config["names"] or "snake_case" - # set default json normalizer module - normalizers_config["json"] = item_normalizer = normalizers_config.get("json") or {} - if "module" not in item_normalizer: - item_normalizer["module"] = "dlt.common.normalizers.json.relational" - - try: - if "." in names: + naming: TNamingConventionReferenceArg = explicit_normalizers.get("names") + if naming is None: + if destination_capabilities: + naming = destination_capabilities.naming_convention + if naming is None: + naming = default_normalizers.get("names") or DEFAULT_NAMING_MODULE + naming_convention = naming_from_reference(naming, destination_capabilities) + explicit_normalizers["names"] = serialize_reference(naming) + + item_normalizer = explicit_normalizers.get("json") or default_normalizers.get("json") or {} + item_normalizer.setdefault("module", "dlt.common.normalizers.json.relational") + # if max_table_nesting is set, we need to set the max_table_nesting in the json_normalizer + if destination_capabilities and destination_capabilities.max_table_nesting is not None: + # TODO: this is a hack, we need a better method to do this + from dlt.common.normalizers.json.relational import DataItemNormalizer + + try: + DataItemNormalizer.ensure_this_normalizer(item_normalizer) + item_normalizer.setdefault("config", {}) + item_normalizer["config"]["max_nesting"] = destination_capabilities.max_table_nesting # type: ignore[index] + except InvalidJsonNormalizer: + # not a right normalizer + logger.warning(f"JSON Normalizer {item_normalizer} does not support max_nesting") + pass + json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) + explicit_normalizers["json"] = item_normalizer + return ( + explicit_normalizers, + naming_convention, + json_module.DataItemNormalizer, + ) + + +def naming_from_reference( + names: TNamingConventionReferenceArg, + destination_capabilities: DestinationCapabilitiesContext = None, +) -> NamingConvention: + """Resolves naming convention from reference in `names` and applies max length from `destination_capabilities` + + Reference may be: (1) shorthand name pointing to `dlt.common.normalizers.naming` namespace + (2) a type name which is a module containing `NamingConvention` attribute (3) a type of class deriving from NamingConvention + """ + + def _import_naming(module: str) -> ModuleType: + if "." in module: # TODO: bump schema engine version and migrate schema. also change the name in TNormalizersConfig from names to naming - if names == "dlt.common.normalizers.names.snake_case": - names = DEFAULT_NAMING_MODULE + if module == "dlt.common.normalizers.names.snake_case": + module = f"{DEFAULT_NAMING_NAMESPACE}.{DEFAULT_NAMING_MODULE}" # this is full module name - naming_module = cast(SupportsNamingConvention, import_module(names)) + naming_module = import_module(module) else: # from known location - naming_module = cast( - SupportsNamingConvention, import_module(f"dlt.common.normalizers.naming.{names}") - ) - except ImportError: - raise UnknownNamingModule(names) - if not hasattr(naming_module, "NamingConvention"): - raise InvalidNamingModule(names) + try: + naming_module = import_module(f"{DEFAULT_NAMING_NAMESPACE}.{module}") + except ImportError: + # also import local module + naming_module = import_module(module) + return naming_module + + def _get_type(naming_module: ModuleType, cls: str) -> Type[NamingConvention]: + class_: Type[NamingConvention] = getattr(naming_module, cls, None) + if class_ is None: + raise NamingTypeNotFound(naming_module.__name__, cls) + if is_subclass(class_, NamingConvention): + return class_ + raise InvalidNamingType(naming_module.__name__, cls) + + if is_subclass(names, NamingConvention): + class_: Type[NamingConvention] = names # type: ignore[assignment] + elif isinstance(names, ModuleType): + class_ = _get_type(names, "NamingConvention") + elif isinstance(names, str): + try: + class_ = _get_type(_import_naming(names), "NamingConvention") + except ImportError: + parts = names.rsplit(".", 1) + # we have no more options to try + if len(parts) <= 1: + raise UnknownNamingModule(names) + try: + class_ = _get_type(_import_naming(parts[0]), parts[1]) + except UnknownNamingModule: + raise + except ImportError: + raise UnknownNamingModule(names) + else: + raise ValueError(names) + # get max identifier length if destination_capabilities: max_length = min( @@ -63,13 +167,18 @@ def import_normalizers( ) else: max_length = None - json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) - return ( - normalizers_config, - naming_module.NamingConvention(max_length), - json_module.DataItemNormalizer, - ) + return class_(max_length) + + +def serialize_reference(naming: Optional[TNamingConventionReferenceArg]) -> Optional[str]: + """Serializes generic `naming` reference to importable string.""" + if naming is None: + return naming + if isinstance(naming, str): + return naming + # import reference and use naming to get valid path to type + return get_full_class_name(naming_from_reference(naming)) def generate_dlt_ids(n_ids: int) -> List[str]: diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 6cefdd9e6c..c6ee27e58b 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -260,9 +260,6 @@ def asstr(self, verbosity: int = 0) -> str: return self._load_packages_asstr(self.load_packages, verbosity) -# reveal_type(ExtractInfo) - - class NormalizeMetrics(StepMetrics): job_metrics: Dict[str, DataWriterMetrics] """Metrics collected per job id during writing of job file""" diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 678f4de15e..2f016577ce 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -7,37 +7,45 @@ TSchemaContractEntities, TSchemaEvolutionMode, ) +from dlt.common.normalizers.naming import NamingConvention class SchemaException(DltException): - pass + def __init__(self, schema_name: str, msg: str) -> None: + self.schema_name = schema_name + if schema_name: + msg = f"In schema: {schema_name}: " + msg + super().__init__(msg) class InvalidSchemaName(ValueError, SchemaException): MAXIMUM_SCHEMA_NAME_LENGTH = 64 - def __init__(self, name: str) -> None: - self.name = name + def __init__(self, schema_name: str) -> None: + self.name = schema_name super().__init__( - f"{name} is an invalid schema/source name. The source or schema name must be a valid" - " Python identifier ie. a snake case function name and have maximum" + schema_name, + f"{schema_name} is an invalid schema/source name. The source or schema name must be a" + " valid Python identifier ie. a snake case function name and have maximum" f" {self.MAXIMUM_SCHEMA_NAME_LENGTH} characters. Ideally should contain only small" - " letters, numbers and underscores." + " letters, numbers and underscores.", ) -class InvalidDatasetName(ValueError, SchemaException): - def __init__(self, destination_name: str) -> None: - self.destination_name = destination_name - super().__init__( - f"Destination {destination_name} does not accept empty datasets. Please pass the" - " dataset name to the destination configuration ie. via dlt pipeline." - ) +# TODO: does not look like a SchemaException +# class InvalidDatasetName(ValueError, SchemaException): +# def __init__(self, destination_name: str) -> None: +# self.destination_name = destination_name +# super().__init__( +# f"Destination {destination_name} does not accept empty datasets. Please pass the" +# " dataset name to the destination configuration ie. via dlt pipeline." +# ) class CannotCoerceColumnException(SchemaException): def __init__( self, + schema_name: str, table_name: str, column_name: str, from_type: TDataType, @@ -50,37 +58,43 @@ def __init__( self.to_type = to_type self.coerced_value = coerced_value super().__init__( + schema_name, f"Cannot coerce type in table {table_name} column {column_name} existing type" - f" {from_type} coerced type {to_type} value: {coerced_value}" + f" {from_type} coerced type {to_type} value: {coerced_value}", ) class TablePropertiesConflictException(SchemaException): - def __init__(self, table_name: str, prop_name: str, val1: str, val2: str): + def __init__(self, schema_name: str, table_name: str, prop_name: str, val1: str, val2: str): self.table_name = table_name self.prop_name = prop_name self.val1 = val1 self.val2 = val2 super().__init__( + schema_name, f"Cannot merge partial tables for {table_name} due to property {prop_name}: {val1} !=" - f" {val2}" + f" {val2}", ) class ParentTableNotFoundException(SchemaException): - def __init__(self, table_name: str, parent_table_name: str, explanation: str = "") -> None: + def __init__( + self, schema_name: str, table_name: str, parent_table_name: str, explanation: str = "" + ) -> None: self.table_name = table_name self.parent_table_name = parent_table_name super().__init__( + schema_name, f"Parent table {parent_table_name} for {table_name} was not found in the" - f" schema.{explanation}" + f" schema.{explanation}", ) class CannotCoerceNullException(SchemaException): - def __init__(self, table_name: str, column_name: str) -> None: + def __init__(self, schema_name: str, table_name: str, column_name: str) -> None: super().__init__( - f"Cannot coerce NULL in table {table_name} column {column_name} which is not nullable" + schema_name, + f"Cannot coerce NULL in table {table_name} column {column_name} which is not nullable", ) @@ -88,19 +102,48 @@ class SchemaCorruptedException(SchemaException): pass +class SchemaIdentifierNormalizationCollision(SchemaCorruptedException): + def __init__( + self, + schema_name: str, + table_name: str, + identifier_type: str, + identifier_name: str, + conflict_identifier_name: str, + naming_name: str, + collision_msg: str, + ) -> None: + if identifier_type == "column": + table_info = f"in table {table_name} " + else: + table_info = "" + msg = ( + f"A {identifier_type} name {identifier_name} {table_info}collides with" + f" {conflict_identifier_name} after normalization with {naming_name} naming" + " convention. " + + collision_msg + ) + self.table_name = table_name + self.identifier_type = identifier_type + self.identifier_name = identifier_name + self.conflict_identifier_name = conflict_identifier_name + self.naming_name = naming_name + super().__init__(schema_name, msg) + + class SchemaEngineNoUpgradePathException(SchemaException): def __init__( self, schema_name: str, init_engine: int, from_engine: int, to_engine: int ) -> None: - self.schema_name = schema_name self.init_engine = init_engine self.from_engine = from_engine self.to_engine = to_engine super().__init__( + schema_name, f"No engine upgrade path in schema {schema_name} from {init_engine} to {to_engine}," f" stopped at {from_engine}. You possibly tried to run an older dlt" " version against a destination you have previously loaded data to with a newer dlt" - " version." + " version.", ) @@ -133,8 +176,7 @@ def __init__( + f" . Contract on {schema_entity} with mode {contract_mode} is violated. " + (extended_info or "") ) - super().__init__(msg) - self.schema_name = schema_name + super().__init__(schema_name, msg) self.table_name = table_name self.column_name = column_name @@ -148,10 +190,43 @@ def __init__( self.data_item = data_item -class UnknownTableException(SchemaException): - def __init__(self, table_name: str) -> None: +class UnknownTableException(KeyError, SchemaException): + def __init__(self, schema_name: str, table_name: str) -> None: self.table_name = table_name - super().__init__(f"Trying to access unknown table {table_name}.") + super().__init__(schema_name, f"Trying to access unknown table {table_name}.") + + +class TableIdentifiersFrozen(SchemaException): + def __init__( + self, + schema_name: str, + table_name: str, + to_naming: NamingConvention, + from_naming: NamingConvention, + details: str, + ) -> None: + self.table_name = table_name + self.to_naming = to_naming + self.from_naming = from_naming + msg = ( + f"Attempt to normalize identifiers for a table {table_name} from naming" + f" {from_naming.name()} to {to_naming.name()} changed one or more identifiers. " + ) + msg += ( + " This table already received data and tables were created at the destination. By" + " default changing the identifiers is not allowed. " + ) + msg += ( + " Such changes may result in creation of a new table or a new columns while the old" + " columns with data will still be kept. " + ) + msg += ( + " You may disable this behavior by setting" + " schema.allow_identifier_change_on_table_with_data to True or removing `x-normalizer`" + " hints from particular tables. " + ) + msg += f" Details: {details}" + super().__init__(schema_name, msg) class ColumnNameConflictException(SchemaException): diff --git a/dlt/common/schema/migrations.py b/dlt/common/schema/migrations.py index 9b206d61a6..b64714ba19 100644 --- a/dlt/common/schema/migrations.py +++ b/dlt/common/schema/migrations.py @@ -1,7 +1,7 @@ from typing import Dict, List, cast from dlt.common.data_types import TDataType -from dlt.common.normalizers import explicit_normalizers +from dlt.common.normalizers.utils import explicit_normalizers from dlt.common.typing import DictStrAny from dlt.common.schema.typing import ( LOADS_TABLE_NAME, @@ -14,7 +14,7 @@ from dlt.common.schema.exceptions import SchemaEngineNoUpgradePathException from dlt.common.normalizers.utils import import_normalizers -from dlt.common.schema.utils import new_table, version_table, load_table +from dlt.common.schema.utils import new_table, version_table, loads_table def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: @@ -29,7 +29,8 @@ def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> # current version of the schema current = cast(TStoredSchema, schema_dict) # add default normalizers and root hash propagation - current["normalizers"], _, _ = import_normalizers(explicit_normalizers()) + normalizers = explicit_normalizers() + current["normalizers"], _, _ = import_normalizers(normalizers, normalizers) current["normalizers"]["json"]["config"] = { "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} } @@ -92,11 +93,11 @@ def migrate_filters(group: str, filters: List[str]) -> None: if from_engine == 4 and to_engine > 4: # replace schema versions table schema_dict["tables"][VERSION_TABLE_NAME] = version_table() - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + schema_dict["tables"][LOADS_TABLE_NAME] = loads_table() from_engine = 5 if from_engine == 5 and to_engine > 5: # replace loads table - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + schema_dict["tables"][LOADS_TABLE_NAME] = loads_table() from_engine = 6 if from_engine == 6 and to_engine > 6: # migrate from sealed properties to schema evolution settings diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 6d5dc48907..52f8545587 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -1,5 +1,16 @@ from copy import copy, deepcopy -from typing import ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, Any, cast, Literal +from typing import ( + Callable, + ClassVar, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Any, + cast, +) from dlt.common.schema.migrations import migrate_schema from dlt.common.utils import extend_list_deduplicated @@ -11,8 +22,8 @@ VARIANT_FIELD_FORMAT, TDataItem, ) -from dlt.common.normalizers import TNormalizersConfig, explicit_normalizers, import_normalizers -from dlt.common.normalizers.naming import NamingConvention +from dlt.common.normalizers import TNormalizersConfig, NamingConvention +from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers from dlt.common.normalizers.json import DataItemNormalizer, TNormalizedRowIterator from dlt.common.schema import utils from dlt.common.data_types import py_type_to_sc_type, coerce_value, TDataType @@ -22,7 +33,7 @@ SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, VERSION_TABLE_NAME, - STATE_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, TPartialTableSchema, TSchemaContractEntities, TSchemaEvolutionMode, @@ -45,6 +56,7 @@ InvalidSchemaName, ParentTableNotFoundException, SchemaCorruptedException, + TableIdentifiersFrozen, ) from dlt.common.validation import validate_dict from dlt.common.schema.exceptions import DataValidationError @@ -102,13 +114,18 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: self._reset_schema(name, normalizers) @classmethod - def from_dict(cls, d: DictStrAny, bump_version: bool = True) -> "Schema": + def from_dict( + cls, d: DictStrAny, remove_processing_hints: bool = False, bump_version: bool = True + ) -> "Schema": # upgrade engine if needed stored_schema = migrate_schema(d, d["engine_version"], cls.ENGINE_VERSION) # verify schema utils.validate_stored_schema(stored_schema) # add defaults stored_schema = utils.apply_defaults(stored_schema) + # remove processing hints that could be created by normalize and load steps + if remove_processing_hints: + utils.remove_processing_hints(stored_schema["tables"]) # bump version if modified if bump_version: @@ -141,30 +158,6 @@ def replace_schema_content( self._reset_schema(schema.name, schema._normalizers_config) self._from_stored_schema(stored_schema) - def to_dict(self, remove_defaults: bool = False, bump_version: bool = True) -> TStoredSchema: - stored_schema: TStoredSchema = { - "version": self._stored_version, - "version_hash": self._stored_version_hash, - "engine_version": Schema.ENGINE_VERSION, - "name": self._schema_name, - "tables": self._schema_tables, - "settings": self._settings, - "normalizers": self._normalizers_config, - "previous_hashes": self._stored_previous_hashes, - } - if self._imported_version_hash and not remove_defaults: - stored_schema["imported_version_hash"] = self._imported_version_hash - if self._schema_description: - stored_schema["description"] = self._schema_description - - # bump version if modified - if bump_version: - utils.bump_version_if_modified(stored_schema) - # remove defaults after bumping version - if remove_defaults: - utils.remove_defaults(stored_schema) - return stored_schema - def normalize_data_item( self, item: TDataItem, load_id: str, table_name: str ) -> TNormalizedRowIterator: @@ -317,7 +310,7 @@ def apply_schema_contract( column_mode, data_mode = schema_contract["columns"], schema_contract["data_type"] # allow to add new columns when table is new or if columns are allowed to evolve once - if is_new_table or existing_table.get("x-normalizer", {}).get("evolve-columns-once", False): # type: ignore[attr-defined] + if is_new_table or existing_table.get("x-normalizer", {}).get("evolve-columns-once", False): column_mode = "evolve" # check if we should filter any columns, partial table below contains only new columns @@ -402,14 +395,20 @@ def resolve_contract_settings_for_table( # expand settings, empty settings will expand into default settings return Schema.expand_schema_contract_settings(settings) - def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchema: - """Adds or merges `partial_table` into the schema. Identifiers are not normalized""" + def update_table( + self, partial_table: TPartialTableSchema, normalize_identifiers: bool = True + ) -> TPartialTableSchema: + """Adds or merges `partial_table` into the schema. Identifiers are normalized by default""" + if normalize_identifiers: + partial_table = utils.normalize_table_identifiers(partial_table, self.naming) + table_name = partial_table["name"] parent_table_name = partial_table.get("parent") # check if parent table present if parent_table_name is not None: if self._schema_tables.get(parent_table_name) is None: raise ParentTableNotFoundException( + self.name, table_name, parent_table_name, " This may be due to misconfigured excludes filter that fully deletes content" @@ -422,21 +421,20 @@ def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchem self._schema_tables[table_name] = partial_table else: # merge tables performing additional checks - partial_table = utils.merge_table(table, partial_table) + partial_table = utils.merge_table(self.name, table, partial_table) self.data_item_normalizer.extend_table(table_name) return partial_table def update_schema(self, schema: "Schema") -> None: """Updates this schema from an incoming schema. Normalizes identifiers after updating normalizers.""" - # update all tables - for table in schema.tables.values(): - self.update_table(table) # pass normalizer config - self._configure_normalizers(schema._normalizers_config) - # update and compile settings self._settings = deepcopy(schema.settings) + self._configure_normalizers(schema._normalizers_config) self._compile_settings() + # update all tables + for table in schema.tables.values(): + self.update_table(table) def drop_tables( self, table_names: Sequence[str], seen_data_only: bool = False @@ -467,67 +465,60 @@ def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: Str # dicts are ordered and we will return the rows with hints in the same order as they appear in the columns return rv_row - def merge_hints(self, new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]]) -> None: - # validate regexes - validate_dict( - TSchemaSettings, - {"default_hints": new_hints}, - ".", - validator_f=utils.simple_regex_validator, - ) - # prepare hints to be added - default_hints = self._settings.setdefault("default_hints", {}) - # add `new_hints` to existing hints - for h, l in new_hints.items(): - if h in default_hints: - extend_list_deduplicated(default_hints[h], l) - else: - # set new hint type - default_hints[h] = l # type: ignore + def merge_hints( + self, + new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]], + normalize_identifiers: bool = True, + ) -> None: + """Merges existing default hints with `new_hints`. Normalizes names in column regexes if possible. Compiles setting at the end + + NOTE: you can manipulate default hints collection directly via `Schema.settings` as long as you call Schema._compile_settings() at the end. + """ + self._merge_hints(new_hints, normalize_identifiers) self._compile_settings() - def normalize_table_identifiers(self, table: TTableSchema) -> TTableSchema: - """Normalizes all table and column names in `table` schema according to current schema naming convention and returns - new normalized TTableSchema instance. + def update_preferred_types( + self, + new_preferred_types: Mapping[TSimpleRegex, TDataType], + normalize_identifiers: bool = True, + ) -> None: + """Updates preferred types dictionary with `new_preferred_types`. Normalizes names in column regexes if possible. Compiles setting at the end - Naming convention like snake_case may produce name clashes with the column names. Clashing column schemas are merged - where the column that is defined later in the dictionary overrides earlier column. + NOTE: you can manipulate preferred hints collection directly via `Schema.settings` as long as you call Schema._compile_settings() at the end. + """ + self._update_preferred_types(new_preferred_types, normalize_identifiers) + self._compile_settings() - Note that resource name is not normalized. + def add_type_detection(self, detection: TTypeDetections) -> None: + """Add type auto detection to the schema.""" + if detection not in self.settings["detections"]: + self.settings["detections"].append(detection) + self._compile_settings() - """ - # normalize all identifiers in table according to name normalizer of the schema - table["name"] = self.naming.normalize_tables_path(table["name"]) - parent = table.get("parent") - if parent: - table["parent"] = self.naming.normalize_tables_path(parent) - columns = table.get("columns") - if columns: - new_columns: TTableSchemaColumns = {} - for c in columns.values(): - new_col_name = c["name"] = self.naming.normalize_path(c["name"]) - # re-index columns as the name changed, if name space was reduced then - # some columns now clash with each other. so make sure that we merge columns that are already there - if new_col_name in new_columns: - new_columns[new_col_name] = utils.merge_column( - new_columns[new_col_name], c, merge_defaults=False - ) - else: - new_columns[new_col_name] = c - table["columns"] = new_columns - return table + def remove_type_detection(self, detection: TTypeDetections) -> None: + """Adds type auto detection to the schema.""" + if detection in self.settings["detections"]: + self.settings["detections"].remove(detection) + self._compile_settings() def get_new_table_columns( self, table_name: str, - exiting_columns: TTableSchemaColumns, + existing_columns: TTableSchemaColumns, + case_sensitive: bool = True, include_incomplete: bool = False, ) -> List[TColumnSchema]: - """Gets new columns to be added to `exiting_columns` to bring them up to date with `table_name` schema. Optionally includes incomplete columns (without data type)""" + """Gets new columns to be added to `existing_columns` to bring them up to date with `table_name` schema. + Columns names are compared case sensitive by default. + Optionally includes incomplete columns (without data type)""" + casefold_f: Callable[[str], str] = str.casefold if not case_sensitive else str # type: ignore[assignment] + casefold_existing = { + casefold_f(col_name): col for col_name, col in existing_columns.items() + } diff_c: List[TColumnSchema] = [] s_t = self.get_table_columns(table_name, include_incomplete=include_incomplete) for c in s_t.values(): - if c["name"] not in exiting_columns: + if casefold_f(c["name"]) not in casefold_existing: diff_c.append(c) return diff_c @@ -651,20 +642,70 @@ def tables(self) -> TSchemaTables: def settings(self) -> TSchemaSettings: return self._settings - def to_pretty_json(self, remove_defaults: bool = True) -> str: - d = self.to_dict(remove_defaults=remove_defaults) + def to_dict( + self, + remove_defaults: bool = False, + remove_processing_hints: bool = False, + bump_version: bool = True, + ) -> TStoredSchema: + stored_schema: TStoredSchema = { + "version": self._stored_version, + "version_hash": self._stored_version_hash, + "engine_version": Schema.ENGINE_VERSION, + "name": self._schema_name, + "tables": self._schema_tables, + "settings": self._settings, + "normalizers": self._normalizers_config, + "previous_hashes": self._stored_previous_hashes, + } + if self._imported_version_hash and not remove_defaults: + stored_schema["imported_version_hash"] = self._imported_version_hash + if self._schema_description: + stored_schema["description"] = self._schema_description + + # remove processing hints that could be created by normalize and load steps + if remove_processing_hints: + stored_schema["tables"] = utils.remove_processing_hints( + deepcopy(stored_schema["tables"]) + ) + + # bump version if modified + if bump_version: + utils.bump_version_if_modified(stored_schema) + # remove defaults after bumping version + if remove_defaults: + utils.remove_defaults(stored_schema) + return stored_schema + + def to_pretty_json( + self, remove_defaults: bool = True, remove_processing_hints: bool = False + ) -> str: + d = self.to_dict( + remove_defaults=remove_defaults, remove_processing_hints=remove_processing_hints + ) return utils.to_pretty_json(d) - def to_pretty_yaml(self, remove_defaults: bool = True) -> str: - d = self.to_dict(remove_defaults=remove_defaults) + def to_pretty_yaml( + self, remove_defaults: bool = True, remove_processing_hints: bool = False + ) -> str: + d = self.to_dict( + remove_defaults=remove_defaults, remove_processing_hints=remove_processing_hints + ) return utils.to_pretty_yaml(d) - def clone(self, with_name: str = None, update_normalizers: bool = False) -> "Schema": - """Make a deep copy of the schema, optionally changing the name, and updating normalizers and identifiers in the schema if `update_normalizers` is True - - Note that changing of name will set the schema as new + def clone( + self, + with_name: str = None, + remove_processing_hints: bool = False, + update_normalizers: bool = False, + ) -> "Schema": + """Make a deep copy of the schema, optionally changing the name, removing processing markers and updating normalizers and identifiers in the schema if `update_normalizers` is True + Processing markers are `x-` hints created by normalizer (`x-normalizer`) and loader (`x-loader`) to ie. mark newly inferred tables and tables that seen data. + Note that changing of name will break the previous version chain """ - d = deepcopy(self.to_dict(bump_version=False)) + d = deepcopy( + self.to_dict(bump_version=False, remove_processing_hints=remove_processing_hints) + ) if with_name is not None: d["version"] = d["version_hash"] = None d.pop("imported_version_hash", None) @@ -677,12 +718,15 @@ def clone(self, with_name: str = None, update_normalizers: bool = False) -> "Sch return schema def update_normalizers(self) -> None: - """Looks for new normalizer configuration or for destination capabilities context and updates all identifiers in the schema""" - normalizers = explicit_normalizers() - # set the current values as defaults - normalizers["names"] = normalizers["names"] or self._normalizers_config["names"] - normalizers["json"] = normalizers["json"] or self._normalizers_config["json"] - self._configure_normalizers(normalizers) + """Looks for new normalizer configuration or for destination capabilities context and updates all identifiers in the schema + + Table and column names will be normalized with new naming convention, except tables that have seen data ('x-normalizer`) which will + raise if any identifier is to be changed. + Default hints, preferred data types and normalize configs (ie. column propagation) are normalized as well. Regexes are included as long + as textual parts can be extracted from an expression. + """ + self._configure_normalizers(explicit_normalizers(schema_name=self._schema_name)) + self._compile_settings() def set_schema_contract(self, settings: TSchemaContract) -> None: if not settings: @@ -690,18 +734,6 @@ def set_schema_contract(self, settings: TSchemaContract) -> None: else: self._settings["schema_contract"] = settings - def add_type_detection(self, detection: TTypeDetections) -> None: - """Add type auto detection to the schema.""" - if detection not in self.settings["detections"]: - self.settings["detections"].append(detection) - self._compile_settings() - - def remove_type_detection(self, detection: TTypeDetections) -> None: - """Adds type auto detection to the schema.""" - if detection in self.settings["detections"]: - self.settings["detections"].remove(detection) - self._compile_settings() - def _infer_column( self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False ) -> TColumnSchema: @@ -727,7 +759,7 @@ def _coerce_null_value( if col_name in table_columns: existing_column = table_columns[col_name] if not existing_column.get("nullable", True): - raise CannotCoerceNullException(table_name, col_name) + raise CannotCoerceNullException(self.name, table_name, col_name) def _coerce_non_null_value( self, @@ -759,7 +791,12 @@ def _coerce_non_null_value( if is_variant: # this is final call: we cannot generate any more auto-variants raise CannotCoerceColumnException( - table_name, col_name, py_type, table_columns[col_name]["data_type"], v + self.name, + table_name, + col_name, + py_type, + table_columns[col_name]["data_type"], + v, ) # otherwise we must create variant extension to the table # pass final=True so no more auto-variants can be created recursively @@ -816,6 +853,57 @@ def _infer_hint(self, hint_type: TColumnHint, _: Any, col_name: str) -> bool: else: return False + def _merge_hints( + self, + new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]], + normalize_identifiers: bool = True, + ) -> None: + """Used by `merge_hints method, does not compile settings at the end""" + # validate regexes + validate_dict( + TSchemaSettings, + {"default_hints": new_hints}, + ".", + validator_f=utils.simple_regex_validator, + ) + if normalize_identifiers: + new_hints = self._normalize_default_hints(new_hints) + # prepare hints to be added + default_hints = self._settings.setdefault("default_hints", {}) + # add `new_hints` to existing hints + for h, l in new_hints.items(): + if h in default_hints: + extend_list_deduplicated(default_hints[h], l, utils.canonical_simple_regex) + else: + # set new hint type + default_hints[h] = l # type: ignore + + def _update_preferred_types( + self, + new_preferred_types: Mapping[TSimpleRegex, TDataType], + normalize_identifiers: bool = True, + ) -> None: + # validate regexes + validate_dict( + TSchemaSettings, + {"preferred_types": new_preferred_types}, + ".", + validator_f=utils.simple_regex_validator, + ) + if normalize_identifiers: + new_preferred_types = self._normalize_preferred_types(new_preferred_types) + preferred_types = self._settings.setdefault("preferred_types", {}) + # we must update using canonical simple regex + canonical_preferred = { + utils.canonical_simple_regex(rx): rx for rx in preferred_types.keys() + } + for new_rx, new_dt in new_preferred_types.items(): + canonical_new_rx = utils.canonical_simple_regex(new_rx) + if canonical_new_rx not in canonical_preferred: + preferred_types[new_rx] = new_dt + else: + preferred_types[canonical_preferred[canonical_new_rx]] = new_dt + def _bump_version(self) -> Tuple[int, str]: """Computes schema hash in order to check if schema content was modified. In such case the schema ``stored_version`` and ``stored_version_hash`` are updated. @@ -839,40 +927,126 @@ def _drop_version(self) -> None: self._stored_version_hash = self._stored_previous_hashes.pop(0) def _add_standard_tables(self) -> None: - self._schema_tables[self.version_table_name] = self.normalize_table_identifiers( - utils.version_table() + self._schema_tables[self.version_table_name] = utils.normalize_table_identifiers( + utils.version_table(), self.naming ) - self._schema_tables[self.loads_table_name] = self.normalize_table_identifiers( - utils.load_table() + self._schema_tables[self.loads_table_name] = utils.normalize_table_identifiers( + utils.loads_table(), self.naming ) def _add_standard_hints(self) -> None: - default_hints = utils.standard_hints() + default_hints = utils.default_hints() if default_hints: - self._settings["default_hints"] = default_hints + self._merge_hints(default_hints, normalize_identifiers=False) type_detections = utils.standard_type_detections() if type_detections: self._settings["detections"] = type_detections - def _configure_normalizers(self, normalizers: TNormalizersConfig) -> None: - # import desired modules - self._normalizers_config, naming_module, item_normalizer_class = import_normalizers( - normalizers - ) - # print(f"{self.name}: {type(self.naming)} {type(naming_module)}") - if self.naming and type(self.naming) is not type(naming_module): - self.naming = naming_module + def _normalize_default_hints( + self, default_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]] + ) -> Dict[TColumnHint, List[TSimpleRegex]]: + """Normalizes the column names in default hints. In case of column names that are regexes, normalization is skipped""" + return { + hint: [utils.normalize_simple_regex_column(self.naming, regex) for regex in regexes] + for hint, regexes in default_hints.items() + } + + def _normalize_preferred_types( + self, preferred_types: Mapping[TSimpleRegex, TDataType] + ) -> Dict[TSimpleRegex, TDataType]: + """Normalizes the column names in preferred types mapping. In case of column names that are regexes, normalization is skipped""" + return { + utils.normalize_simple_regex_column(self.naming, regex): data_type + for regex, data_type in preferred_types.items() + } + + def _verify_update_normalizers( + self, + normalizers_config: TNormalizersConfig, + to_naming: NamingConvention, + from_naming: NamingConvention, + ) -> TSchemaTables: + """Verifies if normalizers can be updated before schema is changed""" + # print(f"{self.name}: {type(to_naming)} {type(naming_module)}") + if from_naming and type(from_naming) is not type(to_naming): + schema_tables = {} for table in self._schema_tables.values(): - self.normalize_table_identifiers(table) + norm_table = utils.normalize_table_identifiers(table, to_naming) + if utils.has_table_seen_data(norm_table) and not normalizers_config.get( + "allow_identifier_change_on_table_with_data", False + ): + # make sure no identifier got changed in table + if norm_table["name"] != table["name"]: + raise TableIdentifiersFrozen( + self.name, + table["name"], + to_naming, + from_naming, + f"Attempt to rename table name to {norm_table['name']}.", + ) + if len(norm_table["columns"]) != len(table["columns"]): + raise TableIdentifiersFrozen( + self.name, + table["name"], + to_naming, + from_naming, + "Number of columns changed after normalization. Some columns must have" + " merged.", + ) + col_diff = set(norm_table["columns"].keys()).difference(table["columns"].keys()) + if len(col_diff) > 0: + raise TableIdentifiersFrozen( + self.name, + table["name"], + to_naming, + from_naming, + f"Some columns got renamed to {col_diff}.", + ) + schema_tables[norm_table["name"]] = norm_table # re-index the table names - self._schema_tables = {t["name"]: t for t in self._schema_tables.values()} + return schema_tables + else: + return self._schema_tables + def _renormalize_schema_identifiers( + self, + normalizers_config: TNormalizersConfig, + to_naming: NamingConvention, + from_naming: NamingConvention, + ) -> None: + """Normalizes all identifiers in the schema in place""" + self._schema_tables = self._verify_update_normalizers( + normalizers_config, to_naming, from_naming + ) + self._normalizers_config = normalizers_config + self.naming = to_naming # name normalization functions - self.naming = naming_module - self._dlt_tables_prefix = self.naming.normalize_table_identifier(DLT_NAME_PREFIX) - self.version_table_name = self.naming.normalize_table_identifier(VERSION_TABLE_NAME) - self.loads_table_name = self.naming.normalize_table_identifier(LOADS_TABLE_NAME) - self.state_table_name = self.naming.normalize_table_identifier(STATE_TABLE_NAME) + self._dlt_tables_prefix = to_naming.normalize_table_identifier(DLT_NAME_PREFIX) + self.version_table_name = to_naming.normalize_table_identifier(VERSION_TABLE_NAME) + self.loads_table_name = to_naming.normalize_table_identifier(LOADS_TABLE_NAME) + self.state_table_name = to_naming.normalize_table_identifier(PIPELINE_STATE_TABLE_NAME) + # do a sanity check - dlt tables must start with dlt prefix + for table_name in [self.version_table_name, self.loads_table_name, self.state_table_name]: + if not table_name.startswith(self._dlt_tables_prefix): + raise SchemaCorruptedException( + self.name, + f"A naming convention {self.naming.name()} mangles _dlt table prefix to" + f" '{self._dlt_tables_prefix}'. A table '{table_name}' does not start with it.", + ) + # normalize default hints + if default_hints := self._settings.get("default_hints"): + self._settings["default_hints"] = self._normalize_default_hints(default_hints) + # normalized preferred types + if preferred_types := self.settings.get("preferred_types"): + self._settings["preferred_types"] = self._normalize_preferred_types(preferred_types) + + def _configure_normalizers(self, explicit_normalizers: TNormalizersConfig) -> None: + """Gets naming and item normalizer from schema yaml, config providers and destination capabilities and applies them to schema.""" + # import desired modules + normalizers_config, to_naming, item_normalizer_class = import_normalizers( + explicit_normalizers, self._normalizers_config + ) + self._renormalize_schema_identifiers(normalizers_config, to_naming, self.naming) # data item normalization function self.data_item_normalizer = item_normalizer_class(self) self.data_item_normalizer.extend_schema() @@ -903,7 +1077,7 @@ def _reset_schema(self, name: str, normalizers: TNormalizersConfig = None) -> No self._add_standard_hints() # configure normalizers, including custom config if present if not normalizers: - normalizers = explicit_normalizers() + normalizers = explicit_normalizers(schema_name=self._schema_name) self._configure_normalizers(normalizers) # add version tables self._add_standard_tables() @@ -913,9 +1087,13 @@ def _reset_schema(self, name: str, normalizers: TNormalizersConfig = None) -> No def _from_stored_schema(self, stored_schema: TStoredSchema) -> None: self._schema_tables = stored_schema.get("tables") or {} if self.version_table_name not in self._schema_tables: - raise SchemaCorruptedException(f"Schema must contain table {self.version_table_name}") + raise SchemaCorruptedException( + stored_schema["name"], f"Schema must contain table {self.version_table_name}" + ) if self.loads_table_name not in self._schema_tables: - raise SchemaCorruptedException(f"Schema must contain table {self.loads_table_name}") + raise SchemaCorruptedException( + stored_schema["name"], f"Schema must contain table {self.loads_table_name}" + ) self._stored_version = stored_schema["version"] self._stored_version_hash = stored_schema["version_hash"] self._imported_version_hash = stored_schema.get("imported_version_hash") diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index fb360b38d3..b5081c5ff4 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -17,8 +17,7 @@ from dlt.common.data_types import TDataType from dlt.common.normalizers.typing import TNormalizersConfig -from dlt.common.typing import TSortOrder, TAnyDateTime -from dlt.common.pendulum import pendulum +from dlt.common.typing import TSortOrder, TAnyDateTime, TLoaderFileFormat try: from pydantic import BaseModel as _PydanticBaseModel @@ -32,7 +31,7 @@ # dlt tables VERSION_TABLE_NAME = "_dlt_version" LOADS_TABLE_NAME = "_dlt_loads" -STATE_TABLE_NAME = "_dlt_pipeline_state" +PIPELINE_STATE_TABLE_NAME = "_dlt_pipeline_state" DLT_NAME_PREFIX = "_dlt" TColumnProp = Literal[ @@ -47,6 +46,7 @@ "unique", "merge_key", "root_key", + "hard_delete", "dedup_sort", ] """Known properties and hints of the column""" @@ -59,12 +59,16 @@ "foreign_key", "sort", "unique", - "root_key", "merge_key", + "root_key", + "hard_delete", "dedup_sort", ] """Known hints of a column used to declare hint regexes.""" + +TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTableFormat = Literal["iceberg", "delta"] +TFileFormat = Literal[Literal["preferred"], TLoaderFileFormat] TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" ] @@ -72,7 +76,7 @@ TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" -COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) +# COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( [ "partition", @@ -153,9 +157,18 @@ class NormalizerInfo(TypedDict, total=True): new_table: bool -TWriteDisposition = Literal["skip", "append", "replace", "merge"] -TLoaderMergeStrategy = Literal["delete-insert", "scd2"] +# Part of Table containing processing hints added by pipeline stages +TTableProcessingHints = TypedDict( + "TTableProcessingHints", + { + "x-normalizer": Optional[Dict[str, Any]], + "x-loader": Optional[Dict[str, Any]], + "x-extractor": Optional[Dict[str, Any]], + }, + total=False, +) +TLoaderMergeStrategy = Literal["delete-insert", "scd2"] WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition)) MERGE_STRATEGIES: Set[TLoaderMergeStrategy] = set(get_args(TLoaderMergeStrategy)) @@ -178,7 +191,8 @@ class TMergeDispositionDict(TWriteDispositionDict, total=False): TWriteDispositionConfig = Union[TWriteDisposition, TWriteDispositionDict, TMergeDispositionDict] -class TTableSchema(TypedDict, total=False): +# TypedDict that defines properties of a table +class TTableSchema(TTableProcessingHints, total=False): """TypedDict that defines properties of a table""" name: Optional[str] @@ -191,6 +205,7 @@ class TTableSchema(TypedDict, total=False): columns: TTableSchemaColumns resource: Optional[str] table_format: Optional[TTableFormat] + file_format: Optional[TFileFormat] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 51269cbb38..f5765be351 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -7,6 +7,7 @@ from dlt.common.pendulum import pendulum from dlt.common.time import ensure_pendulum_datetime +from dlt.common import logger from dlt.common.json import json from dlt.common.data_types import TDataType from dlt.common.exceptions import DictValidationException @@ -21,12 +22,15 @@ LOADS_TABLE_NAME, SIMPLE_REGEX_PREFIX, VERSION_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, TColumnName, + TFileFormat, TPartialTableSchema, TSchemaTables, TSchemaUpdate, TSimpleRegex, TStoredSchema, + TTableProcessingHints, TTableSchema, TColumnSchemaBase, TColumnSchema, @@ -96,7 +100,8 @@ def apply_defaults(stored_schema: TStoredSchema) -> TStoredSchema: def remove_defaults(stored_schema: TStoredSchema) -> TStoredSchema: """Removes default values from `stored_schema` in place, returns the input for chaining - Default values are removed from table schemas and complete column schemas. Incomplete columns are preserved intact. + * removes column and table names from the value + * removed resource name if same as table name """ clean_tables = deepcopy(stored_schema["tables"]) for table_name, t in clean_tables.items(): @@ -202,6 +207,33 @@ def verify_schema_hash( return hash_ == stored_schema["version_hash"] +def normalize_simple_regex_column(naming: NamingConvention, regex: TSimpleRegex) -> TSimpleRegex: + """Assumes that regex applies to column name and normalizes it.""" + + def _normalize(r_: str) -> str: + is_exact = len(r_) >= 2 and r_[0] == "^" and r_[-1] == "$" + if is_exact: + r_ = r_[1:-1] + # if this a simple string then normalize it + if r_ == re.escape(r_): + r_ = naming.normalize_path(r_) + if is_exact: + r_ = "^" + r_ + "$" + return r_ + + if regex.startswith(SIMPLE_REGEX_PREFIX): + return cast(TSimpleRegex, SIMPLE_REGEX_PREFIX + _normalize(regex[3:])) + else: + return cast(TSimpleRegex, _normalize(regex)) + + +def canonical_simple_regex(regex: str) -> TSimpleRegex: + if regex.startswith(SIMPLE_REGEX_PREFIX): + return cast(TSimpleRegex, regex) + else: + return cast(TSimpleRegex, SIMPLE_REGEX_PREFIX + "^" + regex + "$") + + def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: # custom validator on type TSimpleRegex if t is TSimpleRegex: @@ -237,7 +269,7 @@ def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: # we know how to validate that type return True else: - # don't know how to validate t + # don't know how to validate this return False @@ -299,7 +331,9 @@ def validate_stored_schema(stored_schema: TStoredSchema) -> None: parent_table_name = table.get("parent") if parent_table_name: if parent_table_name not in stored_schema["tables"]: - raise ParentTableNotFoundException(table_name, parent_table_name) + raise ParentTableNotFoundException( + stored_schema["name"], table_name, parent_table_name + ) def autodetect_sc_type(detection_fs: Sequence[TTypeDetections], t: Type[Any], v: Any) -> TDataType: @@ -370,7 +404,9 @@ def merge_columns( return columns_a -def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTableSchema: +def diff_table( + schema_name: str, tab_a: TTableSchema, tab_b: TPartialTableSchema +) -> TPartialTableSchema: """Creates a partial table that contains properties found in `tab_b` that are not present or different in `tab_a`. The name is always present in returned partial. It returns new columns (not present in tab_a) and merges columns from tab_b into tab_a (overriding non-default hint values). @@ -384,7 +420,7 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable # check if table properties can be merged if tab_a.get("parent") != tab_b.get("parent"): raise TablePropertiesConflictException( - table_name, "parent", tab_a.get("parent"), tab_b.get("parent") + schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent") ) # get new columns, changes in the column data type or other properties are not allowed @@ -398,6 +434,7 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable if not compare_complete_columns(tab_a_columns[col_b_name], col_b): # attempt to update to incompatible columns raise CannotCoerceColumnException( + schema_name, table_name, col_b_name, col_b["data_type"], @@ -426,7 +463,7 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable # this should not really happen if tab_a.get("parent") is not None and (resource := tab_b.get("resource")): raise TablePropertiesConflictException( - table_name, "resource", resource, tab_a.get("parent") + schema_name, table_name, "resource", resource, tab_a.get("parent") ) return partial_table @@ -444,7 +481,9 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable # return False -def merge_table(table: TTableSchema, partial_table: TPartialTableSchema) -> TPartialTableSchema: +def merge_table( + schema_name: str, table: TTableSchema, partial_table: TPartialTableSchema +) -> TPartialTableSchema: """Merges "partial_table" into "table". `table` is merged in place. Returns the diff partial table. `table` and `partial_table` names must be identical. A table diff is generated and applied to `table`: @@ -456,9 +495,10 @@ def merge_table(table: TTableSchema, partial_table: TPartialTableSchema) -> TPar if table["name"] != partial_table["name"]: raise TablePropertiesConflictException( - table["name"], "name", table["name"], partial_table["name"] + schema_name, table["name"], "name", table["name"], partial_table["name"] ) - diff = diff_table(table, partial_table) + diff = diff_table(schema_name, table, partial_table) + # add new columns when all checks passed updated_columns = merge_columns(table["columns"], diff["columns"]) table.update(diff) table["columns"] = updated_columns @@ -466,9 +506,67 @@ def merge_table(table: TTableSchema, partial_table: TPartialTableSchema) -> TPar return diff +def normalize_table_identifiers(table: TTableSchema, naming: NamingConvention) -> TTableSchema: + """Normalizes all table and column names in `table` schema according to current schema naming convention and returns + new instance with modified table schema. + + Naming convention like snake_case may produce name collisions with the column names. Colliding column schemas are merged + where the column that is defined later in the dictionary overrides earlier column. + + Note that resource name is not normalized. + """ + + table = copy(table) + table["name"] = naming.normalize_tables_path(table["name"]) + parent = table.get("parent") + if parent: + table["parent"] = naming.normalize_tables_path(parent) + columns = table.get("columns") + if columns: + new_columns: TTableSchemaColumns = {} + for c in columns.values(): + c = copy(c) + origin_c_name = c["name"] + new_col_name = c["name"] = naming.normalize_path(c["name"]) + # re-index columns as the name changed, if name space was reduced then + # some columns now collide with each other. so make sure that we merge columns that are already there + if new_col_name in new_columns: + new_columns[new_col_name] = merge_column( + new_columns[new_col_name], c, merge_defaults=False + ) + logger.warning( + f"In schema {naming} column {origin_c_name} got normalized into" + f" {new_col_name} which collides with other column. Both columns got merged" + " into one." + ) + else: + new_columns[new_col_name] = c + table["columns"] = new_columns + return table + + def has_table_seen_data(table: TTableSchema) -> bool: """Checks if normalizer has seen data coming to the table.""" - return "x-normalizer" in table and table["x-normalizer"].get("seen-data", None) is True # type: ignore[typeddict-item] + return "x-normalizer" in table and table["x-normalizer"].get("seen-data", None) is True + + +def remove_processing_hints(tables: TSchemaTables) -> TSchemaTables: + "Removes processing hints like x-normalizer and x-loader from schema tables. Modifies the input tables and returns it for convenience" + for table_name, hints in get_processing_hints(tables).items(): + for hint in hints: + del tables[table_name][hint] # type: ignore[misc] + return tables + + +def get_processing_hints(tables: TSchemaTables) -> Dict[str, List[str]]: + """Finds processing hints in a set of tables and returns table_name: [hints] mapping""" + hints: Dict[str, List[str]] = {} + for table in tables.values(): + for hint in TTableProcessingHints.__annotations__.keys(): + if hint in table: + table_hints = hints.setdefault(table["name"], []) + table_hints.append(hint) + return hints def hint_to_column_prop(h: TColumnHint) -> TColumnProp: @@ -581,6 +679,12 @@ def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: ) +def get_file_format(tables: TSchemaTables, table_name: str) -> TFileFormat: + return cast( + TFileFormat, get_inherited_table_hint(tables, table_name, "file_format", allow_none=True) + ) + + def fill_hints_from_parent_and_clone_table( tables: TSchemaTables, table: TTableSchema ) -> TTableSchema: @@ -592,6 +696,8 @@ def fill_hints_from_parent_and_clone_table( table["write_disposition"] = get_write_disposition(tables, table["name"]) if "table_format" not in table: table["table_format"] = get_table_format(tables, table["name"]) + if "file_format" not in table: + table["file_format"] = get_file_format(tables, table["name"]) return table @@ -650,6 +756,8 @@ def group_tables_by_resource( def version_table() -> TTableSchema: # NOTE: always add new columns at the end of the table so we have identical layout # after an update of existing tables (always at the end) + # set to nullable so we can migrate existing tables + # WARNING: do not reorder the columns table = new_table( VERSION_TABLE_NAME, columns=[ @@ -670,9 +778,11 @@ def version_table() -> TTableSchema: return table -def load_table() -> TTableSchema: +def loads_table() -> TTableSchema: # NOTE: always add new columns at the end of the table so we have identical layout # after an update of existing tables (always at the end) + # set to nullable so we can migrate existing tables + # WARNING: do not reorder the columns table = new_table( LOADS_TABLE_NAME, columns=[ @@ -692,6 +802,30 @@ def load_table() -> TTableSchema: return table +def pipeline_state_table() -> TTableSchema: + # NOTE: always add new columns at the end of the table so we have identical layout + # after an update of existing tables (always at the end) + # set to nullable so we can migrate existing tables + # WARNING: do not reorder the columns + table = new_table( + PIPELINE_STATE_TABLE_NAME, + write_disposition="append", + columns=[ + {"name": "version", "data_type": "bigint", "nullable": False}, + {"name": "engine_version", "data_type": "bigint", "nullable": False}, + {"name": "pipeline_name", "data_type": "text", "nullable": False}, + {"name": "state", "data_type": "text", "nullable": False}, + {"name": "created_at", "data_type": "timestamp", "nullable": False}, + {"name": "version_hash", "data_type": "text", "nullable": True}, + {"name": "_dlt_load_id", "data_type": "text", "nullable": False}, + ], + # always use caps preferred file format for processing + file_format="preferred", + ) + table["description"] = "Created by DLT. Tracks pipeline state" + return table + + def new_table( table_name: str, parent_table_name: str = None, @@ -701,6 +835,7 @@ def new_table( resource: str = None, schema_contract: TSchemaContract = None, table_format: TTableFormat = None, + file_format: TFileFormat = None, ) -> TTableSchema: table: TTableSchema = { "name": table_name, @@ -719,6 +854,8 @@ def new_table( table["schema_contract"] = schema_contract if table_format: table["table_format"] = table_format + if file_format: + table["file_format"] = file_format if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, @@ -754,7 +891,7 @@ def new_column( return column -def standard_hints() -> Dict[TColumnHint, List[TSimpleRegex]]: +def default_hints() -> Dict[TColumnHint, List[TSimpleRegex]]: return None diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index f6072c0260..29a9da8acf 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -60,15 +60,18 @@ def import_items_file( table_name: str, file_path: str, metrics: DataWriterMetrics, + with_extension: str = None, ) -> DataWriterMetrics: """Import a file from `file_path` into items storage under a new file name. Does not check the imported file format. Uses counts from `metrics` as a base. Logically closes the imported file The preferred import method is a hard link to avoid copying the data. If current filesystem does not support it, a regular copy is used. + + Alternative extension may be provided via `with_extension` so various file formats may be imported into the same folder. """ writer = self._get_writer(load_id, schema_name, table_name) - return writer.import_file(file_path, metrics) + return writer.import_file(file_path, metrics, with_extension) def close_writers(self, load_id: str, skip_flush: bool = False) -> None: """Flush, write footers (skip_flush), write metrics and close files in all diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index 26a76bb5c0..028491dd9c 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -79,6 +79,23 @@ def __init__(self, load_id: str) -> None: super().__init__(f"Package with load id {load_id} could not be found") +class LoadPackageAlreadyCompleted(LoadStorageException): + def __init__(self, load_id: str) -> None: + self.load_id = load_id + super().__init__( + f"Package with load id {load_id} is already completed, but another complete was" + " requested" + ) + + +class LoadPackageNotCompleted(LoadStorageException): + def __init__(self, load_id: str) -> None: + self.load_id = load_id + super().__init__( + f"Package with load id {load_id} is not yet completed, but method required that" + ) + + class SchemaStorageException(StorageException): pass diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index d768ec720a..7d14b8f7f7 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -6,7 +6,7 @@ import tempfile import shutil import pathvalidate -from typing import IO, Any, Optional, List, cast, overload +from typing import IO, Any, Optional, List, cast from dlt.common.typing import AnyFun from dlt.common.utils import encoding_for_mode, uniq_id @@ -18,7 +18,7 @@ class FileStorage: def __init__(self, storage_path: str, file_type: str = "t", makedirs: bool = False) -> None: # make it absolute path - self.storage_path = os.path.realpath(storage_path) # os.path.join(, '') + self.storage_path = os.path.realpath(storage_path) self.file_type = file_type if makedirs: os.makedirs(storage_path, exist_ok=True) @@ -243,7 +243,8 @@ def atomic_import( FileStorage.move_atomic_to_file(external_file_path, dest_file_path) ) - def in_storage(self, path: str) -> bool: + def is_path_in_storage(self, path: str) -> bool: + """Checks if a given path is below storage root, without checking for item existence""" assert path is not None # all paths are relative to root if not os.path.isabs(path): @@ -256,25 +257,30 @@ def in_storage(self, path: str) -> bool: def to_relative_path(self, path: str) -> str: if path == "": return "" - if not self.in_storage(path): + if not self.is_path_in_storage(path): raise ValueError(path) if not os.path.isabs(path): path = os.path.realpath(os.path.join(self.storage_path, path)) # for abs paths find the relative return os.path.relpath(path, start=self.storage_path) - def make_full_path(self, path: str) -> str: + def make_full_path_safe(self, path: str) -> str: + """Verifies that path is under storage root and then returns normalized absolute path""" # try to make a relative path if paths are absolute or overlapping path = self.to_relative_path(path) # then assume that it is a path relative to storage root return os.path.realpath(os.path.join(self.storage_path, path)) + def make_full_path(self, path: str) -> str: + """Joins path with storage root. Intended for path known to be relative to storage root""" + return os.path.join(self.storage_path, path) + def from_wd_to_relative_path(self, wd_relative_path: str) -> str: path = os.path.realpath(wd_relative_path) return self.to_relative_path(path) def from_relative_path_to_wd(self, relative_path: str) -> str: - return os.path.relpath(self.make_full_path(relative_path), start=".") + return os.path.relpath(self.make_full_path_safe(relative_path), start=".") @staticmethod def get_file_name_from_file_path(file_path: str) -> str: diff --git a/dlt/common/storages/fsspec_filesystem.py b/dlt/common/storages/fsspec_filesystem.py index a21f0f2c0c..f419baed03 100644 --- a/dlt/common/storages/fsspec_filesystem.py +++ b/dlt/common/storages/fsspec_filesystem.py @@ -5,7 +5,6 @@ import pathlib import posixpath from io import BytesIO -from gzip import GzipFile from typing import ( Literal, cast, diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index fd4ecc968e..1ecc491174 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -32,20 +32,6 @@ def remove_schema(self, name: str) -> None: # also remove the live schema self.live_schemas.pop(name, None) - def save_import_schema_if_not_exists(self, schema: Schema) -> bool: - """Saves import schema, if not exists. If schema was saved, link itself as imported from""" - if self.config.import_schema_path: - try: - self._load_import_schema(schema.name) - except FileNotFoundError: - # save import schema only if it not exist - self._export_schema(schema, self.config.import_schema_path) - # if import schema got saved then add own version hash as import version hash - schema._imported_version_hash = schema.version_hash - return True - - return False - def commit_live_schema(self, name: str) -> str: """Saves live schema in storage if it was modified""" if not self.is_live_schema_committed(name): diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 4d72458e3e..9e3185221d 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -5,7 +5,7 @@ import datetime # noqa: 251 import humanize -from pathlib import Path +from pathlib import PurePath from pendulum.datetime import DateTime from typing import ( ClassVar, @@ -37,7 +37,12 @@ from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns, TTableSchema from dlt.common.storages import FileStorage -from dlt.common.storages.exceptions import LoadPackageNotFound, CurrentLoadPackageStateNotAvailable +from dlt.common.storages.exceptions import ( + LoadPackageAlreadyCompleted, + LoadPackageNotCompleted, + LoadPackageNotFound, + CurrentLoadPackageStateNotAvailable, +) from dlt.common.typing import DictStrAny, SupportsHumanize from dlt.common.utils import flatten_list_or_items from dlt.common.versioned_state import ( @@ -52,6 +57,7 @@ TJobFileFormat = Literal["sql", "reference", TLoaderFileFormat] """Loader file formats with internal job types""" +JOB_EXCEPTION_EXTENSION = ".exception" class TPipelineStateDoc(TypedDict, total=False): @@ -61,9 +67,9 @@ class TPipelineStateDoc(TypedDict, total=False): engine_version: int pipeline_name: str state: str - version_hash: str created_at: datetime.datetime - dlt_load_id: NotRequired[str] + version_hash: str + _dlt_load_id: NotRequired[str] class TLoadPackageState(TVersionedState, total=False): @@ -165,7 +171,7 @@ def with_retry(self) -> "ParsedLoadJobFileName": @staticmethod def parse(file_name: str) -> "ParsedLoadJobFileName": - p = Path(file_name) + p = PurePath(file_name) parts = p.name.split(".") if len(parts) != 4: raise TerminalValueError(parts) @@ -319,13 +325,16 @@ def __init__(self, storage: FileStorage, initial_state: TLoadPackageStatus) -> N # def get_package_path(self, load_id: str) -> str: + """Gets path of the package relative to storage root""" return load_id - def get_job_folder_path(self, load_id: str, folder: TJobState) -> str: - return os.path.join(self.get_package_path(load_id), folder) + def get_job_state_folder_path(self, load_id: str, state: TJobState) -> str: + """Gets path to the jobs in `state` in package `load_id`, relative to the storage root""" + return os.path.join(self.get_package_path(load_id), state) - def get_job_file_path(self, load_id: str, folder: TJobState, file_name: str) -> str: - return os.path.join(self.get_job_folder_path(load_id, folder), file_name) + def get_job_file_path(self, load_id: str, state: TJobState, file_name: str) -> str: + """Get path to job with `file_name` in `state` in package `load_id`, relative to the storage root""" + return os.path.join(self.get_job_state_folder_path(load_id, state), file_name) def list_packages(self) -> Sequence[str]: """Lists all load ids in storage, earliest first @@ -338,29 +347,42 @@ def list_packages(self) -> Sequence[str]: def list_new_jobs(self, load_id: str) -> Sequence[str]: new_jobs = self.storage.list_folder_files( - self.get_job_folder_path(load_id, PackageStorage.NEW_JOBS_FOLDER) + self.get_job_state_folder_path(load_id, PackageStorage.NEW_JOBS_FOLDER) ) return new_jobs def list_started_jobs(self, load_id: str) -> Sequence[str]: return self.storage.list_folder_files( - self.get_job_folder_path(load_id, PackageStorage.STARTED_JOBS_FOLDER) + self.get_job_state_folder_path(load_id, PackageStorage.STARTED_JOBS_FOLDER) ) def list_failed_jobs(self, load_id: str) -> Sequence[str]: - return self.storage.list_folder_files( - self.get_job_folder_path(load_id, PackageStorage.FAILED_JOBS_FOLDER) - ) - - def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: - return self.filter_jobs_for_table(self.list_all_jobs(load_id), table_name) - - def list_all_jobs(self, load_id: str) -> Sequence[LoadJobInfo]: - info = self.get_load_package_info(load_id) - return [job for job in flatten_list_or_items(iter(info.jobs.values()))] # type: ignore + return [ + file + for file in self.storage.list_folder_files( + self.get_job_state_folder_path(load_id, PackageStorage.FAILED_JOBS_FOLDER) + ) + if not file.endswith(JOB_EXCEPTION_EXTENSION) + ] + + def list_job_with_states_for_table( + self, load_id: str, table_name: str + ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + return self.filter_jobs_for_table(self.list_all_jobs_with_states(load_id), table_name) + + def list_all_jobs_with_states( + self, load_id: str + ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + info = self.get_load_package_jobs(load_id) + state_jobs = [] + for state, jobs in info.items(): + state_jobs.extend([(state, job) for job in jobs]) + return state_jobs def list_failed_jobs_infos(self, load_id: str) -> Sequence[LoadJobInfo]: """List all failed jobs and associated error messages for a load package with `load_id`""" + if not self.is_package_completed(load_id): + raise LoadPackageNotCompleted(load_id) failed_jobs: List[LoadJobInfo] = [] package_path = self.get_package_path(load_id) package_created_at = pendulum.from_timestamp( @@ -371,12 +393,19 @@ def list_failed_jobs_infos(self, load_id: str) -> Sequence[LoadJobInfo]: ) ) for file in self.list_failed_jobs(load_id): - if not file.endswith(".exception"): - failed_jobs.append( - self._read_job_file_info("failed_jobs", file, package_created_at) + failed_jobs.append( + self._read_job_file_info( + load_id, "failed_jobs", ParsedLoadJobFileName.parse(file), package_created_at ) + ) return failed_jobs + def is_package_completed(self, load_id: str) -> bool: + package_path = self.get_package_path(load_id) + return self.storage.has_file( + os.path.join(package_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME) + ) + # # Move jobs # @@ -385,7 +414,9 @@ def import_job( self, load_id: str, job_file_path: str, job_state: TJobState = "new_jobs" ) -> None: """Adds new job by moving the `job_file_path` into `new_jobs` of package `load_id`""" - self.storage.atomic_import(job_file_path, self.get_job_folder_path(load_id, job_state)) + self.storage.atomic_import( + job_file_path, self.get_job_state_folder_path(load_id, job_state) + ) def start_job(self, load_id: str, file_name: str) -> str: return self._move_job( @@ -397,7 +428,7 @@ def fail_job(self, load_id: str, file_name: str, failed_message: Optional[str]) if failed_message: self.storage.save( self.get_job_file_path( - load_id, PackageStorage.FAILED_JOBS_FOLDER, file_name + ".exception" + load_id, PackageStorage.FAILED_JOBS_FOLDER, file_name + JOB_EXCEPTION_EXTENSION ), failed_message, ) @@ -455,6 +486,8 @@ def create_package(self, load_id: str, initial_state: TLoadPackageState = None) def complete_loading_package(self, load_id: str, load_state: TLoadPackageStatus) -> str: """Completes loading the package by writing marker file with`package_state. Returns path to the completed package""" load_path = self.get_package_path(load_id) + if self.is_package_completed(load_id): + raise LoadPackageAlreadyCompleted(load_id) # save marker file self.storage.save( os.path.join(load_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME), load_state @@ -468,7 +501,7 @@ def remove_completed_jobs(self, load_id: str) -> None: # delete completed jobs if not has_failed_jobs: self.storage.delete_folder( - self.get_job_folder_path(load_id, PackageStorage.COMPLETED_JOBS_FOLDER), + self.get_job_state_folder_path(load_id, PackageStorage.COMPLETED_JOBS_FOLDER), recursively=True, ) @@ -533,11 +566,32 @@ def get_load_package_state_path(self, load_id: str) -> str: # Get package info # - def get_load_package_info(self, load_id: str) -> LoadPackageInfo: - """Gets information on normalized/completed package with given load_id, all jobs and their statuses.""" + def get_load_package_jobs(self, load_id: str) -> Dict[TJobState, List[ParsedLoadJobFileName]]: + """Gets all jobs in a package and returns them as lists assigned to a particular state.""" package_path = self.get_package_path(load_id) if not self.storage.has_folder(package_path): raise LoadPackageNotFound(load_id) + all_jobs: Dict[TJobState, List[ParsedLoadJobFileName]] = {} + for state in WORKING_FOLDERS: + jobs: List[ParsedLoadJobFileName] = [] + with contextlib.suppress(FileNotFoundError): + # we ignore if load package lacks one of working folders. completed_jobs may be deleted on archiving + for file in self.storage.list_folder_files( + self.get_job_state_folder_path(load_id, state), to_root=False + ): + if not file.endswith(JOB_EXCEPTION_EXTENSION): + jobs.append(ParsedLoadJobFileName.parse(file)) + all_jobs[state] = jobs + return all_jobs + + def get_load_package_info(self, load_id: str) -> LoadPackageInfo: + """Gets information on normalized/completed package with given load_id, all jobs and their statuses. + + Will reach to the file system to get additional stats, mtime, also collects exceptions for failed jobs. + NOTE: do not call this function often. it should be used only to generate metrics + """ + package_path = self.get_package_path(load_id) + package_jobs = self.get_load_package_jobs(load_id) package_created_at: DateTime = None package_state = self.initial_state @@ -560,15 +614,11 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: schema = Schema.from_dict(self._load_schema(load_id)) # read jobs with all statuses - all_jobs: Dict[TJobState, List[LoadJobInfo]] = {} - for state in WORKING_FOLDERS: - jobs: List[LoadJobInfo] = [] - with contextlib.suppress(FileNotFoundError): - # we ignore if load package lacks one of working folders. completed_jobs may be deleted on archiving - for file in self.storage.list_folder_files(os.path.join(package_path, state)): - if not file.endswith(".exception"): - jobs.append(self._read_job_file_info(state, file, package_created_at)) - all_jobs[state] = jobs + all_job_infos: Dict[TJobState, List[LoadJobInfo]] = {} + for state, jobs in package_jobs.items(): + all_job_infos[state] = [ + self._read_job_file_info(load_id, state, job, package_created_at) for job in jobs + ] return LoadPackageInfo( load_id, @@ -577,15 +627,46 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: schema, applied_update, package_created_at, - all_jobs, + all_job_infos, ) - def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) -> LoadJobInfo: - try: - failed_message = self.storage.load(file + ".exception") - except FileNotFoundError: - failed_message = None - full_path = self.storage.make_full_path(file) + def get_job_failed_message(self, load_id: str, job: ParsedLoadJobFileName) -> str: + """Get exception message of a failed job.""" + rel_path = self.get_job_file_path(load_id, "failed_jobs", job.file_name()) + if not self.storage.has_file(rel_path): + raise FileNotFoundError(rel_path) + failed_message: str = None + with contextlib.suppress(FileNotFoundError): + failed_message = self.storage.load(rel_path + JOB_EXCEPTION_EXTENSION) + return failed_message + + def job_to_job_info( + self, load_id: str, state: TJobState, job: ParsedLoadJobFileName + ) -> LoadJobInfo: + """Creates partial job info by converting job object. size, mtime and failed message will not be populated""" + full_path = os.path.join( + self.storage.storage_path, self.get_job_file_path(load_id, state, job.file_name()) + ) + return LoadJobInfo( + state, + full_path, + 0, + None, + 0, + job, + None, + ) + + def _read_job_file_info( + self, load_id: str, state: TJobState, job: ParsedLoadJobFileName, now: DateTime = None + ) -> LoadJobInfo: + """Creates job info by reading additional props from storage""" + failed_message = None + if state == "failed_jobs": + failed_message = self.get_job_failed_message(load_id, job) + full_path = os.path.join( + self.storage.storage_path, self.get_job_file_path(load_id, state, job.file_name()) + ) st = os.stat(full_path) return LoadJobInfo( state, @@ -593,7 +674,7 @@ def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) st.st_size, pendulum.from_timestamp(st.st_mtime), PackageStorage._job_elapsed_time_seconds(full_path, now.timestamp() if now else None), - ParsedLoadJobFileName.parse(file), + job, failed_message, ) @@ -611,10 +692,11 @@ def _move_job( ) -> str: # ensure we move file names, not paths assert file_name == FileStorage.get_file_name_from_file_path(file_name) - load_path = self.get_package_path(load_id) - dest_path = os.path.join(load_path, dest_folder, new_file_name or file_name) - self.storage.atomic_rename(os.path.join(load_path, source_folder, file_name), dest_path) - # print(f"{join(load_path, source_folder, file_name)} -> {dest_path}") + + dest_path = self.get_job_file_path(load_id, dest_folder, new_file_name or file_name) + self.storage.atomic_rename( + self.get_job_file_path(load_id, source_folder, file_name), dest_path + ) return self.storage.make_full_path(dest_path) def _load_schema(self, load_id: str) -> DictStrAny: @@ -659,9 +741,9 @@ def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: @staticmethod def filter_jobs_for_table( - all_jobs: Iterable[LoadJobInfo], table_name: str - ) -> Sequence[LoadJobInfo]: - return [job for job in all_jobs if job.job_file_info.table_name == table_name] + all_jobs: Iterable[Tuple[TJobState, ParsedLoadJobFileName]], table_name: str + ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + return [job for job in all_jobs if job[1].table_name == table_name] @configspec diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 1afed18929..0544de696f 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -5,7 +5,7 @@ from dlt.common.json import json from dlt.common.configuration import with_config from dlt.common.configuration.accessors import config -from dlt.common.schema.utils import to_pretty_json, to_pretty_yaml +from dlt.common.schema.utils import get_processing_hints, to_pretty_json, to_pretty_yaml from dlt.common.storages.configuration import ( SchemaStorageConfiguration, TSchemaFileFormat, @@ -57,6 +57,14 @@ def load_schema(self, name: str) -> Schema: return Schema.from_dict(storage_schema) def save_schema(self, schema: Schema) -> str: + """Saves schema to the storage and returns the path relative to storage. + + If import schema path is configured and import schema with schema.name exits, it + will be linked to `schema` via `_imported_version_hash`. Such hash is used in `load_schema` to + detect if import schema changed and thus to overwrite the storage schema. + + If export schema path is configured, `schema` will be exported to it. + """ # check if there's schema to import if self.config.import_schema_path: try: @@ -66,11 +74,25 @@ def save_schema(self, schema: Schema) -> str: except FileNotFoundError: # just save the schema pass - path = self._save_schema(schema) - if self.config.export_schema_path: - self._export_schema(schema, self.config.export_schema_path) + path = self._save_and_export_schema(schema) return path + def save_import_schema_if_not_exists(self, schema: Schema) -> bool: + """Saves import schema, if not exists. If schema was saved, link itself as imported from""" + if self.config.import_schema_path: + try: + self._load_import_schema(schema.name) + except FileNotFoundError: + # save import schema only if it not exist + self._export_schema( + schema, self.config.import_schema_path, remove_processing_hints=True + ) + # if import schema got saved then add own version hash as import version hash + schema._imported_version_hash = schema.version_hash + return True + + return False + def remove_schema(self, name: str) -> None: schema_file = self._file_name_in_store(name, "json") self.storage.delete(schema_file) @@ -116,25 +138,32 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> f" {rv_schema._imported_version_hash}" ) # if schema was imported, overwrite storage schema - self._save_schema(rv_schema) - if self.config.export_schema_path: - self._export_schema(rv_schema, self.config.export_schema_path) + self._save_and_export_schema(rv_schema, check_processing_hints=True) else: # import schema when imported schema was modified from the last import rv_schema = Schema.from_dict(storage_schema) i_s = Schema.from_dict(imported_schema) if i_s.version_hash != rv_schema._imported_version_hash: + logger.warning( + f"Schema {name} was present in schema storage at" + f" {self.storage.storage_path} but will be overwritten with imported schema" + f" version {i_s.version} and imported hash {i_s.version_hash}" + ) + tables_seen_data = rv_schema.data_tables(seen_data_only=True) + if tables_seen_data: + logger.warning( + f"Schema {name} in schema storage contains tables" + f" ({', '.join(t['name'] for t in tables_seen_data)}) that are present" + " in the destination. If you changed schema of those tables in import" + " schema, consider using one of the refresh options:" + " https://dlthub.com/devel/general-usage/pipeline#refresh-pipeline-data-and-state" + ) + rv_schema.replace_schema_content(i_s, link_to_replaced_schema=True) rv_schema._imported_version_hash = i_s.version_hash - logger.info( - f"Schema {name} was present in {self.storage.storage_path} but is" - f" overwritten with imported schema version {i_s.version} and" - f" imported hash {i_s.version_hash}" - ) + # if schema was imported, overwrite storage schema - self._save_schema(rv_schema) - if self.config.export_schema_path: - self._export_schema(rv_schema, self.config.export_schema_path) + self._save_and_export_schema(rv_schema, check_processing_hints=True) except FileNotFoundError: # no schema to import -> skip silently and return the original if storage_schema is None: @@ -156,8 +185,13 @@ def _load_import_schema(self, name: str) -> DictStrAny: import_storage.load(schema_file), self.config.external_schema_format ) - def _export_schema(self, schema: Schema, export_path: str) -> None: - stored_schema = schema.to_dict(remove_defaults=True) + def _export_schema( + self, schema: Schema, export_path: str, remove_processing_hints: bool = False + ) -> None: + stored_schema = schema.to_dict( + remove_defaults=self.config.external_schema_format_remove_defaults, + remove_processing_hints=remove_processing_hints, + ) if self.config.external_schema_format == "json": exported_schema_s = to_pretty_json(stored_schema) elif self.config.external_schema_format == "yaml": @@ -175,7 +209,7 @@ def _export_schema(self, schema: Schema, export_path: str) -> None: ) def _save_schema(self, schema: Schema) -> str: - # save a schema to schema store + """Saves schema to schema store and bumps the version""" schema_file = self._file_name_in_store(schema.name, "json") stored_schema = schema.to_dict() saved_path = self.storage.save(schema_file, to_pretty_json(stored_schema)) @@ -184,16 +218,45 @@ def _save_schema(self, schema: Schema) -> str: schema._bump_version() return saved_path + def _save_and_export_schema(self, schema: Schema, check_processing_hints: bool = False) -> str: + """Saves schema to schema store and then exports it. If the export path is the same as import + path, processing hints will be removed. + """ + saved_path = self._save_schema(schema) + if self.config.export_schema_path: + self._export_schema( + schema, + self.config.export_schema_path, + self.config.export_schema_path == self.config.import_schema_path, + ) + # if any processing hints are found we should warn the user + if check_processing_hints and (processing_hints := get_processing_hints(schema.tables)): + msg = ( + f"Imported schema {schema.name} contains processing hints for some tables." + " Processing hints are used by normalizer (x-normalizer) to mark tables that got" + " materialized and that prevents destructive changes to the schema. In most cases" + " import schema should not contain processing hints because it is mostly used to" + " initialize tables in a new dataset. " + ) + msg += "Affected tables are: " + ", ".join(processing_hints.keys()) + logger.warning(msg) + return saved_path + @staticmethod def load_schema_file( - path: str, name: str, extensions: Tuple[TSchemaFileFormat, ...] = SchemaFileExtensions + path: str, + name: str, + extensions: Tuple[TSchemaFileFormat, ...] = SchemaFileExtensions, + remove_processing_hints: bool = False, ) -> Schema: storage = FileStorage(path) for extension in extensions: file = SchemaStorage._file_name_in_store(name, extension) if storage.has_file(file): parsed_schema = SchemaStorage._parse_schema_str(storage.load(file), extension) - schema = Schema.from_dict(parsed_schema) + schema = Schema.from_dict( + parsed_schema, remove_processing_hints=remove_processing_hints + ) if schema.name != name: raise UnexpectedSchemaName(name, path, schema.name) return schema diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 29c1b01d80..fdd27161f7 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -4,7 +4,7 @@ import os from re import Pattern as _REPattern import sys -from types import FunctionType, MethodType, ModuleType +from types import FunctionType from typing import ( ForwardRef, Callable, @@ -39,6 +39,7 @@ Concatenate, get_args, get_origin, + get_original_bases, ) try: @@ -105,6 +106,8 @@ VARIANT_FIELD_FORMAT = "v_%s" TFileOrPath = Union[str, PathLike, IO[Any]] TSortOrder = Literal["asc", "desc"] +TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv"] +"""known loader file formats""" class ConfigValueSentinel(NamedTuple): @@ -257,6 +260,25 @@ def is_literal_type(hint: Type[Any]) -> bool: return False +def get_literal_args(literal: Type[Any]) -> List[Any]: + """Recursively get arguments from nested Literal types and return an unified list.""" + if not hasattr(literal, "__origin__") or literal.__origin__ is not Literal: + raise ValueError("Provided type is not a Literal") + + unified_args = [] + + def _get_args(literal: Type[Any]) -> None: + for arg in get_args(literal): + if hasattr(arg, "__origin__") and arg.__origin__ is Literal: + _get_args(arg) + else: + unified_args.append(arg) + + _get_args(literal) + + return unified_args + + def is_newtype_type(t: Type[Any]) -> bool: if hasattr(t, "__supertype__"): return True @@ -362,7 +384,7 @@ def is_subclass(subclass: Any, cls: Any) -> bool: def get_generic_type_argument_from_instance( - instance: Any, sample_value: Optional[Any] + instance: Any, sample_value: Optional[Any] = None ) -> Type[Any]: """Infers type argument of a Generic class from an `instance` of that class using optional `sample_value` of the argument type @@ -376,8 +398,14 @@ def get_generic_type_argument_from_instance( Type[Any]: type argument or Any if not known """ orig_param_type = Any - if hasattr(instance, "__orig_class__"): - orig_param_type = get_args(instance.__orig_class__)[0] + if cls_ := getattr(instance, "__orig_class__", None): + # instance of generic class + pass + elif bases_ := get_original_bases(instance.__class__): + # instance of class deriving from generic + cls_ = bases_[0] + if cls_: + orig_param_type = get_args(cls_)[0] if orig_param_type is Any and sample_value is not None: orig_param_type = type(sample_value) return orig_param_type # type: ignore diff --git a/dlt/common/utils.py b/dlt/common/utils.py index cb2ec4c3d9..8e89556c39 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -13,6 +13,7 @@ from typing import ( Any, + Callable, ContextManager, Dict, MutableMapping, @@ -141,42 +142,6 @@ def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> DictStrAn return o -# def flatten_dicts_of_dicts(dicts: Mapping[str, Any]) -> Sequence[Any]: -# """ -# Transform and object {K: {...}, L: {...}...} -> [{key:K, ....}, {key: L, ...}, ...] -# """ -# o: List[Any] = [] -# for k, v in dicts.items(): -# if isinstance(v, list): -# # if v is a list then add "key" to each list element -# for lv in v: -# lv["key"] = k -# else: -# # add as "key" to dict -# v["key"] = k - -# o.append(v) -# return o - - -# def tuplify_list_of_dicts(dicts: Sequence[DictStrAny]) -> Sequence[DictStrAny]: -# """ -# Transform list of dictionaries with single key into single dictionary of {"key": orig_key, "value": orig_value} -# """ -# for d in dicts: -# if len(d) > 1: -# raise ValueError(f"Tuplify requires one key dicts {d}") -# if len(d) == 1: -# key = next(iter(d)) -# # delete key first to avoid name clashes -# value = d[key] -# del d[key] -# d["key"] = key -# d["value"] = value - -# return dicts - - def flatten_list_or_items(_iter: Union[Iterable[TAny], Iterable[List[TAny]]]) -> Iterator[TAny]: for items in _iter: if isinstance(items, List): @@ -503,11 +468,15 @@ def merge_row_counts(row_counts_1: RowCounts, row_counts_2: RowCounts) -> None: row_counts_1[counter_name] = row_counts_1.get(counter_name, 0) + row_counts_2[counter_name] -def extend_list_deduplicated(original_list: List[Any], extending_list: Iterable[Any]) -> List[Any]: +def extend_list_deduplicated( + original_list: List[Any], + extending_list: Iterable[Any], + normalize_f: Callable[[str], str] = str.__call__, +) -> List[Any]: """extends the first list by the second, but does not add duplicates""" - list_keys = set(original_list) + list_keys = set(normalize_f(s) for s in original_list) for item in extending_list: - if item not in list_keys: + if normalize_f(item) not in list_keys: original_list.append(item) return original_list diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 0a8bced287..8862c10024 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -7,6 +7,7 @@ from dlt.common.exceptions import DictValidationException from dlt.common.typing import ( StrAny, + get_literal_args, get_type_name, is_callable_type, is_literal_type, @@ -114,7 +115,7 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: failed_validations, ) elif is_literal_type(t): - a_l = get_args(t) + a_l = get_literal_args(t) if pv not in a_l: raise DictValidationException( f"field '{pk}' with value {pv} is not one of: {a_l}", path, t, pk, pv diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 302de24a6b..0546d16bcd 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -8,6 +8,7 @@ from dlt.destinations.impl.athena.factory import athena from dlt.destinations.impl.redshift.factory import redshift from dlt.destinations.impl.qdrant.factory import qdrant +from dlt.destinations.impl.lancedb.factory import lancedb from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate from dlt.destinations.impl.destination.factory import destination @@ -28,6 +29,7 @@ "athena", "redshift", "qdrant", + "lancedb", "motherduck", "weaviate", "synapse", diff --git a/dlt/destinations/adapters.py b/dlt/destinations/adapters.py index 1c3e094e19..0cf04b7b59 100644 --- a/dlt/destinations/adapters.py +++ b/dlt/destinations/adapters.py @@ -1,17 +1,20 @@ """This module collects all destination adapters present in `impl` namespace""" -from dlt.destinations.impl.weaviate import weaviate_adapter -from dlt.destinations.impl.qdrant import qdrant_adapter -from dlt.destinations.impl.bigquery import bigquery_adapter -from dlt.destinations.impl.synapse import synapse_adapter -from dlt.destinations.impl.clickhouse import clickhouse_adapter -from dlt.destinations.impl.athena import athena_adapter +from dlt.destinations.impl.weaviate.weaviate_adapter import weaviate_adapter +from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter +from dlt.destinations.impl.lancedb import lancedb_adapter +from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.impl.synapse.synapse_adapter import synapse_adapter +from dlt.destinations.impl.clickhouse.clickhouse_adapter import clickhouse_adapter +from dlt.destinations.impl.athena.athena_adapter import athena_adapter, athena_partition __all__ = [ "weaviate_adapter", "qdrant_adapter", + "lancedb_adapter", "bigquery_adapter", "synapse_adapter", "clickhouse_adapter", "athena_adapter", + "athena_partition", ] diff --git a/dlt/destinations/fs_client.py b/dlt/destinations/fs_client.py index 5153659614..3233446594 100644 --- a/dlt/destinations/fs_client.py +++ b/dlt/destinations/fs_client.py @@ -1,3 +1,4 @@ +import gzip from typing import Iterable, cast, Any, List from abc import ABC, abstractmethod from fsspec import AbstractFileSystem @@ -38,10 +39,19 @@ def read_bytes(self, path: str, start: Any = None, end: Any = None, **kwargs: An def read_text( self, path: str, - encoding: Any = None, + encoding: Any = "utf-8", errors: Any = None, newline: Any = None, + compression: str = None, **kwargs: Any ) -> str: - """reads given file into string""" - return cast(str, self.fs_client.read_text(path, encoding, errors, newline, **kwargs)) + """reads given file into string, tries gzip and pure text""" + if compression is None: + try: + return self.read_text(path, encoding, errors, newline, "gzip", **kwargs) + except (gzip.BadGzipFile, OSError): + pass + with self.fs_client.open( + path, mode="rt", compression=compression, encoding=encoding, newline=newline + ) as f: + return cast(str, f.read()) diff --git a/dlt/destinations/impl/athena/__init__.py b/dlt/destinations/impl/athena/__init__.py index 87a11f9f41..e69de29bb2 100644 --- a/dlt/destinations/impl/athena/__init__.py +++ b/dlt/destinations/impl/athena/__init__.py @@ -1,33 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import ( - escape_athena_identifier, - format_bigquery_datetime_literal, -) -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - # athena only supports loading from staged files on s3 for now - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] - caps.supported_table_formats = ["iceberg"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet", "jsonl"] - caps.escape_identifier = escape_athena_identifier - caps.format_datetime_literal = format_bigquery_datetime_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 16 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 262144 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.supports_transactions = False - caps.alter_add_multi_column = True - caps.schema_supports_numeric_precision = False - caps.timestamp_precision = 3 - caps.supports_truncate_command = False - return caps diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 60ea64a4e7..8d0ffb1d0c 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -34,21 +34,18 @@ from dlt.common import logger from dlt.common.exceptions import TerminalValueError from dlt.common.utils import without_none -from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema +from dlt.common.schema import TColumnSchema, Schema, TTableSchema from dlt.common.schema.typing import ( TTableSchema, TColumnType, - TWriteDisposition, TTableFormat, TSortOrder, ) -from dlt.common.schema.utils import table_schema_has_type, get_table_format +from dlt.common.schema.utils import table_schema_has_type from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import LoadJob, DoNothingFollowupJob, DoNothingJob -from dlt.common.destination.reference import TLoadJobState, NewLoadJob, SupportsStagingDestination -from dlt.common.storages import FileStorage -from dlt.common.data_writers.escape import escape_bigquery_identifier +from dlt.common.destination.reference import NewLoadJob, SupportsStagingDestination +from dlt.common.data_writers.escape import escape_hive_identifier from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob from dlt.destinations.typing import DBApi, DBTransaction @@ -58,7 +55,6 @@ DatabaseUndefinedRelation, LoadJobTerminalException, ) -from dlt.destinations.impl.athena import capabilities from dlt.destinations.sql_client import ( SqlClientBase, DBApiCursorImpl, @@ -221,11 +217,15 @@ def requires_temp_table_for_delete(cls) -> bool: class AthenaSQLClient(SqlClientBase[Connection]): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() dbapi: ClassVar[DBApi] = pyathena - def __init__(self, dataset_name: str, config: AthenaClientConfiguration) -> None: - super().__init__(None, dataset_name) + def __init__( + self, + dataset_name: str, + config: AthenaClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(None, dataset_name, capabilities) self._conn: Connection = None self.config = config self.credentials = config.credentials @@ -254,8 +254,9 @@ def escape_ddl_identifier(self, v: str) -> str: # Athena uses HIVE to create tables but for querying it uses PRESTO (so normal escaping) if not v: return v + v = self.capabilities.casefold_identifier(v) # bigquery uses hive escaping - return escape_bigquery_identifier(v) + return escape_hive_identifier(v) def fully_qualified_ddl_dataset_name(self) -> str: return self.escape_ddl_identifier(self.dataset_name) @@ -271,11 +272,6 @@ def create_dataset(self) -> None: def drop_dataset(self) -> None: self.execute_sql(f"DROP DATABASE {self.fully_qualified_ddl_dataset_name()} CASCADE;") - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - def drop_tables(self, *tables: str) -> None: if not tables: return @@ -366,17 +362,14 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB yield DBApiCursorImpl(cursor) # type: ignore - def has_dataset(self) -> bool: - # PRESTO escaping for queries - query = f"""SHOW DATABASES LIKE {self.fully_qualified_dataset_name()};""" - rows = self.execute_sql(query) - return len(rows) > 0 - class AthenaClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: + def __init__( + self, + schema: Schema, + config: AthenaClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: # verify if staging layout is valid for Athena # this will raise if the table prefix is not properly defined # we actually that {table_name} is first, no {schema_name} is allowed @@ -386,7 +379,7 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: table_needs_own_folder=True, ) - sql_client = AthenaSQLClient(config.normalize_dataset_name(schema), config) + sql_client = AthenaSQLClient(config.normalize_dataset_name(schema), config, capabilities) super().__init__(schema, config, sql_client) self.sql_client: AthenaSQLClient = sql_client # type: ignore self.config: AthenaClientConfiguration = config diff --git a/dlt/destinations/impl/athena/factory.py b/dlt/destinations/impl/athena/factory.py index 5b37607cca..d4c29a641f 100644 --- a/dlt/destinations/impl/athena/factory.py +++ b/dlt/destinations/impl/athena/factory.py @@ -1,9 +1,14 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration from dlt.common.configuration.specs import AwsCredentials -from dlt.destinations.impl.athena import capabilities +from dlt.common.data_writers.escape import ( + escape_athena_identifier, + format_bigquery_datetime_literal, +) +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + +from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration if t.TYPE_CHECKING: from dlt.destinations.impl.athena.athena import AthenaClient @@ -12,8 +17,36 @@ class athena(Destination[AthenaClientConfiguration, "AthenaClient"]): spec = AthenaClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + # athena only supports loading from staged files on s3 for now + caps.preferred_loader_file_format = None + caps.supported_loader_file_formats = [] + caps.supported_table_formats = ["iceberg"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet", "jsonl"] + # athena is storing all identifiers in lower case and is case insensitive + # it also uses lower case in all the queries + # https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html + caps.escape_identifier = escape_athena_identifier + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = False + caps.format_datetime_literal = format_bigquery_datetime_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 16 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 262144 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.supports_transactions = False + caps.alter_add_multi_column = True + caps.schema_supports_numeric_precision = False + caps.timestamp_precision = 3 + caps.supports_truncate_command = False + return caps @property def client_class(self) -> t.Type["AthenaClient"]: diff --git a/dlt/destinations/impl/bigquery/__init__.py b/dlt/destinations/impl/bigquery/__init__.py index 39322b43a0..e69de29bb2 100644 --- a/dlt/destinations/impl/bigquery/__init__.py +++ b/dlt/destinations/impl/bigquery/__init__.py @@ -1,31 +0,0 @@ -from dlt.common.data_writers.escape import ( - escape_bigquery_identifier, - format_bigquery_datetime_literal, -) -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl", "parquet"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet", "jsonl"] - # BQ limit is 4GB but leave a large headroom since buffered writer does not preemptively check size - caps.recommended_file_size = int(1024 * 1024 * 1024) - caps.escape_identifier = escape_bigquery_identifier - caps.escape_literal = None - caps.format_datetime_literal = format_bigquery_datetime_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (76, 38) - caps.max_identifier_length = 1024 - caps.max_column_identifier_length = 300 - caps.max_query_length = 1024 * 1024 - caps.is_max_query_length_in_bytes = False - caps.max_text_data_type_length = 10 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.supports_clone_table = True - - return caps diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index f26e6f42ee..c3a1be4174 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -1,7 +1,7 @@ import functools import os from pathlib import Path -from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, cast +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, cast import google.cloud.bigquery as bigquery # noqa: I250 from google.api_core import exceptions as api_core_exceptions @@ -35,7 +35,6 @@ LoadJobNotExistsException, LoadJobTerminalException, ) -from dlt.destinations.impl.bigquery import capabilities from dlt.destinations.impl.bigquery.bigquery_adapter import ( PARTITION_HINT, CLUSTER_HINT, @@ -50,6 +49,7 @@ from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.type_mapping import TypeMapper +from dlt.destinations.utils import parse_db_data_type_str_with_precision from dlt.pipeline.current import destination_state @@ -58,10 +58,10 @@ class BigQueryTypeMapper(TypeMapper): "complex": "JSON", "text": "STRING", "double": "FLOAT64", - "bool": "BOOLEAN", + "bool": "BOOL", "date": "DATE", "timestamp": "TIMESTAMP", - "bigint": "INTEGER", + "bigint": "INT64", "binary": "BYTES", "wei": "BIGNUMERIC", # non-parametrized should hold wei values "time": "TIME", @@ -74,11 +74,11 @@ class BigQueryTypeMapper(TypeMapper): dbt_to_sct = { "STRING": "text", - "FLOAT": "double", - "BOOLEAN": "bool", + "FLOAT64": "double", + "BOOL": "bool", "DATE": "date", "TIMESTAMP": "timestamp", - "INTEGER": "bigint", + "INT64": "bigint", "BYTES": "binary", "NUMERIC": "decimal", "BIGNUMERIC": "decimal", @@ -97,9 +97,10 @@ def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> def from_db_type( self, db_type: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - if db_type == "BIGNUMERIC" and precision is None: + # precision is present in the type name + if db_type == "BIGNUMERIC": return dict(data_type="wei") - return super().from_db_type(db_type, precision, scale) + return super().from_db_type(*parse_db_data_type_str_with_precision(db_type)) class BigQueryLoadJob(LoadJob, FollowupJob): @@ -173,12 +174,16 @@ def gen_key_table_clauses( class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: + def __init__( + self, + schema: Schema, + config: BigQueryClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: sql_client = BigQuerySqlClient( config.normalize_dataset_name(schema), config.credentials, + capabilities, config.get_location(), config.http_timeout, config.retry_deadline, @@ -266,7 +271,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> reason = BigQuerySqlClient._get_reason_from_errors(gace) if reason == "notFound": # google.api_core.exceptions.NotFound: 404 – table not found - raise UnknownTableException(table["name"]) from gace + raise UnknownTableException(self.schema.name, table["name"]) from gace elif ( reason == "duplicate" ): # google.api_core.exceptions.Conflict: 409 PUT – already exists @@ -292,15 +297,15 @@ def _get_table_update_sql( c for c in new_columns if c.get("partition") or c.get(PARTITION_HINT, False) ]: if len(partition_list) > 1: - col_names = [self.capabilities.escape_identifier(c["name"]) for c in partition_list] + col_names = [self.sql_client.escape_column_name(c["name"]) for c in partition_list] raise DestinationSchemaWillNotUpdate( canonical_name, col_names, "Partition requested for more than one column" ) elif (c := partition_list[0])["data_type"] == "date": - sql[0] += f"\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" + sql[0] += f"\nPARTITION BY {self.sql_client.escape_column_name(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": sql[0] = ( - f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + f"{sql[0]}\nPARTITION BY DATE({self.sql_client.escape_column_name(c['name'])})" ) # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. @@ -309,12 +314,12 @@ def _get_table_update_sql( # See: https://dlthub.com/devel/dlt-ecosystem/destinations/bigquery#supported-column-hints elif (c := partition_list[0])["data_type"] == "bigint": sql[0] += ( - f"\nPARTITION BY RANGE_BUCKET({self.capabilities.escape_identifier(c['name'])}," + f"\nPARTITION BY RANGE_BUCKET({self.sql_client.escape_column_name(c['name'])}," " GENERATE_ARRAY(-172800000, 691200000, 86400))" ) if cluster_list := [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster") or c.get(CLUSTER_HINT, False) ]: @@ -365,8 +370,57 @@ def prepare_load_table( ) return table + def get_storage_tables( + self, table_names: Iterable[str] + ) -> Iterable[Tuple[str, TTableSchemaColumns]]: + """Gets table schemas from BigQuery using INFORMATION_SCHEMA or get_table for hidden datasets""" + if not self.sql_client.is_hidden_dataset: + return super().get_storage_tables(table_names) + + # use the api to get storage tables for hidden dataset + schema_tables: List[Tuple[str, TTableSchemaColumns]] = [] + for table_name in table_names: + try: + schema_table: TTableSchemaColumns = {} + table = self.sql_client.native_connection.get_table( + self.sql_client.make_qualified_table_name(table_name, escape=False), + retry=self.sql_client._default_retry, + timeout=self.config.http_timeout, + ) + for c in table.schema: + schema_c: TColumnSchema = { + "name": c.name, + "nullable": c.is_nullable, + **self._from_db_type(c.field_type, c.precision, c.scale), + } + schema_table[c.name] = schema_c + schema_tables.append((table_name, schema_table)) + except gcp_exceptions.NotFound: + # table is not present + schema_tables.append((table_name, {})) + return schema_tables + + def _get_info_schema_columns_query( + self, catalog_name: Optional[str], schema_name: str, folded_table_names: List[str] + ) -> Tuple[str, List[Any]]: + """Bigquery needs to scope the INFORMATION_SCHEMA.COLUMNS with project and dataset name so standard query generator cannot be used.""" + # escape schema and catalog names + catalog_name = self.capabilities.escape_identifier(catalog_name) + schema_name = self.capabilities.escape_identifier(schema_name) + + query = f""" +SELECT {",".join(self._get_storage_table_query_columns())} + FROM {catalog_name}.{schema_name}.INFORMATION_SCHEMA.COLUMNS +WHERE """ + + # placeholder for each table + table_placeholders = ",".join(["%s"] * len(folded_table_names)) + query += f"table_name IN ({table_placeholders}) ORDER BY table_name, ordinal_position;" + + return query, folded_table_names + def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(column["name"]) + name = self.sql_client.escape_column_name(column["name"]) column_def_sql = ( f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" ) @@ -376,32 +430,6 @@ def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_AWAY_FROM_ZERO')" return column_def_sql - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - schema_table: TTableSchemaColumns = {} - try: - table = self.sql_client.native_connection.get_table( - self.sql_client.make_qualified_table_name(table_name, escape=False), - retry=self.sql_client._default_retry, - timeout=self.config.http_timeout, - ) - partition_field = table.time_partitioning.field if table.time_partitioning else None - for c in table.schema: - schema_c: TColumnSchema = { - "name": c.name, - "nullable": c.is_nullable, - "unique": False, - "sort": False, - "primary_key": False, - "foreign_key": False, - "cluster": c.name in (table.clustering_fields or []), - "partition": c.name == partition_field, - **self._from_db_type(c.field_type, c.precision, c.scale), - } - schema_table[c.name] = schema_c - return True, schema_table - except gcp_exceptions.NotFound: - return False, schema_table - def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.LoadJob: # append to table for merge loads (append to stage) and regular appends. table_name = table["name"] diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index f69e85ca3d..0e2403f7d9 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -14,6 +14,7 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_type: Final[str] = dataclasses.field(default="bigquery", init=False, repr=False, compare=False) # type: ignore credentials: GcpServiceAccountCredentials = None location: str = "US" + has_case_sensitive_identifiers: bool = True http_timeout: float = 15.0 # connection timeout for http request to BigQuery api file_upload_timeout: float = 30 * 60.0 # a timeout for file upload when loading local files diff --git a/dlt/destinations/impl/bigquery/factory.py b/dlt/destinations/impl/bigquery/factory.py index bee55fa164..db61a6042a 100644 --- a/dlt/destinations/impl/bigquery/factory.py +++ b/dlt/destinations/impl/bigquery/factory.py @@ -1,10 +1,13 @@ import typing as t -from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration +from dlt.common.normalizers.naming import NamingConvention from dlt.common.configuration.specs import GcpServiceAccountCredentials -from dlt.destinations.impl.bigquery import capabilities +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.data_writers.escape import escape_hive_identifier, format_bigquery_datetime_literal from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration + if t.TYPE_CHECKING: from dlt.destinations.impl.bigquery.bigquery import BigQueryClient @@ -13,8 +16,34 @@ class bigquery(Destination[BigQueryClientConfiguration, "BigQueryClient"]): spec = BigQueryClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl", "parquet"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet", "jsonl"] + # BigQuery is by default case sensitive but that cannot be turned off for a dataset + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity + caps.escape_identifier = escape_hive_identifier + caps.escape_literal = None + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + # BQ limit is 4GB but leave a large headroom since buffered writer does not preemptively check size + caps.recommended_file_size = int(1024 * 1024 * 1024) + caps.format_datetime_literal = format_bigquery_datetime_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (76, 38) + caps.max_identifier_length = 1024 + caps.max_column_identifier_length = 300 + caps.max_query_length = 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 10 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.supports_clone_table = True + caps.schema_supports_numeric_precision = False # no precision information in BigQuery + + return caps @property def client_class(self) -> t.Type["BigQueryClient"]: @@ -26,14 +55,38 @@ def __init__( self, credentials: t.Optional[GcpServiceAccountCredentials] = None, location: t.Optional[str] = None, + has_case_sensitive_identifiers: bool = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, ) -> None: + """Configure the MsSql destination to use in a pipeline. + + All arguments provided here supersede other configuration sources such as environment variables and dlt config files. + + Args: + credentials: Credentials to connect to the mssql database. Can be an instance of `GcpServiceAccountCredentials` or + a dict or string with service accounts credentials as used in the Google Cloud + location: A location where the datasets will be created, eg. "EU". The default is "US" + has_case_sensitive_identifiers: Is the dataset case-sensitive, defaults to True + **kwargs: Additional arguments passed to the destination config + """ super().__init__( credentials=credentials, location=location, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: BigQueryClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + caps.has_case_sensitive_identifiers = config.has_case_sensitive_identifiers + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 21086a4db6..45e9379af5 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence +from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence, Generator import google.cloud.bigquery as bigquery # noqa: I250 from google.api_core import exceptions as api_core_exceptions @@ -8,6 +8,7 @@ from google.cloud.bigquery.dbapi import Connection as DbApiConnection, Cursor as BQDbApiCursor from google.cloud.bigquery.dbapi import exceptions as dbapi_exceptions +from dlt.common import logger from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.typing import StrAny @@ -16,7 +17,6 @@ DatabaseTransientException, DatabaseUndefinedRelation, ) -from dlt.destinations.impl.bigquery import capabilities from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, @@ -44,29 +44,42 @@ class BigQueryDBApiCursorImpl(DBApiCursorImpl): """Use native BigQuery data frame support if available""" native_cursor: BQDbApiCursor # type: ignore + df_iterator: Generator[Any, None, None] - def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: - if chunk_size is not None: - return super().df(chunk_size=chunk_size) + def __init__(self, curr: DBApiCursor) -> None: + super().__init__(curr) + self.df_iterator = None + + def df(self, chunk_size: Optional[int] = None, **kwargs: Any) -> DataFrame: query_job: bigquery.QueryJob = getattr( self.native_cursor, "_query_job", self.native_cursor.query_job ) - + if self.df_iterator: + return next(self.df_iterator, None) try: + if chunk_size is not None: + # create iterator with given page size + self.df_iterator = query_job.result(page_size=chunk_size).to_dataframe_iterable() + return next(self.df_iterator, None) return query_job.to_dataframe(**kwargs) - except ValueError: + except ValueError as ex: # no pyarrow/db-types, fallback to our implementation - return super().df() + logger.warning(f"Native BigQuery pandas reader could not be used: {str(ex)}") + return super().df(chunk_size=chunk_size) + + def close(self) -> None: + if self.df_iterator: + self.df_iterator.close() class BigQuerySqlClient(SqlClientBase[bigquery.Client], DBTransaction): dbapi: ClassVar[DBApi] = bq_dbapi - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__( self, dataset_name: str, credentials: GcpServiceAccountCredentialsWithoutDefaults, + capabilities: DestinationCapabilitiesContext, location: str = "US", http_timeout: float = 15.0, retry_deadline: float = 60.0, @@ -75,7 +88,7 @@ def __init__( self.credentials: GcpServiceAccountCredentialsWithoutDefaults = credentials self.location = location self.http_timeout = http_timeout - super().__init__(credentials.project_id, dataset_name) + super().__init__(credentials.project_id, dataset_name, capabilities) self._default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) self._default_query = bigquery.QueryJobConfig( @@ -177,8 +190,11 @@ def has_dataset(self) -> bool: return False def create_dataset(self) -> None: + dataset = bigquery.Dataset(self.fully_qualified_dataset_name(escape=False)) + dataset.location = self.location + dataset.is_case_insensitive = not self.capabilities.has_case_sensitive_identifiers self._client.create_dataset( - self.fully_qualified_dataset_name(escape=False), + dataset, retry=self._default_retry, timeout=self.http_timeout, ) @@ -221,14 +237,19 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # will close all cursors conn.close() - def fully_qualified_dataset_name(self, escape: bool = True) -> str: + def catalog_name(self, escape: bool = True) -> Optional[str]: + project_id = self.capabilities.casefold_identifier(self.credentials.project_id) if escape: - project_id = self.capabilities.escape_identifier(self.credentials.project_id) - dataset_name = self.capabilities.escape_identifier(self.dataset_name) - else: - project_id = self.credentials.project_id - dataset_name = self.dataset_name - return f"{project_id}.{dataset_name}" + project_id = self.capabilities.escape_identifier(project_id) + return project_id + + @property + def is_hidden_dataset(self) -> bool: + """Tells if the dataset associated with sql_client is a hidden dataset. + + Hidden datasets are not present in information schema. + """ + return self.dataset_name.startswith("_") @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: diff --git a/dlt/destinations/impl/clickhouse/__init__.py b/dlt/destinations/impl/clickhouse/__init__.py index bead136828..e69de29bb2 100644 --- a/dlt/destinations/impl/clickhouse/__init__.py +++ b/dlt/destinations/impl/clickhouse/__init__.py @@ -1,53 +0,0 @@ -import sys - -from dlt.common.pendulum import pendulum -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.data_writers.escape import ( - escape_clickhouse_identifier, - escape_clickhouse_literal, - format_clickhouse_datetime_literal, -) -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.impl.clickhouse.clickhouse_adapter import clickhouse_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["parquet", "jsonl"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["parquet", "jsonl"] - - caps.format_datetime_literal = format_clickhouse_datetime_literal - caps.escape_identifier = escape_clickhouse_identifier - caps.escape_literal = escape_clickhouse_literal - - # https://stackoverflow.com/questions/68358686/what-is-the-maximum-length-of-a-column-in-clickhouse-can-it-be-modified - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - - # ClickHouse has no max `String` type length. - caps.max_text_data_type_length = sys.maxsize - - caps.schema_supports_numeric_precision = True - # Use 'Decimal128' with these defaults. - # https://clickhouse.com/docs/en/sql-reference/data-types/decimal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - # Use 'Decimal256' with these defaults. - caps.wei_precision = (76, 0) - caps.timestamp_precision = 6 - - # https://clickhouse.com/docs/en/operations/settings/settings#max_query_size - caps.is_max_query_length_in_bytes = True - caps.max_query_length = 262144 - - # ClickHouse has limited support for transactional semantics, especially for `ReplicatedMergeTree`, - # the default ClickHouse Cloud engine. It does, however, provide atomicity for individual DDL operations like `ALTER TABLE`. - # https://clickhouse-driver.readthedocs.io/en/latest/dbapi.html#clickhouse_driver.dbapi.connection.Connection.commit - # https://clickhouse.com/docs/en/guides/developer/transactional#transactions-commit-and-rollback - caps.supports_transactions = False - caps.supports_ddl_transactions = False - - caps.supports_truncate_command = True - - return caps diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index cf1f1bc857..6dd8fd47ed 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -36,7 +36,6 @@ ) from dlt.common.storages import FileStorage from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.clickhouse import capabilities from dlt.destinations.impl.clickhouse.clickhouse_adapter import ( TTableEngineType, TABLE_ENGINE_TYPE_HINT, @@ -289,15 +288,14 @@ def requires_temp_table_for_delete(cls) -> bool: class ClickHouseClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__( self, schema: Schema, config: ClickHouseClientConfiguration, + capabilities: DestinationCapabilitiesContext, ) -> None: self.sql_client: ClickHouseSqlClient = ClickHouseSqlClient( - config.normalize_dataset_name(schema), config.credentials + config.normalize_dataset_name(schema), config.credentials, capabilities ) super().__init__(schema, config, self.sql_client) self.config: ClickHouseClientConfiguration = config @@ -327,7 +325,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) return ( - f"{self.capabilities.escape_identifier(c['name'])} {type_with_nullability_modifier} {hints_str}" + f"{self.sql_client.escape_column_name(c['name'])} {type_with_nullability_modifier} {hints_str}" .strip() ) @@ -357,7 +355,7 @@ def _get_table_update_sql( sql[0] = f"{sql[0]}\nENGINE = {TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR.get(table_type)}" if primary_key_list := [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("primary_key") ]: @@ -367,34 +365,6 @@ def _get_table_update_sql( return sql - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - fields = self._get_storage_table_query_columns() - db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split( - ".", 3 - ) - query = f'SELECT {",".join(fields)} FROM INFORMATION_SCHEMA.COLUMNS WHERE ' - if len(db_params) == 3: - query += "table_catalog = %s AND " - query += "table_schema = %s AND table_name = %s ORDER BY ordinal_position;" - rows = self.sql_client.execute_sql(query, *db_params) - - # If no rows we assume that table does not exist. - schema_table: TTableSchemaColumns = {} - if len(rows) == 0: - return False, schema_table - for c in rows: - numeric_precision = ( - c[3] if self.capabilities.schema_supports_numeric_precision else None - ) - numeric_scale = c[4] if self.capabilities.schema_supports_numeric_precision else None - schema_c: TColumnSchemaBase = { - "name": c[0], - "nullable": bool(c[2]), - **self._from_db_type(c[1], numeric_precision, numeric_scale), - } - schema_table[c[0]] = schema_c # type: ignore - return True, schema_table - @staticmethod def _gen_not_null(v: bool) -> str: # ClickHouse fields are not nullable by default. diff --git a/dlt/destinations/impl/clickhouse/factory.py b/dlt/destinations/impl/clickhouse/factory.py index e5b8fc0e6a..52a1694dee 100644 --- a/dlt/destinations/impl/clickhouse/factory.py +++ b/dlt/destinations/impl/clickhouse/factory.py @@ -1,7 +1,14 @@ +import sys import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.impl.clickhouse import capabilities +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.data_writers.escape import ( + escape_clickhouse_identifier, + escape_clickhouse_literal, + format_clickhouse_datetime_literal, +) + from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseClientConfiguration, ClickHouseCredentials, @@ -16,8 +23,51 @@ class clickhouse(Destination[ClickHouseClientConfiguration, "ClickHouseClient"]): spec = ClickHouseClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["parquet", "jsonl"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["parquet", "jsonl"] + + caps.format_datetime_literal = format_clickhouse_datetime_literal + caps.escape_identifier = escape_clickhouse_identifier + caps.escape_literal = escape_clickhouse_literal + # docs are very unclear https://clickhouse.com/docs/en/sql-reference/syntax + # taking into account other sources: identifiers are case sensitive + caps.has_case_sensitive_identifiers = True + # and store as is in the information schema + caps.casefold_identifier = str + + # https://stackoverflow.com/questions/68358686/what-is-the-maximum-length-of-a-column-in-clickhouse-can-it-be-modified + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + + # ClickHouse has no max `String` type length. + caps.max_text_data_type_length = sys.maxsize + + caps.schema_supports_numeric_precision = True + # Use 'Decimal128' with these defaults. + # https://clickhouse.com/docs/en/sql-reference/data-types/decimal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + # Use 'Decimal256' with these defaults. + caps.wei_precision = (76, 0) + caps.timestamp_precision = 6 + + # https://clickhouse.com/docs/en/operations/settings/settings#max_query_size + caps.is_max_query_length_in_bytes = True + caps.max_query_length = 262144 + + # ClickHouse has limited support for transactional semantics, especially for `ReplicatedMergeTree`, + # the default ClickHouse Cloud engine. It does, however, provide atomicity for individual DDL operations like `ALTER TABLE`. + # https://clickhouse-driver.readthedocs.io/en/latest/dbapi.html#clickhouse_driver.dbapi.connection.Connection.commit + # https://clickhouse.com/docs/en/guides/developer/transactional#transactions-commit-and-rollback + caps.supports_transactions = False + caps.supports_ddl_transactions = False + + caps.supports_truncate_command = True + + return caps @property def client_class(self) -> t.Type["ClickHouseClient"]: diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index 8fb89c90cd..ee013ea123 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -7,6 +7,7 @@ Optional, Sequence, ClassVar, + Tuple, ) import clickhouse_driver # type: ignore[import-untyped] @@ -20,7 +21,6 @@ DatabaseTransientException, DatabaseTerminalException, ) -from dlt.destinations.impl.clickhouse import capabilities from dlt.destinations.impl.clickhouse.configuration import ClickHouseCredentials from dlt.destinations.sql_client import ( DBApiCursorImpl, @@ -45,15 +45,20 @@ class ClickHouseSqlClient( SqlClientBase[clickhouse_driver.dbapi.connection.Connection], DBTransaction ): dbapi: ClassVar[DBApi] = clickhouse_driver.dbapi - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: ClickHouseCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + credentials: ClickHouseCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, capabilities) self._conn: clickhouse_driver.dbapi.connection = None self.credentials = credentials self.database_name = credentials.database def has_dataset(self) -> bool: + # we do not need to normalize dataset_sentinel_table_name sentinel_table = self.credentials.dataset_sentinel_table_name return sentinel_table in [ t.split(self.credentials.dataset_table_separator)[1] for t in self._list_tables() @@ -110,10 +115,11 @@ def drop_dataset(self) -> None: # This is because the driver incorrectly substitutes the entire query string, causing the "DROP TABLE" keyword to be omitted. # To resolve this, we are forced to provide the full query string here. self.execute_sql( - f"""DROP TABLE {self.capabilities.escape_identifier(self.database_name)}.{self.capabilities.escape_identifier(table)} SYNC""" + f"""DROP TABLE {self.catalog_name()}.{self.capabilities.escape_identifier(table)} SYNC""" ) def _list_tables(self) -> List[str]: + catalog_name, table_name = self.make_qualified_table_name_path("%", escape=False) rows = self.execute_sql( """ SELECT name @@ -121,10 +127,8 @@ def _list_tables(self) -> List[str]: WHERE database = %s AND name LIKE %s """, - ( - self.database_name, - f"{self.dataset_name}{self.credentials.dataset_table_separator}%", - ), + catalog_name, + table_name, ) return [row[0] for row in rows] @@ -151,21 +155,33 @@ def execute_query( yield ClickHouseDBApiCursorImpl(cursor) # type: ignore[abstract] - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - database_name = self.database_name - dataset_name = self.dataset_name - if escape: - database_name = self.capabilities.escape_identifier(database_name) - dataset_name = self.capabilities.escape_identifier(dataset_name) - return f"{database_name}.{dataset_name}" - - def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: - database_name = self.database_name - table_name = f"{self.dataset_name}{self.credentials.dataset_table_separator}{table_name}" + def catalog_name(self, escape: bool = True) -> Optional[str]: + database_name = self.capabilities.casefold_identifier(self.database_name) if escape: database_name = self.capabilities.escape_identifier(database_name) - table_name = self.capabilities.escape_identifier(table_name) - return f"{database_name}.{table_name}" + return database_name + + def make_qualified_table_name_path( + self, table_name: Optional[str], escape: bool = True + ) -> List[str]: + # get catalog and dataset + path = super().make_qualified_table_name_path(None, escape=escape) + if table_name: + # table name combines dataset name and table name + table_name = self.capabilities.casefold_identifier( + f"{self.dataset_name}{self.credentials.dataset_table_separator}{table_name}" + ) + if escape: + table_name = self.capabilities.escape_identifier(table_name) + # we have only two path components + path[1] = table_name + return path + + def _get_information_schema_components(self, *tables: str) -> Tuple[str, str, List[str]]: + components = super()._get_information_schema_components(*tables) + # clickhouse has a catalogue and no schema but uses catalogue as a schema to query the information schema 🤷 + # so we must disable catalogue search. also note that table name is prefixed with logical "dataset_name" + return (None, components[0], components[2]) @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index 81884fae4b..e69de29bb2 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -1,30 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - -from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_databricks_identifier - caps.escape_literal = escape_databricks_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 2 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 16 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.supports_truncate_command = True - # caps.supports_transactions = False - caps.alter_add_multi_column = True - caps.supports_multiple_statements = False - caps.supports_clone_table = True - return caps diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index cd203e7e4d..62debdedb7 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,6 +1,7 @@ from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type, cast from urllib.parse import urlparse, urlunparse +from dlt import config from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( FollowupJob, @@ -15,27 +16,22 @@ AzureCredentials, AzureCredentialsWithoutDefaults, ) -from dlt.common.data_types import TDataType from dlt.common.exceptions import TerminalValueError from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat from dlt.common.schema.utils import table_schema_has_type +from dlt.common.storages import FilesystemConfiguration, fsspec_from_config from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import LoadJobTerminalException - -from dlt.destinations.impl.databricks import capabilities from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper -from dlt.common.storages import FilesystemConfiguration, fsspec_from_config -from dlt import config class DatabricksTypeMapper(TypeMapper): @@ -258,10 +254,15 @@ def gen_delete_from_sql( class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DatabricksClientConfiguration) -> None: - sql_client = DatabricksSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: DatabricksClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = DatabricksSqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] @@ -303,7 +304,7 @@ def _get_table_update_sql( sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) cluster_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster") ] if cluster_list: @@ -317,14 +318,14 @@ def _from_db_type( return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(c["name"]) + name = self.sql_client.escape_column_name(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() - fields[1] = ( # Override because this is the only way to get data type with precision + fields[2] = ( # Override because this is the only way to get data type with precision "full_data_type" ) return fields diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index 7c6c95137d..56462714c1 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -1,12 +1,13 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.impl.databricks.configuration import ( DatabricksCredentials, DatabricksClientConfiguration, ) -from dlt.destinations.impl.databricks import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.databricks.databricks import DatabricksClient @@ -15,8 +16,33 @@ class databricks(Destination[DatabricksClientConfiguration, "DatabricksClient"]): spec = DatabricksClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = None + caps.supported_loader_file_formats = [] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.escape_identifier = escape_databricks_identifier + # databricks identifiers are case insensitive and stored in lower case + # https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html + caps.escape_literal = escape_databricks_literal + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 2 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 16 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.supports_truncate_command = True + # caps.supports_transactions = False + caps.alter_add_multi_column = True + caps.supports_multiple_statements = False + caps.supports_clone_table = True + return caps @property def client_class(self) -> t.Type["DatabricksClient"]: diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 530b03715a..da91402803 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Union, Dict +from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Tuple, Union, Dict from databricks import sql as databricks_lib @@ -21,18 +21,37 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction +from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame from dlt.destinations.impl.databricks.configuration import DatabricksCredentials -from dlt.destinations.impl.databricks import capabilities -from dlt.common.time import to_py_date, to_py_datetime + + +class DatabricksCursorImpl(DBApiCursorImpl): + """Use native data frame support if available""" + + native_cursor: DatabricksSqlCursor # type: ignore[assignment] + vector_size: ClassVar[int] = 2048 + + def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: + if chunk_size is None: + return self.native_cursor.fetchall_arrow().to_pandas() + else: + df = self.native_cursor.fetchmany_arrow(chunk_size).to_pandas() + if df.shape[0] == 0: + return None + else: + return df class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> None: - super().__init__(credentials.catalog, dataset_name) + def __init__( + self, + dataset_name: str, + credentials: DatabricksCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.catalog, dataset_name, capabilities) self._conn: DatabricksSqlConnection = None self.credentials = credentials @@ -112,16 +131,13 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB db_args = args or kwargs or None with self._conn.cursor() as curr: # type: ignore[assignment] curr.execute(query, db_args) - yield DBApiCursorImpl(curr) # type: ignore[abstract] + yield DatabricksCursorImpl(curr) # type: ignore[abstract] - def fully_qualified_dataset_name(self, escape: bool = True) -> str: + def catalog_name(self, escape: bool = True) -> Optional[str]: + catalog = self.capabilities.casefold_identifier(self.credentials.catalog) if escape: - catalog = self.capabilities.escape_identifier(self.credentials.catalog) - dataset_name = self.capabilities.escape_identifier(self.dataset_name) - else: - catalog = self.credentials.catalog - dataset_name = self.dataset_name - return f"{catalog}.{dataset_name}" + catalog = self.capabilities.escape_identifier(catalog) + return catalog @staticmethod def _make_database_exception(ex: Exception) -> Exception: diff --git a/dlt/destinations/impl/destination/__init__.py b/dlt/destinations/impl/destination/__init__.py index 5b076df4c6..e69de29bb2 100644 --- a/dlt/destinations/impl/destination/__init__.py +++ b/dlt/destinations/impl/destination/__init__.py @@ -1,21 +0,0 @@ -from typing import Optional -from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat -from dlt.common.destination.capabilities import TLoaderParallelismStrategy - - -def capabilities( - preferred_loader_file_format: TLoaderFileFormat = "typed-jsonl", - naming_convention: str = "direct", - max_table_nesting: Optional[int] = 0, - max_parallel_load_jobs: Optional[int] = 0, - loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None, -) -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) - caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] - caps.supports_ddl_transactions = False - caps.supports_transactions = False - caps.naming_convention = naming_convention - caps.max_table_nesting = max_table_nesting - caps.max_parallel_load_jobs = max_parallel_load_jobs - caps.loader_parallelism_strategy = loader_parallelism_strategy - return caps diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py index c3b677058c..705f3b0bb5 100644 --- a/dlt/destinations/impl/destination/configuration.py +++ b/dlt/destinations/impl/destination/configuration.py @@ -1,20 +1,23 @@ import dataclasses -from typing import Optional, Final, Callable, Union +from typing import Optional, Final, Callable, Union, Any from typing_extensions import ParamSpec -from dlt.common.configuration import configspec +from dlt.common.configuration import configspec, ConfigurationValueError from dlt.common.destination import TLoaderFileFormat from dlt.common.destination.reference import ( DestinationClientConfiguration, ) from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema -from dlt.common.destination import Destination TDestinationCallable = Callable[[Union[TDataItems, str], TTableSchema], None] TDestinationCallableParams = ParamSpec("TDestinationCallableParams") +def dummy_custom_destination(*args: Any, **kwargs: Any) -> None: + pass + + @configspec class CustomDestinationClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = dataclasses.field(default="destination", init=False, repr=False, compare=False) # type: ignore @@ -23,3 +26,15 @@ class CustomDestinationClientConfiguration(DestinationClientConfiguration): batch_size: int = 10 skip_dlt_columns_and_tables: bool = True max_table_nesting: Optional[int] = 0 + + def ensure_callable(self) -> None: + """Makes sure that valid callable was provided""" + # TODO: this surely can be done with `on_resolved` + if ( + self.destination_callable is None + or self.destination_callable is dummy_custom_destination + ): + raise ConfigurationValueError( + f"A valid callable was not provided to {self.__class__.__name__}. Did you decorate" + " a function @dlt.destination correctly?" + ) diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index 69d1d1d98a..c44fd3cca1 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -15,8 +15,6 @@ DoNothingJob, JobClientBase, ) - -from dlt.destinations.impl.destination import capabilities from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration from dlt.destinations.job_impl import ( DestinationJsonlLoadJob, @@ -27,10 +25,14 @@ class DestinationClient(JobClientBase): """Sink Client""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: CustomDestinationClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: CustomDestinationClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + config.ensure_callable() + super().__init__(schema, config, capabilities) self.config: CustomDestinationClientConfiguration = config # create pre-resolved callable to avoid multiple config resolutions during execution of the jobs self.destination_callable = create_resolved_partial( diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py index b3127ab99b..69bb0daa13 100644 --- a/dlt/destinations/impl/destination/factory.py +++ b/dlt/destinations/impl/destination/factory.py @@ -4,18 +4,20 @@ from types import ModuleType from dlt.common import logger +from dlt.common.destination.capabilities import TLoaderParallelismStrategy +from dlt.common.exceptions import TerminalValueError +from dlt.common.normalizers.naming.naming import NamingConvention from dlt.common.typing import AnyFun from dlt.common.destination import Destination, DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.configuration import known_sections, with_config, get_fun_spec from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.utils import get_callable_name, is_inner_callable -from dlt.destinations.exceptions import DestinationTransientException from dlt.destinations.impl.destination.configuration import ( CustomDestinationClientConfiguration, + dummy_custom_destination, TDestinationCallable, ) -from dlt.destinations.impl.destination import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.destination.destination import DestinationClient @@ -34,16 +36,16 @@ class DestinationInfo(t.NamedTuple): class destination(Destination[CustomDestinationClientConfiguration, "DestinationClient"]): - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities( - preferred_loader_file_format=self.config_params.get( - "loader_file_format", "typed-jsonl" - ), - naming_convention=self.config_params.get("naming_convention", "direct"), - max_table_nesting=self.config_params.get("max_table_nesting", None), - max_parallel_load_jobs=self.config_params.get("max_parallel_load_jobs", None), - loader_parallelism_strategy=self.config_params.get("loader_parallelism_strategy", None), - ) + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext.generic_capabilities("typed-jsonl") + caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] + caps.supports_ddl_transactions = False + caps.supports_transactions = False + caps.naming_convention = "direct" + caps.max_table_nesting = 0 + caps.max_parallel_load_jobs = 0 + caps.loader_parallelism_strategy = None + return caps @property def spec(self) -> t.Type[CustomDestinationClientConfiguration]: @@ -68,7 +70,7 @@ def __init__( **kwargs: t.Any, ) -> None: if spec and not issubclass(spec, CustomDestinationClientConfiguration): - raise ValueError( + raise TerminalValueError( "A SPEC for a sink destination must use CustomDestinationClientConfiguration as a" " base." ) @@ -97,14 +99,7 @@ def __init__( "No destination callable provided, providing dummy callable which will fail on" " load." ) - - def dummy_callable(*args: t.Any, **kwargs: t.Any) -> None: - raise DestinationTransientException( - "You tried to load to a custom destination without a valid callable." - ) - - destination_callable = dummy_callable - + destination_callable = dummy_custom_destination elif not callable(destination_callable): raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") @@ -138,9 +133,21 @@ def dummy_callable(*args: t.Any, **kwargs: t.Any) -> None: super().__init__( destination_name=destination_name, environment=environment, + # NOTE: `loader_file_format` is not a field in the caps so we had to hack the base class to allow this loader_file_format=loader_file_format, batch_size=batch_size, naming_convention=naming_convention, destination_callable=conf_callable, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: CustomDestinationClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super().adjust_capabilities(caps, config, naming) + caps.preferred_loader_file_format = config.loader_file_format + return caps diff --git a/dlt/destinations/impl/dremio/__init__.py b/dlt/destinations/impl/dremio/__init__.py index b4bde2fe6d..e69de29bb2 100644 --- a/dlt/destinations/impl/dremio/__init__.py +++ b/dlt/destinations/impl/dremio/__init__.py @@ -1,27 +0,0 @@ -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.data_writers.escape import escape_dremio_identifier -from dlt.common.destination import DestinationCapabilitiesContext - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_dremio_identifier - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 2 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 16 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_transactions = False - caps.supports_ddl_transactions = False - caps.alter_add_multi_column = True - caps.supports_clone_table = False - caps.supports_multiple_statements = False - caps.timestamp_precision = 3 - return caps diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 23bca0ad74..00e51b74a6 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -14,7 +14,6 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import uniq_id from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.dremio import capabilities from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration from dlt.destinations.impl.dremio.sql_client import DremioSqlClient from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -137,10 +136,15 @@ def exception(self) -> str: class DremioClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DremioClientConfiguration) -> None: - sql_client = DremioSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: DremioClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = DremioSqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) super().__init__(schema, config, sql_client) self.config: DremioClientConfiguration = config self.sql_client: DremioSqlClient = sql_client # type: ignore @@ -172,7 +176,7 @@ def _get_table_update_sql( if not generate_alter: partition_list = [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("partition") ] @@ -180,7 +184,7 @@ def _get_table_update_sql( sql[0] += "\nPARTITION BY (" + ",".join(partition_list) + ")" sort_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("sort") + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("sort") ] if sort_list: sql[0] += "\nLOCALSORT BY (" + ",".join(sort_list) + ")" @@ -193,45 +197,11 @@ def _from_db_type( return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(c["name"]) + name = self.sql_client.escape_column_name(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - def _null_to_bool(v: str) -> bool: - if v == "NO": - return False - elif v == "YES": - return True - raise ValueError(v) - - fields = self._get_storage_table_query_columns() - table_schema = self.sql_client.fully_qualified_dataset_name(escape=False) - db_params = (table_schema, table_name) - query = f""" -SELECT {",".join(fields)} - FROM INFORMATION_SCHEMA.COLUMNS -WHERE - table_catalog = 'DREMIO' AND table_schema = %s AND table_name = %s ORDER BY ordinal_position; -""" - rows = self.sql_client.execute_sql(query, *db_params) - - # if no rows we assume that table does not exist - schema_table: TTableSchemaColumns = {} - if len(rows) == 0: - return False, schema_table - for c in rows: - numeric_precision = c[3] - numeric_scale = c[4] - schema_c: TColumnSchemaBase = { - "name": c[0], - "nullable": _null_to_bool(c[2]), - **self._from_db_type(c[1], numeric_precision, numeric_scale), - } - schema_table[c[0]] = schema_c # type: ignore - return True, schema_table - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] diff --git a/dlt/destinations/impl/dremio/factory.py b/dlt/destinations/impl/dremio/factory.py index 61895e4f90..29a4937c69 100644 --- a/dlt/destinations/impl/dremio/factory.py +++ b/dlt/destinations/impl/dremio/factory.py @@ -1,11 +1,13 @@ import typing as t +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.data_writers.escape import escape_dremio_identifier + from dlt.destinations.impl.dremio.configuration import ( DremioCredentials, DremioClientConfiguration, ) -from dlt.destinations.impl.dremio import capabilities -from dlt.common.destination import Destination, DestinationCapabilitiesContext if t.TYPE_CHECKING: from dlt.destinations.impl.dremio.dremio import DremioClient @@ -14,8 +16,31 @@ class dremio(Destination[DremioClientConfiguration, "DremioClient"]): spec = DremioClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = None + caps.supported_loader_file_formats = [] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.escape_identifier = escape_dremio_identifier + # all identifiers are case insensitive but are stored as is + # https://docs.dremio.com/current/sonar/data-sources + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 2 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 16 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_transactions = False + caps.supports_ddl_transactions = False + caps.alter_add_multi_column = True + caps.supports_clone_table = False + caps.supports_multiple_statements = False + caps.timestamp_precision = 3 + return caps @property def client_class(self) -> t.Type["DremioClient"]: diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index 255c8acee0..fac65e7fd0 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List +from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Tuple import pyarrow @@ -10,7 +10,7 @@ DatabaseUndefinedRelation, DatabaseTransientException, ) -from dlt.destinations.impl.dremio import capabilities, pydremio +from dlt.destinations.impl.dremio import pydremio from dlt.destinations.impl.dremio.configuration import DremioCredentials from dlt.destinations.sql_client import ( DBApiCursorImpl, @@ -32,10 +32,14 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: class DremioSqlClient(SqlClientBase[pydremio.DremioConnection]): dbapi: ClassVar[DBApi] = pydremio - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: DremioCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + credentials: DremioCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, capabilities) self._conn: Optional[pydremio.DremioConnection] = None self.credentials = credentials @@ -99,18 +103,16 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB raise DatabaseTransientException(ex) yield DremioCursorImpl(curr) # type: ignore - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - database_name = self.credentials.database - dataset_name = self.dataset_name + def catalog_name(self, escape: bool = True) -> Optional[str]: + database_name = self.capabilities.casefold_identifier(self.database_name) if escape: database_name = self.capabilities.escape_identifier(database_name) - dataset_name = self.capabilities.escape_identifier(dataset_name) - return f"{database_name}.{dataset_name}" + return database_name - def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: - if escape: - table_name = self.capabilities.escape_identifier(table_name) - return f"{self.fully_qualified_dataset_name(escape=escape)}.{table_name}" + def _get_information_schema_components(self, *tables: str) -> Tuple[str, str, List[str]]: + components = super()._get_information_schema_components(*tables) + # catalog is always DREMIO but schema contains "database" prefix 🤷 + return ("DREMIO", self.fully_qualified_dataset_name(escape=False), components[2]) @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: @@ -138,10 +140,10 @@ def _get_table_names(self) -> List[str]: query = """ SELECT TABLE_NAME FROM INFORMATION_SCHEMA."TABLES" - WHERE TABLE_CATALOG = 'DREMIO' AND TABLE_SCHEMA = %s + WHERE TABLE_CATALOG = %s AND TABLE_SCHEMA = %s """ - db_params = [self.fully_qualified_dataset_name(escape=False)] - tables = self.execute_sql(query, *db_params) or [] + catalog_name, schema_name, _ = self._get_information_schema_components() + tables = self.execute_sql(query, catalog_name, schema_name) or [] return [table[0] for table in tables] def drop_dataset(self) -> None: diff --git a/dlt/destinations/impl/duckdb/__init__.py b/dlt/destinations/impl/duckdb/__init__.py index 5cbc8dea53..e69de29bb2 100644 --- a/dlt/destinations/impl/duckdb/__init__.py +++ b/dlt/destinations/impl/duckdb/__init__.py @@ -1,26 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_duckdb_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 65536 - caps.max_column_identifier_length = 65536 - caps.max_query_length = 32 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - caps.alter_add_multi_column = False - caps.supports_truncate_command = False - - return caps diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 7016e9bfff..b87a2c4780 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -12,7 +12,6 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.impl.duckdb import capabilities from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration from dlt.destinations.type_mapping import TypeMapper @@ -151,10 +150,15 @@ def exception(self) -> str: class DuckDbClient(InsertValuesJobClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DuckDbClientConfiguration) -> None: - sql_client = DuckDbSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: DuckDbClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = DuckDbSqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) super().__init__(schema, config, sql_client) self.config: DuckDbClientConfiguration = config self.sql_client: DuckDbSqlClient = sql_client # type: ignore @@ -173,7 +177,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in self.active_hints.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return ( f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) diff --git a/dlt/destinations/impl/duckdb/factory.py b/dlt/destinations/impl/duckdb/factory.py index 55fcd3b339..388f914479 100644 --- a/dlt/destinations/impl/duckdb/factory.py +++ b/dlt/destinations/impl/duckdb/factory.py @@ -1,8 +1,10 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + from dlt.destinations.impl.duckdb.configuration import DuckDbCredentials, DuckDbClientConfiguration -from dlt.destinations.impl.duckdb import capabilities if t.TYPE_CHECKING: from duckdb import DuckDBPyConnection @@ -12,8 +14,29 @@ class duckdb(Destination[DuckDbClientConfiguration, "DuckDbClient"]): spec = DuckDbClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + caps.escape_identifier = escape_postgres_identifier + # all identifiers are case insensitive but are stored as is + caps.escape_literal = escape_duckdb_literal + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 65536 + caps.max_column_identifier_length = 65536 + caps.max_query_length = 32 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.alter_add_multi_column = False + caps.supports_truncate_command = False + + return caps @property def client_class(self) -> t.Type["DuckDbClient"]: diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index bb85b5825b..95762a1f26 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -17,12 +17,11 @@ raise_open_connection_error, ) -from dlt.destinations.impl.duckdb import capabilities from dlt.destinations.impl.duckdb.configuration import DuckDbBaseCredentials class DuckDBDBApiCursorImpl(DBApiCursorImpl): - """Use native BigQuery data frame support if available""" + """Use native duckdb data frame support if available""" native_cursor: duckdb.DuckDBPyConnection # type: ignore vector_size: ClassVar[int] = 2048 @@ -43,10 +42,14 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: class DuckDbSqlClient(SqlClientBase[duckdb.DuckDBPyConnection], DBTransaction): dbapi: ClassVar[DBApi] = duckdb - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: DuckDbBaseCredentials) -> None: - super().__init__(None, dataset_name) + def __init__( + self, + dataset_name: str, + credentials: DuckDbBaseCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(None, dataset_name, capabilities) self._conn: duckdb.DuckDBPyConnection = None self.credentials = credentials @@ -142,11 +145,6 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # else: # return None - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, (duckdb.CatalogException)): diff --git a/dlt/destinations/impl/dummy/__init__.py b/dlt/destinations/impl/dummy/__init__.py index e09f7d07a9..e69de29bb2 100644 --- a/dlt/destinations/impl/dummy/__init__.py +++ b/dlt/destinations/impl/dummy/__init__.py @@ -1,39 +0,0 @@ -from typing import List -from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.accessors import config -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.capabilities import TLoaderFileFormat - -from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration - - -@with_config( - spec=DummyClientConfiguration, - sections=( - known_sections.DESTINATION, - "dummy", - ), -) -def _configure(config: DummyClientConfiguration = config.value) -> DummyClientConfiguration: - return config - - -def capabilities() -> DestinationCapabilitiesContext: - config = _configure() - additional_formats: List[TLoaderFileFormat] = ( - ["reference"] if config.create_followup_jobs else [] # type:ignore[list-item] - ) - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = config.loader_file_format - caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = additional_formats + [config.loader_file_format] - caps.max_identifier_length = 127 - caps.max_column_identifier_length = 127 - caps.max_query_length = 8 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 65536 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - - return caps diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index 3c78493b57..c41b7dca61 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -36,7 +36,6 @@ LoadJobNotExistsException, LoadJobInvalidStateTransitionException, ) -from dlt.destinations.impl.dummy import capabilities from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration from dlt.destinations.job_impl import NewReferenceJob @@ -110,10 +109,13 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: class DummyClient(JobClientBase, SupportsStagingDestination, WithStagingDataset): """dummy client storing jobs in memory""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DummyClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: DummyClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) self.in_staging_context = False self.config: DummyClientConfiguration = config @@ -160,7 +162,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] diff --git a/dlt/destinations/impl/dummy/factory.py b/dlt/destinations/impl/dummy/factory.py index 1c848cf22d..c68bc36ca9 100644 --- a/dlt/destinations/impl/dummy/factory.py +++ b/dlt/destinations/impl/dummy/factory.py @@ -2,11 +2,12 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.common.normalizers.naming.naming import NamingConvention from dlt.destinations.impl.dummy.configuration import ( DummyClientConfiguration, DummyClientCredentials, ) -from dlt.destinations.impl.dummy import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.dummy.dummy import DummyClient @@ -15,8 +16,19 @@ class dummy(Destination[DummyClientConfiguration, "DummyClient"]): spec = DummyClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_staging_file_format = None + caps.has_case_sensitive_identifiers = True + caps.max_identifier_length = 127 + caps.max_column_identifier_length = 127 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 65536 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + + return caps @property def client_class(self) -> t.Type["DummyClient"]: @@ -37,3 +49,19 @@ def __init__( environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: DummyClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super().adjust_capabilities(caps, config, naming) + additional_formats: t.List[TLoaderFileFormat] = ( + ["reference"] if config.create_followup_jobs else [] # type:ignore[list-item] + ) + caps.preferred_loader_file_format = config.loader_file_format + caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] + caps.supported_staging_file_formats = additional_formats + [config.loader_file_format] + return caps diff --git a/dlt/destinations/impl/filesystem/__init__.py b/dlt/destinations/impl/filesystem/__init__.py index 49fabd61d7..e69de29bb2 100644 --- a/dlt/destinations/impl/filesystem/__init__.py +++ b/dlt/destinations/impl/filesystem/__init__.py @@ -1,24 +0,0 @@ -from typing import Sequence, Tuple - -from dlt.common.schema.typing import TTableSchema -from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat - - -def loader_file_format_adapter( - preferred_loader_file_format: TLoaderFileFormat, - supported_loader_file_formats: Sequence[TLoaderFileFormat], - /, - *, - table_schema: TTableSchema, -) -> Tuple[TLoaderFileFormat, Sequence[TLoaderFileFormat]]: - if table_schema.get("table_format") == "delta": - return ("parquet", ["parquet"]) - return (preferred_loader_file_format, supported_loader_file_formats) - - -def capabilities() -> DestinationCapabilitiesContext: - return DestinationCapabilitiesContext.generic_capabilities( - preferred_loader_file_format="jsonl", - loader_file_format_adapter=loader_file_format_adapter, - supported_table_formats=["delta"], - ) diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index 029a5bdda5..1e6eec5cce 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -1,19 +1,38 @@ import typing as t -from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration -from dlt.destinations.impl.filesystem import capabilities -from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination import Destination, DestinationCapabilitiesContext, TLoaderFileFormat +from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT +from dlt.common.schema.typing import TTableSchema from dlt.common.storages.configuration import FileSystemCredentials +from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration +from dlt.destinations.impl.filesystem.typing import TCurrentDateTime, TExtraPlaceholders + if t.TYPE_CHECKING: from dlt.destinations.impl.filesystem.filesystem import FilesystemClient +def loader_file_format_adapter( + preferred_loader_file_format: TLoaderFileFormat, + supported_loader_file_formats: t.Sequence[TLoaderFileFormat], + /, + *, + table_schema: TTableSchema, +) -> t.Tuple[TLoaderFileFormat, t.Sequence[TLoaderFileFormat]]: + if table_schema.get("table_format") == "delta": + return ("parquet", ["parquet"]) + return (preferred_loader_file_format, supported_loader_file_formats) + + class filesystem(Destination[FilesystemDestinationClientConfiguration, "FilesystemClient"]): spec = FilesystemDestinationClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + return DestinationCapabilitiesContext.generic_capabilities( + preferred_loader_file_format="jsonl", + loader_file_format_adapter=loader_file_format_adapter, + supported_table_formats=["delta"], + ) @property def client_class(self) -> t.Type["FilesystemClient"]: @@ -25,6 +44,9 @@ def __init__( self, bucket_url: str = None, credentials: t.Union[FileSystemCredentials, t.Dict[str, t.Any], t.Any] = None, + layout: str = DEFAULT_FILE_LAYOUT, + extra_placeholders: t.Optional[TExtraPlaceholders] = None, + current_datetime: t.Optional[TCurrentDateTime] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -46,11 +68,20 @@ def __init__( credentials: Credentials to connect to the filesystem. The type of credentials should correspond to the bucket protocol. For example, for AWS S3, the credentials should be an instance of `AwsCredentials`. A dictionary with the credentials parameters can also be provided. + layout (str): A layout of the files holding table data in the destination bucket/filesystem. Uses a set of pre-defined + and user-defined (extra) placeholders. Please refer to https://dlthub.com/docs/dlt-ecosystem/destinations/filesystem#files-layout + extra_placeholders (dict(str, str | callable)): A dictionary of extra placeholder names that can be used in the `layout` parameter. Names + are mapped to string values or to callables evaluated at runtime. + current_datetime (DateTime | callable): current datetime used by date/time related placeholders. If not provided, load package creation timestamp + will be used. **kwargs: Additional arguments passed to the destination config """ super().__init__( bucket_url=bucket_url, credentials=credentials, + layout=layout, + extra_placeholders=extra_placeholders, + current_datetime=current_datetime, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 9d15ba959e..00b990d4fa 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -12,7 +12,7 @@ from dlt.common.typing import DictStrAny from dlt.common.schema import Schema, TSchemaTables, TTableSchema from dlt.common.storages import FileStorage, fsspec_from_config -from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName +from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName, TPipelineStateDoc from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( NewLoadJob, @@ -29,7 +29,6 @@ ) from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob -from dlt.destinations.impl.filesystem import capabilities from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations import path_utils @@ -153,15 +152,19 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: class FilesystemClient(FSClientBase, JobClientBase, WithStagingDataset, WithStateSync): """filesystem client storing jobs in memory""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() fs_client: AbstractFileSystem # a path (without the scheme) to a location in the bucket where dataset is present bucket_path: str # name of the dataset dataset_name: str - def __init__(self, schema: Schema, config: FilesystemDestinationClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: FilesystemDestinationClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) self.fs_client, fs_path = fsspec_from_config(config) self.is_local_filesystem = config.protocol == "file" self.bucket_path = ( @@ -365,7 +368,7 @@ def _write_to_json_file(self, filepath: str, data: DictStrAny) -> None: dirname = self.pathlib.dirname(filepath) if not self.fs_client.isdir(dirname): return - self.fs_client.write_text(filepath, json.dumps(data), "utf-8") + self.fs_client.write_text(filepath, json.dumps(data), encoding="utf-8") def _to_path_safe_string(self, s: str) -> str: """for base64 strings""" @@ -447,8 +450,13 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # Load compressed state from destination if selected_path: - state_json = json.loads(self.fs_client.read_text(selected_path)) - state_json.pop("version_hash") + state_json: TPipelineStateDoc = json.loads( + self.fs_client.read_text(selected_path, encoding="utf-8") + ) + # we had dlt_load_id stored until version 0.5 and since we do not have any version control + # we always migrate + if load_id := state_json.pop("dlt_load_id", None): # type: ignore[typeddict-item] + state_json["_dlt_load_id"] = load_id return StateInfo(**state_json) return None @@ -491,7 +499,9 @@ def _get_stored_schema_by_hash_or_newest( break if selected_path: - return StorageSchemaInfo(**json.loads(self.fs_client.read_text(selected_path))) + return StorageSchemaInfo( + **json.loads(self.fs_client.read_text(selected_path, encoding="utf-8")) + ) return None @@ -528,19 +538,23 @@ def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchema def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: def get_table_jobs( table_jobs: Sequence[LoadJobInfo], table_name: str ) -> Sequence[LoadJobInfo]: return [job for job in table_jobs if job.job_file_info.table_name == table_name] - assert table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs(table_chain, table_chain_jobs) + assert completed_table_chain_jobs is not None + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) table_format = table_chain[0].get("table_format") if table_format == "delta": delta_jobs = [ - DeltaLoadFilesystemJob(self, table, get_table_jobs(table_chain_jobs, table["name"])) + DeltaLoadFilesystemJob( + self, table, get_table_jobs(completed_table_chain_jobs, table["name"]) + ) for table in table_chain ] jobs.extend(delta_jobs) diff --git a/dlt/destinations/impl/filesystem/typing.py b/dlt/destinations/impl/filesystem/typing.py index 139602198d..6781fe21ac 100644 --- a/dlt/destinations/impl/filesystem/typing.py +++ b/dlt/destinations/impl/filesystem/typing.py @@ -15,5 +15,7 @@ `schema name`, `table name`, `load_id`, `file_id` and an `extension` """ -TExtraPlaceholders: TypeAlias = Dict[str, Union[str, TLayoutPlaceholderCallback]] +TExtraPlaceholders: TypeAlias = Dict[ + str, Union[Union[str, int, DateTime], TLayoutPlaceholderCallback] +] """Extra placeholders for filesystem layout""" diff --git a/dlt/destinations/impl/lancedb/__init__.py b/dlt/destinations/impl/lancedb/__init__.py new file mode 100644 index 0000000000..bc6974b072 --- /dev/null +++ b/dlt/destinations/impl/lancedb/__init__.py @@ -0,0 +1 @@ +from dlt.destinations.impl.lancedb.lancedb_adapter import lancedb_adapter diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py new file mode 100644 index 0000000000..ba3a8b49d9 --- /dev/null +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -0,0 +1,111 @@ +import dataclasses +from typing import Optional, Final, Literal, ClassVar, List + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs.base_configuration import ( + BaseConfiguration, + CredentialsConfiguration, +) +from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.typing import TSecretStrValue +from dlt.common.utils import digest128 + + +@configspec +class LanceDBCredentials(CredentialsConfiguration): + uri: Optional[str] = ".lancedb" + """LanceDB database URI. Defaults to local, on-disk instance. + + The available schemas are: + + - `/path/to/database` - local database. + - `db://host:port` - remote database (LanceDB cloud). + """ + api_key: Optional[TSecretStrValue] = None + """API key for the remote connections (LanceDB cloud).""" + embedding_model_provider_api_key: Optional[str] = None + """API key for the embedding model provider.""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "uri", + "api_key", + "embedding_model_provider_api_key", + ] + + +@configspec +class LanceDBClientOptions(BaseConfiguration): + max_retries: Optional[int] = 3 + """`EmbeddingFunction` class wraps the calls for source and query embedding + generation inside a rate limit handler that retries the requests with exponential + backoff after successive failures. + + You can tune it by setting it to a different number, or disable it by setting it to 0.""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "max_retries", + ] + + +TEmbeddingProvider = Literal[ + "gemini-text", + "bedrock-text", + "cohere", + "gte-text", + "imagebind", + "instructor", + "open-clip", + "openai", + "sentence-transformers", + "huggingface", + "colbert", +] + + +@configspec +class LanceDBClientConfiguration(DestinationClientDwhConfiguration): + destination_type: Final[str] = dataclasses.field( # type: ignore + default="LanceDB", init=False, repr=False, compare=False + ) + credentials: LanceDBCredentials = None + dataset_separator: str = "___" + """Character for the dataset separator.""" + dataset_name: Final[Optional[str]] = dataclasses.field( # type: ignore + default=None, init=False, repr=False, compare=False + ) + + options: Optional[LanceDBClientOptions] = None + """LanceDB client options.""" + + embedding_model_provider: TEmbeddingProvider = "cohere" + """Embedding provider used for generating embeddings. Default is "cohere". You can find the full list of + providers at https://github.com/lancedb/lancedb/tree/main/python/python/lancedb/embeddings as well as + https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/.""" + embedding_model: str = "embed-english-v3.0" + """The model used by the embedding provider for generating embeddings. + Check with the embedding provider which options are available. + Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/.""" + embedding_model_dimensions: Optional[int] = None + """The dimensions of the embeddings generated. In most cases it will be automatically inferred, by LanceDB, + but it is configurable in rare cases. + + Make sure it corresponds with the associated embedding model's dimensionality.""" + vector_field_name: str = "vector__" + """Name of the special field to store the vector embeddings.""" + id_field_name: str = "id__" + """Name of the special field to manage deduplication.""" + sentinel_table_name: str = "dltSentinelTable" + """Name of the sentinel table that encapsulates datasets. Since LanceDB has no + concept of schemas, this table serves as a proxy to group related dlt tables together.""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "embedding_model", + "embedding_model_provider", + ] + + def fingerprint(self) -> str: + """Returns a fingerprint of a connection string.""" + + if self.credentials and self.credentials.uri: + return digest128(self.credentials.uri) + return "" diff --git a/dlt/destinations/impl/lancedb/exceptions.py b/dlt/destinations/impl/lancedb/exceptions.py new file mode 100644 index 0000000000..35b86ce76c --- /dev/null +++ b/dlt/destinations/impl/lancedb/exceptions.py @@ -0,0 +1,30 @@ +from functools import wraps +from typing import ( + Any, +) + +from lancedb.exceptions import MissingValueError, MissingColumnError # type: ignore + +from dlt.common.destination.exceptions import ( + DestinationUndefinedEntity, + DestinationTerminalException, +) +from dlt.common.destination.reference import JobClientBase +from dlt.common.typing import TFun + + +def lancedb_error(f: TFun) -> TFun: + @wraps(f) + def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: + try: + return f(self, *args, **kwargs) + except ( + FileNotFoundError, + MissingValueError, + MissingColumnError, + ) as status_ex: + raise DestinationUndefinedEntity(status_ex) from status_ex + except Exception as e: + raise DestinationTerminalException(e) from e + + return _wrap # type: ignore[return-value] diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py new file mode 100644 index 0000000000..f2e17168b9 --- /dev/null +++ b/dlt/destinations/impl/lancedb/factory.py @@ -0,0 +1,53 @@ +import typing as t + +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBCredentials, + LanceDBClientConfiguration, +) + + +if t.TYPE_CHECKING: + from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient + + +class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): + spec = LanceDBClientConfiguration + + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl"] + + caps.max_identifier_length = 200 + caps.max_column_identifier_length = 1024 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 8 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = False + + caps.decimal_precision = (38, 18) + caps.timestamp_precision = 6 + + return caps + + @property + def client_class(self) -> t.Type["LanceDBClient"]: + from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient + + return LanceDBClient + + def __init__( + self, + credentials: t.Union[LanceDBCredentials, t.Dict[str, t.Any]] = None, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py new file mode 100644 index 0000000000..bb33632b48 --- /dev/null +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -0,0 +1,58 @@ +from typing import Any + +from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns +from dlt.destinations.utils import ensure_resource +from dlt.extract import DltResource + + +VECTORIZE_HINT = "x-lancedb-embed" + + +def lancedb_adapter( + data: Any, + embed: TColumnNames = None, +) -> DltResource: + """Prepares data for the LanceDB destination by specifying which columns should be embedded. + + Args: + data (Any): The data to be transformed. It can be raw data or an instance + of DltResource. If raw data, the function wraps it into a DltResource + object. + embed (TColumnNames, optional): Specify columns to generate embeddings for. + It can be a single column name as a string, or a list of column names. + + Returns: + DltResource: A resource with applied LanceDB-specific hints. + + Raises: + ValueError: If input for `embed` invalid or empty. + + Examples: + >>> data = [{"name": "Marcel", "description": "Moonbase Engineer"}] + >>> lancedb_adapter(data, embed="description") + [DltResource with hints applied] + """ + resource = ensure_resource(data) + + column_hints: TTableSchemaColumns = {} + + if embed: + if isinstance(embed, str): + embed = [embed] + if not isinstance(embed, list): + raise ValueError( + "'embed' must be a list of column names or a single column name as a string." + ) + + for column_name in embed: + column_hints[column_name] = { + "name": column_name, + VECTORIZE_HINT: True, # type: ignore[misc] + } + + if not column_hints: + raise ValueError("A value for 'embed' must be specified.") + else: + resource.apply_hints(columns=column_hints) + + return resource diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py new file mode 100644 index 0000000000..128e2c7e7e --- /dev/null +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -0,0 +1,767 @@ +import uuid +from types import TracebackType +from typing import ( + ClassVar, + List, + Any, + cast, + Union, + Tuple, + Iterable, + Type, + Optional, + Dict, + Sequence, +) + +import lancedb # type: ignore +import pyarrow as pa +from lancedb import DBConnection +from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore +from lancedb.query import LanceQueryBuilder # type: ignore +from lancedb.table import Table # type: ignore +from numpy import ndarray +from pyarrow import Array, ChunkedArray, ArrowInvalid + +from dlt.common import json, pendulum, logger +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.exceptions import ( + DestinationUndefinedEntity, + DestinationTransientException, + DestinationTerminalException, +) +from dlt.common.destination.reference import ( + JobClientBase, + WithStateSync, + LoadJob, + StorageSchemaInfo, + StateInfo, + TLoadJobState, +) +from dlt.common.pendulum import timedelta +from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema.typing import ( + TColumnType, + TTableFormat, + TTableSchemaColumns, + TWriteDisposition, +) +from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.storages import FileStorage +from dlt.common.typing import DictStrAny +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBClientConfiguration, +) +from dlt.destinations.impl.lancedb.exceptions import ( + lancedb_error, +) +from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT +from dlt.destinations.impl.lancedb.schema import ( + make_arrow_field_schema, + make_arrow_table_schema, + TArrowSchema, + NULL_SCHEMA, + TArrowField, +) +from dlt.destinations.impl.lancedb.utils import ( + list_merge_identifiers, + generate_uuid, + set_non_standard_providers_environment_variables, +) +from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.type_mapping import TypeMapper + + +TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} + + +class LanceDBTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "text": pa.string(), + "double": pa.float64(), + "bool": pa.bool_(), + "bigint": pa.int64(), + "binary": pa.binary(), + "date": pa.date32(), + "complex": pa.string(), + } + + sct_to_dbt = {} + + dbt_to_sct = { + pa.string(): "text", + pa.float64(): "double", + pa.bool_(): "bool", + pa.int64(): "bigint", + pa.binary(): "binary", + pa.date32(): "date", + } + + def to_db_decimal_type( + self, precision: Optional[int], scale: Optional[int] + ) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(precision, scale) + return pa.decimal128(precision, scale) + + def to_db_datetime_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.TimestampType: + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.timestamp(unit, "UTC") + + def to_db_time_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.Time64Type: + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.time64(unit) + + def from_db_type( + self, + db_type: pa.DataType, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnType: + if isinstance(db_type, pa.TimestampType): + return dict( + data_type="timestamp", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Time64Type): + return dict( + data_type="time", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Decimal128Type): + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return cast(TColumnType, dict(data_type="wei")) + return dict(data_type="decimal", precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + + +def upload_batch( + records: List[DictStrAny], + /, + *, + db_client: DBConnection, + table_name: str, + write_disposition: TWriteDisposition, + id_field_name: Optional[str] = None, +) -> None: + """Inserts records into a LanceDB table with automatic embedding computation. + + Args: + records: The data to be inserted as payload. + db_client: The LanceDB client connection. + table_name: The name of the table to insert into. + id_field_name: The name of the ID field for update/merge operations. + write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. + + Raises: + ValueError: If the write disposition is unsupported, or `id_field_name` is not + provided for update/merge operations. + """ + + try: + tbl = db_client.open_table(table_name) + tbl.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Batch WILL BE RETRIED" + ) from e + + try: + if write_disposition in ("append", "skip"): + tbl.add(records) + elif write_disposition == "replace": + tbl.add(records, mode="overwrite") + elif write_disposition == "merge": + if not id_field_name: + raise ValueError("To perform a merge update, 'id_field_name' must be specified.") + tbl.merge_insert( + id_field_name + ).when_matched_update_all().when_not_matched_insert_all().execute(records) + else: + raise DestinationTerminalException( + f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" + " failed AND WILL **NOT** BE RETRIED." + ) + except ArrowInvalid as e: + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e + + +class LanceDBClient(JobClientBase, WithStateSync): + """LanceDB destination handler.""" + + model_func: TextEmbeddingFunction + + def __init__( + self, + schema: Schema, + config: LanceDBClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) + self.config: LanceDBClientConfiguration = config + self.db_client: DBConnection = lancedb.connect( + uri=self.config.credentials.uri, + api_key=self.config.credentials.api_key, + read_consistency_interval=timedelta(0), + ) + self.registry = EmbeddingFunctionRegistry.get_instance() + self.type_mapper = LanceDBTypeMapper(self.capabilities) + self.sentinel_table_name = config.sentinel_table_name + + embedding_model_provider = self.config.embedding_model_provider + + # LanceDB doesn't provide a standardized way to set API keys across providers. + # Some use ENV variables and others allow passing api key as an argument. + # To account for this, we set provider environment variable as well. + set_non_standard_providers_environment_variables( + embedding_model_provider, + self.config.credentials.embedding_model_provider_api_key, + ) + # Use the monkey-patched implementation if openai was chosen. + if embedding_model_provider == "openai": + from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings + + self.model_func = PatchedOpenAIEmbeddings( + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) + else: + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) + + self.vector_field_name = self.config.vector_field_name + self.id_field_name = self.config.id_field_name + + @property + def dataset_name(self) -> str: + return self.config.normalize_dataset_name(self.schema) + + @property + def sentinel_table(self) -> str: + return self.make_qualified_table_name(self.sentinel_table_name) + + def make_qualified_table_name(self, table_name: str) -> str: + return ( + f"{self.dataset_name}{self.config.dataset_separator}{table_name}" + if self.dataset_name + else table_name + ) + + def get_table_schema(self, table_name: str) -> TArrowSchema: + schema_table: Table = self.db_client.open_table(table_name) + schema_table.checkout_latest() + schema = schema_table.schema + return cast( + TArrowSchema, + schema, + ) + + @lancedb_error + def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + """Create a LanceDB Table from the provided LanceModel or PyArrow schema. + + Args: + schema: The table schema to create. + table_name: The name of the table to create. + mode (): The mode to use when creating the table. Can be either "create" or "overwrite". + By default, if the table already exists, an exception is raised. + If you want to overwrite the table, use mode="overwrite". + """ + return self.db_client.create_table(table_name, schema=schema, mode=mode) + + def delete_table(self, table_name: str) -> None: + """Delete a LanceDB table. + + Args: + table_name: The name of the table to delete. + """ + self.db_client.drop_table(table_name) + + def query_table( + self, + table_name: str, + query: Union[ + List[Any], ndarray[Any, Any], Array, ChunkedArray, str, Tuple[Any], None + ] = None, + ) -> LanceQueryBuilder: + """Query a LanceDB table. + + Args: + table_name: The name of the table to query. + query: The targeted vector to search for. + + Returns: + A LanceDB query builder. + """ + query_table: Table = self.db_client.open_table(table_name) + query_table.checkout_latest() + return query_table.search(query=query) + + @lancedb_error + def _get_table_names(self) -> List[str]: + """Return all tables in the dataset, excluding the sentinel table.""" + if self.dataset_name: + prefix = f"{self.dataset_name}{self.config.dataset_separator}" + table_names = [ + table_name + for table_name in self.db_client.table_names() + if table_name.startswith(prefix) + ] + else: + table_names = self.db_client.table_names() + + return [table_name for table_name in table_names if table_name != self.sentinel_table] + + @lancedb_error + def drop_storage(self) -> None: + """Drop the dataset from the LanceDB instance. + + Deletes all tables in the dataset and all data, as well as sentinel table associated with them. + + If the dataset name was not provided, it deletes all the tables in the current schema. + """ + for table_name in self._get_table_names(): + self.db_client.drop_table(table_name) + + self._delete_sentinel_table() + + @lancedb_error + def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: + if not self.is_storage_initialized(): + self._create_sentinel_table() + elif truncate_tables: + for table_name in truncate_tables: + fq_table_name = self.make_qualified_table_name(table_name) + if not self.table_exists(fq_table_name): + continue + schema = self.get_table_schema(fq_table_name) + self.db_client.drop_table(fq_table_name) + self.create_table( + table_name=fq_table_name, + schema=schema, + ) + + @lancedb_error + def is_storage_initialized(self) -> bool: + return self.table_exists(self.sentinel_table) + + def _create_sentinel_table(self) -> Table: + """Create an empty table to indicate that the storage is initialized.""" + return self.create_table(schema=NULL_SCHEMA, table_name=self.sentinel_table) + + def _delete_sentinel_table(self) -> None: + """Delete the sentinel table.""" + self.db_client.drop_table(self.sentinel_table) + + @lancedb_error + def update_stored_schema( + self, + only_tables: Iterable[str] = None, + expected_update: TSchemaTables = None, + ) -> Optional[TSchemaTables]: + super().update_stored_schema(only_tables, expected_update) + applied_update: TSchemaTables = {} + + try: + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + except DestinationUndefinedEntity: + schema_info = None + + if schema_info is None: + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + "not found in the storage. upgrading" + ) + self._execute_schema_update(only_tables) + else: + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + f"inserted at {schema_info.inserted_at} found " + "in storage, no upgrade required" + ) + return applied_update + + def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: + table_schema: TTableSchemaColumns = {} + + try: + fq_table_name = self.make_qualified_table_name(table_name) + + table: Table = self.db_client.open_table(fq_table_name) + table.checkout_latest() + arrow_schema: TArrowSchema = table.schema + except FileNotFoundError: + return False, table_schema + + field: TArrowField + for field in arrow_schema: + name = self.schema.naming.normalize_identifier(field.name) + print(field.type) + print(field.name) + table_schema[name] = { + "name": name, + **self.type_mapper.from_db_type(field.type), + } + return True, table_schema + + @lancedb_error + def add_table_fields( + self, table_name: str, field_schemas: List[TArrowField] + ) -> Optional[Table]: + """Add multiple fields to the LanceDB table at once. + + Args: + table_name: The name of the table to create the fields on. + field_schemas: The list of fields to create. + """ + table: Table = self.db_client.open_table(table_name) + table.checkout_latest() + arrow_table = table.to_arrow() + + # Check if any of the new fields already exist in the table. + existing_fields = set(arrow_table.schema.names) + new_fields = [field for field in field_schemas if field.name not in existing_fields] + + if not new_fields: + # All fields already present, skip. + return None + + null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + + for field, null_array in zip(new_fields, null_arrays): + arrow_table = arrow_table.append_column(field, null_array) + + try: + return self.db_client.create_table(table_name, arrow_table, mode="overwrite") + except OSError: + # Error occurred while creating the table, skip. + return None + + def _execute_schema_update(self, only_tables: Iterable[str]) -> None: + for table_name in only_tables or self.schema.tables: + exists, existing_columns = self.get_storage_table(table_name) + new_columns = self.schema.get_new_table_columns(table_name, existing_columns) + print(table_name) + print(new_columns) + embedding_fields: List[str] = get_columns_names_with_prop( + self.schema.get_table(table_name), VECTORIZE_HINT + ) + logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + if len(new_columns) > 0: + if exists: + field_schemas: List[TArrowField] = [ + make_arrow_field_schema(column["name"], column, self.type_mapper) + for column in new_columns + ] + fq_table_name = self.make_qualified_table_name(table_name) + self.add_table_fields(fq_table_name, field_schemas) + else: + if table_name not in self.schema.dlt_table_names(): + embedding_fields = get_columns_names_with_prop( + self.schema.get_table(table_name=table_name), VECTORIZE_HINT + ) + vector_field_name = self.vector_field_name + id_field_name = self.id_field_name + embedding_model_func = self.model_func + embedding_model_dimensions = self.config.embedding_model_dimensions + else: + embedding_fields = None + vector_field_name = None + id_field_name = None + embedding_model_func = None + embedding_model_dimensions = None + + table_schema: TArrowSchema = make_arrow_table_schema( + table_name, + schema=self.schema, + type_mapper=self.type_mapper, + embedding_fields=embedding_fields, + embedding_model_func=embedding_model_func, + embedding_model_dimensions=embedding_model_dimensions, + vector_field_name=vector_field_name, + id_field_name=id_field_name, + ) + fq_table_name = self.make_qualified_table_name(table_name) + self.create_table(fq_table_name, table_schema) + + self.update_schema_in_storage() + + @lancedb_error + def update_schema_in_storage(self) -> None: + records = [ + { + self.schema.naming.normalize_identifier("version"): self.schema.version, + self.schema.naming.normalize_identifier( + "engine_version" + ): self.schema.ENGINE_VERSION, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "version_hash" + ): self.schema.stored_version_hash, + self.schema.naming.normalize_identifier("schema"): json.dumps( + self.schema.to_dict() + ), + } + ] + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + write_disposition = self.schema.get_table(self.schema.version_table_name).get( + "write_disposition" + ) + print("UPLOAD") + upload_batch( + records, + db_client=self.db_client, + table_name=fq_version_table_name, + write_disposition=write_disposition, + ) + + @lancedb_error + def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: + """Retrieves the latest completed state for a pipeline.""" + fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + + state_table_: Table = self.db_client.open_table(fq_state_table_name) + state_table_.checkout_latest() + + loads_table_: Table = self.db_client.open_table(fq_loads_table_name) + loads_table_.checkout_latest() + + # normalize property names + p_load_id = self.schema.naming.normalize_identifier("load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") + p_status = self.schema.naming.normalize_identifier("status") + p_version = self.schema.naming.normalize_identifier("version") + p_engine_version = self.schema.naming.normalize_identifier("engine_version") + p_state = self.schema.naming.normalize_identifier("state") + p_created_at = self.schema.naming.normalize_identifier("created_at") + p_version_hash = self.schema.naming.normalize_identifier("version_hash") + + # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less + # data into memory as possible. + state_table = ( + state_table_.search() + .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) + .to_arrow() + ) + loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + + # Join arrow tables in-memory. + joined_table: pa.Table = state_table.join( + loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" + ).sort_by([(p_dlt_load_id, "descending")]) + + if joined_table.num_rows == 0: + return None + + state = joined_table.take([0]).to_pylist()[0] + return StateInfo( + version=state[p_version], + engine_version=state[p_engine_version], + pipeline_name=state[p_pipeline_name], + state=state[p_state], + created_at=pendulum.instance(state[p_created_at]), + version_hash=state[p_version_hash], + _dlt_load_id=state[p_dlt_load_id], + ) + + @lancedb_error + def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + + version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table.checkout_latest() + p_version_hash = self.schema.naming.normalize_identifier("version_hash") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + p_version = self.schema.naming.normalize_identifier("version") + p_engine_version = self.schema.naming.normalize_identifier("engine_version") + p_schema = self.schema.naming.normalize_identifier("schema") + + try: + schemas = ( + version_table.search().where( + f'`{p_version_hash}` = "{schema_hash}"', prefilter=True + ) + ).to_list() + + # LanceDB's ORDER BY clause doesn't seem to work. + # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) + except IndexError: + return None + + @lancedb_error + def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + """Retrieves newest schema from destination storage.""" + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + + version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table.checkout_latest() + p_version_hash = self.schema.naming.normalize_identifier("version_hash") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + p_version = self.schema.naming.normalize_identifier("version") + p_engine_version = self.schema.naming.normalize_identifier("engine_version") + p_schema = self.schema.naming.normalize_identifier("schema") + + try: + schemas = ( + version_table.search().where( + f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True + ) + ).to_list() + + # LanceDB's ORDER BY clause doesn't seem to work. + # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) + except IndexError: + return None + + def __exit__( + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + pass + + def __enter__(self) -> "LanceDBClient": + return self + + @lancedb_error + def complete_load(self, load_id: str) -> None: + records = [ + { + self.schema.naming.normalize_identifier("load_id"): load_id, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("status"): 0, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier( + "schema_version_hash" + ): None, # Payload schema must match the target schema. + } + ] + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + write_disposition = self.schema.get_table(self.schema.loads_table_name).get( + "write_disposition" + ) + upload_batch( + records, + db_client=self.db_client, + table_name=fq_loads_table_name, + write_disposition=write_disposition, + ) + + def restore_file_load(self, file_path: str) -> LoadJob: + return EmptyLoadJob.from_file_path(file_path, "completed") + + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + return LoadLanceDBJob( + self.schema, + table, + file_path, + type_mapper=self.type_mapper, + db_client=self.db_client, + client_config=self.config, + model_func=self.model_func, + fq_table_name=self.make_qualified_table_name(table["name"]), + ) + + def table_exists(self, table_name: str) -> bool: + return table_name in self.db_client.table_names() + + +class LoadLanceDBJob(LoadJob): + arrow_schema: TArrowSchema + + def __init__( + self, + schema: Schema, + table_schema: TTableSchema, + local_path: str, + type_mapper: LanceDBTypeMapper, + db_client: DBConnection, + client_config: LanceDBClientConfiguration, + model_func: TextEmbeddingFunction, + fq_table_name: str, + ) -> None: + file_name = FileStorage.get_file_name_from_file_path(local_path) + super().__init__(file_name) + self.schema: Schema = schema + self.table_schema: TTableSchema = table_schema + self.db_client: DBConnection = db_client + self.type_mapper: TypeMapper = type_mapper + self.table_name: str = table_schema["name"] + self.fq_table_name: str = fq_table_name + self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) + self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.embedding_model_func: TextEmbeddingFunction = model_func + self.embedding_model_dimensions: int = client_config.embedding_model_dimensions + self.id_field_name: str = client_config.id_field_name + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition", "append") + ) + + with FileStorage.open_zipsafe_ro(local_path) as f: + records: List[DictStrAny] = [json.loads(line) for line in f] + + if self.table_schema not in self.schema.dlt_tables(): + for record in records: + # Add reserved ID fields. + uuid_id = ( + generate_uuid(record, self.unique_identifiers, self.fq_table_name) + if self.unique_identifiers + else str(uuid.uuid4()) + ) + record.update({self.id_field_name: uuid_id}) + + # LanceDB expects all fields in the target arrow table to be present in the data payload. + # We add and set these missing fields, that are fields not present in the target schema, to NULL. + missing_fields = set(self.table_schema["columns"]) - set(record) + for field in missing_fields: + record[field] = None + + upload_batch( + records, + db_client=db_client, + table_name=self.fq_table_name, + write_disposition=self.write_disposition, + id_field_name=self.id_field_name, + ) + + def state(self) -> TLoadJobState: + return "completed" + + def exception(self) -> str: + raise NotImplementedError() diff --git a/dlt/destinations/impl/lancedb/models.py b/dlt/destinations/impl/lancedb/models.py new file mode 100644 index 0000000000..d90adb62bd --- /dev/null +++ b/dlt/destinations/impl/lancedb/models.py @@ -0,0 +1,34 @@ +from typing import Union, List + +import numpy as np +from lancedb.embeddings import OpenAIEmbeddings # type: ignore +from lancedb.embeddings.registry import register # type: ignore +from lancedb.embeddings.utils import TEXT # type: ignore + + +@register("openai_patched") +class PatchedOpenAIEmbeddings(OpenAIEmbeddings): + EMPTY_STRING_PLACEHOLDER: str = "___EMPTY___" + + def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: # type: ignore[type-arg] + """ + Replace empty strings with a placeholder value. + """ + + sanitized_texts = super().sanitize_input(texts) + return [self.EMPTY_STRING_PLACEHOLDER if item == "" else item for item in sanitized_texts] + + def generate_embeddings( + self, + texts: Union[List[str], np.ndarray], # type: ignore[type-arg] + ) -> List[np.array]: # type: ignore[valid-type] + """ + Generate embeddings, treating the placeholder as an empty result. + """ + embeddings: List[np.array] = super().generate_embeddings(texts) # type: ignore[valid-type] + + for i, text in enumerate(texts): + if text == self.EMPTY_STRING_PLACEHOLDER: + embeddings[i] = np.zeros(self.ndims()) + + return embeddings diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py new file mode 100644 index 0000000000..c7cceec274 --- /dev/null +++ b/dlt/destinations/impl/lancedb/schema.py @@ -0,0 +1,84 @@ +"""Utilities for creating arrow schemas from table schemas.""" + +from dlt.common.json import json +from typing import ( + List, + cast, + Optional, +) + +import pyarrow as pa +from lancedb.embeddings import TextEmbeddingFunction # type: ignore +from typing_extensions import TypeAlias + +from dlt.common.schema import Schema, TColumnSchema +from dlt.common.typing import DictStrAny +from dlt.destinations.type_mapping import TypeMapper + + +TArrowSchema: TypeAlias = pa.Schema +TArrowDataType: TypeAlias = pa.DataType +TArrowField: TypeAlias = pa.Field +NULL_SCHEMA: TArrowSchema = pa.schema([]) +"""Empty pyarrow Schema with no fields.""" + + +def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: + return {field.name: field.type for field in schema} + + +def make_arrow_field_schema( + column_name: str, + column: TColumnSchema, + type_mapper: TypeMapper, +) -> TArrowField: + """Creates a PyArrow field from a dlt column schema.""" + dtype = cast(TArrowDataType, type_mapper.to_db_type(column)) + return pa.field(column_name, dtype) + + +def make_arrow_table_schema( + table_name: str, + schema: Schema, + type_mapper: TypeMapper, + id_field_name: Optional[str] = None, + vector_field_name: Optional[str] = None, + embedding_fields: Optional[List[str]] = None, + embedding_model_func: Optional[TextEmbeddingFunction] = None, + embedding_model_dimensions: Optional[int] = None, +) -> TArrowSchema: + """Creates a PyArrow schema from a dlt schema.""" + arrow_schema: List[TArrowField] = [] + + if id_field_name: + arrow_schema.append(pa.field(id_field_name, pa.string())) + + if embedding_fields: + # User's provided dimension config, if provided, takes precedence. + vec_size = embedding_model_dimensions or embedding_model_func.ndims() + arrow_schema.append(pa.field(vector_field_name, pa.list_(pa.float32(), vec_size))) + + for column_name, column in schema.get_table_columns(table_name).items(): + field = make_arrow_field_schema(column_name, column, type_mapper) + arrow_schema.append(field) + + metadata = {} + if embedding_model_func: + # Get the registered alias if it exists, otherwise use the class name. + name = getattr( + embedding_model_func, + "__embedding_function_registry_alias__", + embedding_model_func.__class__.__name__, + ) + embedding_functions = [ + { + "source_column": source_column, + "vector_column": vector_field_name, + "name": name, + "model": embedding_model_func.safe_model_dump(), + } + for source_column in embedding_fields + ] + metadata["embedding_functions"] = json.dumps(embedding_functions).encode("utf-8") + + return pa.schema(arrow_schema, metadata=metadata) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py new file mode 100644 index 0000000000..aeacd4d34b --- /dev/null +++ b/dlt/destinations/impl/lancedb/utils.py @@ -0,0 +1,55 @@ +import os +import uuid +from typing import Sequence, Union, Dict + +from dlt.common.schema import TTableSchema +from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.typing import DictStrAny +from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider + + +PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { + "cohere": "COHERE_API_KEY", + "gemini-text": "GOOGLE_API_KEY", + "openai": "OPENAI_API_KEY", + "huggingface": "HUGGINGFACE_API_KEY", +} + + +def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: + """Generates deterministic UUID - used for deduplication. + + Args: + data (Dict[str, Any]): Arbitrary data to generate UUID for. + unique_identifiers (Sequence[str]): A list of unique identifiers. + table_name (str): LanceDB table name. + + Returns: + str: A string representation of the generated UUID. + """ + data_id = "_".join(str(data[key]) for key in unique_identifiers) + return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) + + +def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: + """Returns a list of merge keys for a table used for either merging or deduplication. + + Args: + table_schema (TTableSchema): a dlt table schema. + + Returns: + Sequence[str]: A list of unique column identifiers. + """ + if table_schema.get("write_disposition") == "merge": + primary_keys = get_columns_names_with_prop(table_schema, "primary_key") + merge_keys = get_columns_names_with_prop(table_schema, "merge_key") + if join_keys := list(set(primary_keys + merge_keys)): + return join_keys + return get_columns_names_with_prop(table_schema, "unique") + + +def set_non_standard_providers_environment_variables( + embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] +) -> None: + if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: + os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" diff --git a/dlt/destinations/impl/motherduck/__init__.py b/dlt/destinations/impl/motherduck/__init__.py index 74c0e36ef3..e69de29bb2 100644 --- a/dlt/destinations/impl/motherduck/__init__.py +++ b/dlt/destinations/impl/motherduck/__init__.py @@ -1,24 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "parquet" - caps.supported_loader_file_formats = ["parquet", "insert_values", "jsonl"] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_duckdb_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 65536 - caps.max_column_identifier_length = 65536 - caps.max_query_length = 512 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.alter_add_multi_column = False - caps.supports_truncate_command = False - - return caps diff --git a/dlt/destinations/impl/motherduck/factory.py b/dlt/destinations/impl/motherduck/factory.py index 5e35f69d75..df7418b9db 100644 --- a/dlt/destinations/impl/motherduck/factory.py +++ b/dlt/destinations/impl/motherduck/factory.py @@ -1,11 +1,13 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + from dlt.destinations.impl.motherduck.configuration import ( MotherDuckCredentials, MotherDuckClientConfiguration, ) -from dlt.destinations.impl.motherduck import capabilities if t.TYPE_CHECKING: from duckdb import DuckDBPyConnection @@ -15,8 +17,27 @@ class motherduck(Destination[MotherDuckClientConfiguration, "MotherDuckClient"]): spec = MotherDuckClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["parquet", "insert_values", "jsonl"] + caps.escape_identifier = escape_postgres_identifier + # all identifiers are case insensitive but are stored as is + caps.escape_literal = escape_duckdb_literal + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 65536 + caps.max_column_identifier_length = 65536 + caps.max_query_length = 512 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.alter_add_multi_column = False + caps.supports_truncate_command = False + + return caps @property def client_class(self) -> t.Type["MotherDuckClient"]: diff --git a/dlt/destinations/impl/motherduck/motherduck.py b/dlt/destinations/impl/motherduck/motherduck.py index c695d9715e..3a5f172864 100644 --- a/dlt/destinations/impl/motherduck/motherduck.py +++ b/dlt/destinations/impl/motherduck/motherduck.py @@ -1,20 +1,22 @@ -from typing import ClassVar - from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import Schema from dlt.destinations.impl.duckdb.duck import DuckDbClient -from dlt.destinations.impl.motherduck import capabilities from dlt.destinations.impl.motherduck.sql_client import MotherDuckSqlClient from dlt.destinations.impl.motherduck.configuration import MotherDuckClientConfiguration class MotherDuckClient(DuckDbClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: MotherDuckClientConfiguration) -> None: - super().__init__(schema, config) # type: ignore - sql_client = MotherDuckSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: MotherDuckClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) # type: ignore + sql_client = MotherDuckSqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) self.config: MotherDuckClientConfiguration = config # type: ignore self.sql_client: MotherDuckSqlClient = sql_client diff --git a/dlt/destinations/impl/motherduck/sql_client.py b/dlt/destinations/impl/motherduck/sql_client.py index 7990f90947..40157406ab 100644 --- a/dlt/destinations/impl/motherduck/sql_client.py +++ b/dlt/destinations/impl/motherduck/sql_client.py @@ -1,41 +1,22 @@ -import duckdb +from typing import Optional -from contextlib import contextmanager -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence -from dlt.common.destination import DestinationCapabilitiesContext - -from dlt.destinations.exceptions import ( - DatabaseTerminalException, - DatabaseTransientException, - DatabaseUndefinedRelation, -) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame -from dlt.destinations.sql_client import ( - SqlClientBase, - DBApiCursorImpl, - raise_database_error, - raise_open_connection_error, -) - -from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient, DuckDBDBApiCursorImpl -from dlt.destinations.impl.motherduck import capabilities +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.impl.motherduck.configuration import MotherDuckCredentials class MotherDuckSqlClient(DuckDbSqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, dataset_name: str, credentials: MotherDuckCredentials) -> None: - super().__init__(dataset_name, credentials) + def __init__( + self, + dataset_name: str, + credentials: MotherDuckCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(dataset_name, credentials, capabilities) self.database_name = credentials.database - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - database_name = ( - self.capabilities.escape_identifier(self.database_name) - if escape - else self.database_name - ) - dataset_name = ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - return f"{database_name}.{dataset_name}" + def catalog_name(self, escape: bool = True) -> Optional[str]: + database_name = self.database_name + if escape: + database_name = self.capabilities.escape_identifier(database_name) + return database_name diff --git a/dlt/destinations/impl/mssql/__init__.py b/dlt/destinations/impl/mssql/__init__.py index f7768d9238..e69de29bb2 100644 --- a/dlt/destinations/impl/mssql/__init__.py +++ b/dlt/destinations/impl/mssql/__init__.py @@ -1,29 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.wei import EVM_DECIMAL_PRECISION - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_mssql_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - # https://learn.microsoft.com/en-us/sql/sql-server/maximum-capacity-specifications-for-sql-server?view=sql-server-ver16&redirectedfrom=MSDN - caps.max_identifier_length = 128 - caps.max_column_identifier_length = 128 - # A SQL Query can be a varchar(max) but is shown as limited to 65,536 * Network Packet - caps.max_query_length = 65536 * 10 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 2**30 - 1 - caps.is_max_text_data_type_length_in_bytes = False - caps.supports_ddl_transactions = True - caps.max_rows_per_insert = 1000 - caps.timestamp_precision = 7 - - return caps diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index 8a50ecc6d2..64d87065f3 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -93,6 +93,7 @@ class MsSqlClientConfiguration(DestinationClientDwhWithStagingConfiguration): credentials: MsSqlCredentials = None create_indexes: bool = False + has_case_sensitive_identifiers: bool = False def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" diff --git a/dlt/destinations/impl/mssql/factory.py b/dlt/destinations/impl/mssql/factory.py index 2e19d7c2a8..6912510995 100644 --- a/dlt/destinations/impl/mssql/factory.py +++ b/dlt/destinations/impl/mssql/factory.py @@ -1,30 +1,58 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.normalizers.naming.naming import NamingConvention +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration -from dlt.destinations.impl.mssql import capabilities if t.TYPE_CHECKING: - from dlt.destinations.impl.mssql.mssql import MsSqlClient + from dlt.destinations.impl.mssql.mssql import MsSqlJobClient -class mssql(Destination[MsSqlClientConfiguration, "MsSqlClient"]): +class mssql(Destination[MsSqlClientConfiguration, "MsSqlJobClient"]): spec = MsSqlClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + # mssql is by default case insensitive and stores identifiers as is + # case sensitivity can be changed by database collation so we allow to reconfigure + # capabilities in the mssql factory + caps.escape_identifier = escape_postgres_identifier + caps.escape_literal = escape_mssql_literal + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + # https://learn.microsoft.com/en-us/sql/sql-server/maximum-capacity-specifications-for-sql-server?view=sql-server-ver16&redirectedfrom=MSDN + caps.max_identifier_length = 128 + caps.max_column_identifier_length = 128 + # A SQL Query can be a varchar(max) but is shown as limited to 65,536 * Network Packet + caps.max_query_length = 65536 * 10 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 2**30 - 1 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = True + caps.max_rows_per_insert = 1000 + caps.timestamp_precision = 7 + + return caps @property - def client_class(self) -> t.Type["MsSqlClient"]: - from dlt.destinations.impl.mssql.mssql import MsSqlClient + def client_class(self) -> t.Type["MsSqlJobClient"]: + from dlt.destinations.impl.mssql.mssql import MsSqlJobClient - return MsSqlClient + return MsSqlJobClient def __init__( self, credentials: t.Union[MsSqlCredentials, t.Dict[str, t.Any], str] = None, - create_indexes: bool = True, + create_indexes: bool = False, + has_case_sensitive_identifiers: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -37,12 +65,27 @@ def __init__( credentials: Credentials to connect to the mssql database. Can be an instance of `MsSqlCredentials` or a connection string in the format `mssql://user:password@host:port/database` create_indexes: Should unique indexes be created + has_case_sensitive_identifiers: Are identifiers used by mssql database case sensitive (following the collation) **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, create_indexes=create_indexes, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: MsSqlClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + if config.has_case_sensitive_identifiers: + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 6f364c8af1..25aab5c52a 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,19 +1,15 @@ -from typing import ClassVar, Dict, Optional, Sequence, List, Any, Tuple +from typing import Dict, Optional, Sequence, List, Any from dlt.common.exceptions import TerminalValueError -from dlt.common.wei import EVM_DECIMAL_PRECISION from dlt.common.destination.reference import NewLoadJob from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat -from dlt.common.utils import uniq_id from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.impl.mssql import capabilities from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration from dlt.destinations.sql_client import SqlClientBase @@ -145,11 +141,16 @@ def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) return "#" + name -class MsSqlClient(InsertValuesJobClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None: - sql_client = PyOdbcMsSqlClient(config.normalize_dataset_name(schema), config.credentials) +class MsSqlJobClient(InsertValuesJobClient): + def __init__( + self, + schema: Schema, + config: MsSqlClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = PyOdbcMsSqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) super().__init__(schema, config, sql_client) self.config: MsSqlClientConfiguration = config self.sql_client = sql_client @@ -180,7 +181,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in self.active_hints.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}" def _create_replace_followup_jobs( diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index db043bae25..a360670e77 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -1,4 +1,3 @@ -import platform import struct from datetime import datetime, timedelta, timezone # noqa: I251 @@ -23,7 +22,6 @@ ) from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -from dlt.destinations.impl.mssql import capabilities def handle_datetimeoffset(dto_value: bytes) -> datetime: @@ -43,10 +41,14 @@ def handle_datetimeoffset(dto_value: bytes) -> datetime: class PyOdbcMsSqlClient(SqlClientBase[pyodbc.Connection], DBTransaction): dbapi: ClassVar[DBApi] = pyodbc - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: MsSqlCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + credentials: MsSqlCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, capabilities) self._conn: pyodbc.Connection = None self.credentials = credentials @@ -104,14 +106,14 @@ def drop_dataset(self) -> None: # Drop all views rows = self.execute_sql( "SELECT table_name FROM information_schema.views WHERE table_schema = %s;", - self.dataset_name, + self.capabilities.casefold_identifier(self.dataset_name), ) view_names = [row[0] for row in rows] self._drop_views(*view_names) # Drop all tables rows = self.execute_sql( "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;", - self.dataset_name, + self.capabilities.casefold_identifier(self.dataset_name), ) table_names = [row[0] for row in rows] self.drop_tables(*table_names) @@ -158,11 +160,6 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB except pyodbc.Error as outer: raise outer - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, pyodbc.ProgrammingError): diff --git a/dlt/destinations/impl/postgres/__init__.py b/dlt/destinations/impl/postgres/__init__.py index bdb9297210..e69de29bb2 100644 --- a/dlt/destinations/impl/postgres/__init__.py +++ b/dlt/destinations/impl/postgres/__init__.py @@ -1,27 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.wei import EVM_DECIMAL_PRECISION - - -def capabilities() -> DestinationCapabilitiesContext: - # https://www.postgresql.org/docs/current/limits.html - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values", "csv"] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_postgres_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (2 * EVM_DECIMAL_PRECISION, EVM_DECIMAL_PRECISION) - caps.max_identifier_length = 63 - caps.max_column_identifier_length = 63 - caps.max_query_length = 32 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - - return caps diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index ae0b5200b2..13bdc7f6b2 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -1,6 +1,7 @@ import dataclasses -from typing import Dict, Final, ClassVar, Any, List, TYPE_CHECKING, Union +from typing import Dict, Final, ClassVar, Any, List, Optional +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials @@ -37,6 +38,9 @@ class PostgresClientConfiguration(DestinationClientDwhWithStagingConfiguration): create_indexes: bool = True + csv_format: Optional[CsvFormatConfiguration] = None + """Optional csv format configuration""" + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/postgres/factory.py b/dlt/destinations/impl/postgres/factory.py index 68d72f890a..b873bf97d5 100644 --- a/dlt/destinations/impl/postgres/factory.py +++ b/dlt/destinations/impl/postgres/factory.py @@ -1,12 +1,15 @@ import typing as t +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.wei import EVM_DECIMAL_PRECISION from dlt.destinations.impl.postgres.configuration import ( PostgresCredentials, PostgresClientConfiguration, ) -from dlt.destinations.impl.postgres import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.postgres.postgres import PostgresClient @@ -15,8 +18,32 @@ class postgres(Destination[PostgresClientConfiguration, "PostgresClient"]): spec = PostgresClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + # https://www.postgresql.org/docs/current/limits.html + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values", "csv"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + caps.escape_identifier = escape_postgres_identifier + # postgres has case sensitive identifiers but by default + # it folds them to lower case which makes them case insensitive + # https://stackoverflow.com/questions/20878932/are-postgresql-column-names-case-sensitive + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = True + caps.escape_literal = escape_postgres_literal + caps.has_case_sensitive_identifiers = True + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (2 * EVM_DECIMAL_PRECISION, EVM_DECIMAL_PRECISION) + caps.max_identifier_length = 63 + caps.max_column_identifier_length = 63 + caps.max_query_length = 32 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + + return caps @property def client_class(self) -> t.Type["PostgresClient"]: @@ -28,6 +55,7 @@ def __init__( self, credentials: t.Union[PostgresCredentials, t.Dict[str, t.Any], str] = None, create_indexes: bool = True, + csv_format: t.Optional[CsvFormatConfiguration] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -40,11 +68,13 @@ def __init__( credentials: Credentials to connect to the postgres database. Can be an instance of `PostgresCredentials` or a connection string in the format `postgres://user:password@host:port/database` create_indexes: Should unique indexes be created + csv_format: Formatting options for csv file format **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, create_indexes=create_indexes, + csv_format=csv_format, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 11cee208b1..7b173a7711 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -1,5 +1,11 @@ -from typing import ClassVar, Dict, Optional, Sequence, List, Any - +from typing import Dict, Optional, Sequence, List, Any + +from dlt.common import logger +from dlt.common.data_writers.configuration import CsvFormatConfiguration +from dlt.common.destination.exceptions import ( + DestinationInvalidFileFormat, + DestinationTerminalException, +) from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.exceptions import TerminalValueError @@ -9,7 +15,6 @@ from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.impl.postgres import capabilities from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration from dlt.destinations.sql_client import SqlClientBase @@ -106,21 +111,85 @@ def generate_sql( class PostgresCsvCopyJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: Psycopg2SqlClient) -> None: + def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient") -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) + config = client.config + sql_client = client.sql_client + csv_format = config.csv_format or CsvFormatConfiguration() + table_name = table["name"] + sep = csv_format.delimiter + if csv_format.on_error_continue: + logger.warning( + f"When processing {file_path} on table {table_name} Postgres csv reader does not" + " support on_error_continue" + ) with FileStorage.open_zipsafe_ro(file_path, "rb") as f: - # all headers in first line - headers = f.readline().decode("utf-8").strip() - # quote headers if not quoted - all special keywords like "binary" must be quoted - headers = ",".join(h if h.startswith('"') else f'"{h}"' for h in headers.split(",")) + if csv_format.include_header: + # all headers in first line + headers_row: str = f.readline().decode(csv_format.encoding).strip() + split_headers = headers_row.split(sep) + else: + # read first row to figure out the headers + split_first_row: str = f.readline().decode(csv_format.encoding).strip().split(sep) + split_headers = list(client.schema.get_table_columns(table_name).keys()) + if len(split_first_row) > len(split_headers): + raise DestinationInvalidFileFormat( + "postgres", + "csv", + file_path, + f"First row {split_first_row} has more rows than columns {split_headers} in" + f" table {table_name}", + ) + if len(split_first_row) < len(split_headers): + logger.warning( + f"First row {split_first_row} has less rows than columns {split_headers} in" + f" table {table_name}. We will not load data to superfluous columns." + ) + split_headers = split_headers[: len(split_first_row)] + # stream the first row again + f.seek(0) + + # normalized and quoted headers + split_headers = [ + sql_client.escape_column_name(h.strip('"'), escape=True) for h in split_headers + ] + split_null_headers = [] + split_columns = [] + # detect columns with NULL to use in FORCE NULL + # detect headers that are not in columns + for col in client.schema.get_table_columns(table_name).values(): + norm_col = sql_client.escape_column_name(col["name"], escape=True) + split_columns.append(norm_col) + if norm_col in split_headers and col.get("nullable", True): + split_null_headers.append(norm_col) + split_unknown_headers = set(split_headers).difference(split_columns) + if split_unknown_headers: + raise DestinationInvalidFileFormat( + "postgres", + "csv", + file_path, + f"Following headers {split_unknown_headers} cannot be matched to columns" + f" {split_columns} of table {table_name}.", + ) + + # use comma to join + headers = ",".join(split_headers) + if split_null_headers: + null_headers = f"FORCE_NULL({','.join(split_null_headers)})," + else: + null_headers = "" + qualified_table_name = sql_client.make_qualified_table_name(table_name) copy_sql = ( - "COPY %s (%s) FROM STDIN WITH (FORMAT CSV, DELIMITER ',', NULL '', FORCE_NULL(%s))" + "COPY %s (%s) FROM STDIN WITH (FORMAT CSV, DELIMITER '%s', NULL ''," + " %s ENCODING '%s')" % ( qualified_table_name, headers, - headers, + sep, + null_headers, + csv_format.encoding, ) ) with sql_client.begin_transaction(): @@ -135,10 +204,15 @@ def exception(self) -> str: class PostgresClient(InsertValuesJobClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: - sql_client = Psycopg2SqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: PostgresClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = Psycopg2SqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) super().__init__(schema, config, sql_client) self.config: PostgresClientConfiguration = config self.sql_client: Psycopg2SqlClient = sql_client @@ -148,7 +222,7 @@ def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job and file_path.endswith("csv"): - job = PostgresCsvCopyJob(table["name"], file_path, self.sql_client) + job = PostgresCsvCopyJob(table, file_path, self) return job def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -157,7 +231,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in self.active_hints.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return ( f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index 366ed243ef..38bfc212d5 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -26,15 +26,18 @@ ) from dlt.destinations.impl.postgres.configuration import PostgresCredentials -from dlt.destinations.impl.postgres import capabilities class Psycopg2SqlClient(SqlClientBase["psycopg2.connection"], DBTransaction): dbapi: ClassVar[DBApi] = psycopg2 - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: PostgresCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + credentials: PostgresCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, capabilities) self._conn: psycopg2.connection = None self.credentials = credentials @@ -112,11 +115,6 @@ def execute_fragments( composed = Composed(sql if isinstance(sql, Composable) else SQL(sql) for sql in fragments) return self.execute_sql(composed, *args, **kwargs) - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - def _reset_connection(self) -> None: # self._conn.autocommit = True self._conn.reset() diff --git a/dlt/destinations/impl/qdrant/__init__.py b/dlt/destinations/impl/qdrant/__init__.py index 1a2c466b14..e69de29bb2 100644 --- a/dlt/destinations/impl/qdrant/__init__.py +++ b/dlt/destinations/impl/qdrant/__init__.py @@ -1,18 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] - - caps.max_identifier_length = 200 - caps.max_column_identifier_length = 1024 - caps.max_query_length = 8 * 1024 * 1024 - caps.is_max_query_length_in_bytes = False - caps.max_text_data_type_length = 8 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = False - caps.supports_ddl_transactions = False - - return caps diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index fd11cc7dcb..4d1ed1234d 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -18,6 +18,8 @@ class QdrantCredentials(CredentialsConfiguration): location: Optional[str] = None # API key for authentication in Qdrant Cloud. Default: `None` api_key: Optional[str] = None + # Persistence path for QdrantLocal. Default: `None` + path: Optional[str] = None def __str__(self) -> str: return self.location or "localhost" @@ -44,7 +46,7 @@ class QdrantClientOptions(BaseConfiguration): # Default: `None` host: Optional[str] = None # Persistence path for QdrantLocal. Default: `None` - path: Optional[str] = None + # path: Optional[str] = None @configspec diff --git a/dlt/destinations/impl/qdrant/factory.py b/dlt/destinations/impl/qdrant/factory.py index df9cd64871..defd29a03a 100644 --- a/dlt/destinations/impl/qdrant/factory.py +++ b/dlt/destinations/impl/qdrant/factory.py @@ -3,7 +3,6 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.destinations.impl.qdrant.configuration import QdrantCredentials, QdrantClientConfiguration -from dlt.destinations.impl.qdrant import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient @@ -12,8 +11,20 @@ class qdrant(Destination[QdrantClientConfiguration, "QdrantClient"]): spec = QdrantClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl"] + caps.has_case_sensitive_identifiers = True + caps.max_identifier_length = 200 + caps.max_column_identifier_length = 1024 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 8 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = False + + return caps @property def client_class(self) -> t.Type["QdrantClient"]: diff --git a/dlt/destinations/impl/qdrant/qdrant_client.py b/dlt/destinations/impl/qdrant/qdrant_client.py index 9898b28c86..51915c5536 100644 --- a/dlt/destinations/impl/qdrant/qdrant_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_client.py @@ -1,19 +1,25 @@ from types import TracebackType -from typing import ClassVar, Optional, Sequence, List, Dict, Type, Iterable, Any, IO +from typing import Optional, Sequence, List, Dict, Type, Iterable, Any from dlt.common import logger from dlt.common.json import json from dlt.common.pendulum import pendulum from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + loads_table, + normalize_table_identifiers, + version_table, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync from dlt.common.storages import FileStorage +from dlt.common.time import precise_time from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo -from dlt.destinations.impl.qdrant import capabilities +from dlt.destinations.utils import get_pipeline_state_query_columns from dlt.destinations.impl.qdrant.configuration import QdrantClientConfiguration from dlt.destinations.impl.qdrant.qdrant_adapter import VECTORIZE_HINT @@ -49,21 +55,24 @@ def __init__( if self.unique_identifiers else uuid.uuid4() ) - embedding_doc = self._get_embedding_doc(data) payloads.append(data) ids.append(point_id) - docs.append(embedding_doc) - - embedding_model = db_client._get_or_init_model(db_client.embedding_model_name) - embeddings = list( - embedding_model.embed( - docs, - batch_size=self.config.embedding_batch_size, - parallel=self.config.embedding_parallelism, + if len(self.embedding_fields) > 0: + docs.append(self._get_embedding_doc(data)) + + if len(self.embedding_fields) > 0: + embedding_model = db_client._get_or_init_model(db_client.embedding_model_name) + embeddings = list( + embedding_model.embed( + docs, + batch_size=self.config.embedding_batch_size, + parallel=self.config.embedding_parallelism, + ) ) - ) - vector_name = db_client.get_vector_field_name() - embeddings = [{vector_name: embedding.tolist()} for embedding in embeddings] + vector_name = db_client.get_vector_field_name() + embeddings = [{vector_name: embedding.tolist()} for embedding in embeddings] + else: + embeddings = [{}] * len(ids) assert len(embeddings) == len(payloads) == len(ids) self._upload_data(vectors=embeddings, ids=ids, payloads=payloads) @@ -126,7 +135,7 @@ def _generate_uuid( collection_name (str): Qdrant collection name. Returns: - str: A string representation of the genrated UUID + str: A string representation of the generated UUID """ data_id = "_".join(str(data[key]) for key in unique_identifiers) return str(uuid.uuid5(uuid.NAMESPACE_DNS, collection_name + data_id)) @@ -141,20 +150,25 @@ def exception(self) -> str: class QdrantClient(JobClientBase, WithStateSync): """Qdrant Destination Handler""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - state_properties: ClassVar[List[str]] = [ - "version", - "engine_version", - "pipeline_name", - "state", - "created_at", - "_dlt_load_id", - ] - - def __init__(self, schema: Schema, config: QdrantClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: QdrantClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) + # get definitions of the dlt tables, normalize column names and keep for later use + version_table_ = normalize_table_identifiers(version_table(), schema.naming) + self.version_collection_properties = list(version_table_["columns"].keys()) + loads_table_ = normalize_table_identifiers(loads_table(), schema.naming) + self.loads_collection_properties = list(loads_table_["columns"].keys()) + state_table_ = normalize_table_identifiers( + get_pipeline_state_query_columns(), schema.naming + ) + self.pipeline_state_properties = list(state_table_["columns"].keys()) + self.config: QdrantClientConfiguration = config - self.db_client: QC = QdrantClient._create_db_client(config) + self.db_client: QC = None self.model = config.model @property @@ -216,19 +230,24 @@ def _create_collection(self, full_collection_name: str) -> None: self.db_client.create_collection( collection_name=full_collection_name, vectors_config=vectors_config ) + # TODO: we can use index hints to create indexes on properties or full text + # self.db_client.create_payload_index(full_collection_name, "_dlt_load_id", field_type="float") - def _create_point(self, obj: Dict[str, Any], collection_name: str) -> None: + def _create_point_no_vector(self, obj: Dict[str, Any], collection_name: str) -> None: """Inserts a point into a Qdrant collection without a vector. Args: obj (Dict[str, Any]): The arbitrary data to be inserted as payload. collection_name (str): The name of the collection to insert the point into. """ + # we want decreased ids because the point scroll functions orders by id ASC + # so we want newest first + id_ = 2**64 - int(precise_time() * 10**6) self.db_client.upsert( collection_name, points=[ models.PointStruct( - id=str(uuid.uuid4()), + id=id_, payload=obj, vector={}, ) @@ -308,7 +327,13 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Loads compressed state from destination storage By finding a load id that was completed """ - limit = 10 + # normalize property names + p_load_id = self.schema.naming.normalize_identifier("load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") + # p_created_at = self.schema.naming.normalize_identifier("created_at") + + limit = 100 offset = None while True: try: @@ -317,22 +342,28 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) state_records, offset = self.db_client.scroll( scroll_table_name, - with_payload=self.state_properties, + with_payload=self.pipeline_state_properties, scroll_filter=models.Filter( must=[ models.FieldCondition( - key="pipeline_name", match=models.MatchValue(value=pipeline_name) + key=p_pipeline_name, match=models.MatchValue(value=pipeline_name) ) ] ), + # search by package load id which is guaranteed to increase over time + # order_by=models.OrderBy( + # key=p_created_at, + # # direction=models.Direction.DESC, + # ), limit=limit, offset=offset, ) + # print("state_r", state_records) if len(state_records) == 0: return None for state_record in state_records: state = state_record.payload - load_id = state["_dlt_load_id"] + load_id = state[p_dlt_load_id] scroll_table_name = self._make_qualified_collection_name( self.schema.loads_table_name ) @@ -342,13 +373,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: count_filter=models.Filter( must=[ models.FieldCondition( - key="load_id", match=models.MatchValue(value=load_id) + key=p_load_id, match=models.MatchValue(value=load_id) ) ] ), ) if load_records.count > 0: - state["dlt_load_id"] = state.pop("_dlt_load_id") return StateInfo(**state) except Exception: return None @@ -357,18 +387,28 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + # this works only because we create points that have no vectors + # with decreasing ids. so newest (lowest ids) go first + # we do not use order_by because it requires and index to be created + # and this behavior is different for local and cloud qdrant + # p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") response = self.db_client.scroll( scroll_table_name, with_payload=True, scroll_filter=models.Filter( must=[ models.FieldCondition( - key="schema_name", + key=p_schema_name, match=models.MatchValue(value=self.schema.name), ) ] ), limit=1, + # order_by=models.OrderBy( + # key=p_inserted_at, + # direction=models.Direction.DESC, + # ) ) record = response[0][0].payload return StorageSchemaInfo(**record) @@ -378,13 +418,14 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) + p_version_hash = self.schema.naming.normalize_identifier("version_hash") response = self.db_client.scroll( scroll_table_name, with_payload=True, scroll_filter=models.Filter( must=[ models.FieldCondition( - key="version_hash", match=models.MatchValue(value=schema_hash) + key=p_version_hash, match=models.MatchValue(value=schema_hash) ) ] ), @@ -408,16 +449,14 @@ def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") def complete_load(self, load_id: str) -> None: - properties = { - "load_id": load_id, - "schema_name": self.schema.name, - "status": 0, - "inserted_at": str(pendulum.now()), - } + values = [load_id, self.schema.name, 0, str(pendulum.now()), self.schema.version_hash] + assert len(values) == len(self.loads_collection_properties) + properties = {k: v for k, v in zip(self.loads_collection_properties, values)} loads_table_name = self._make_qualified_collection_name(self.schema.loads_table_name) - self._create_point(properties, loads_table_name) + self._create_point_no_vector(properties, loads_table_name) def __enter__(self) -> "QdrantClient": + self.db_client = QdrantClient._create_db_client(self.config) return self def __exit__( @@ -426,20 +465,24 @@ def __exit__( exc_val: BaseException, exc_tb: TracebackType, ) -> None: - pass + if self.db_client: + self.db_client.close() + self.db_client = None def _update_schema_in_storage(self, schema: Schema) -> None: schema_str = json.dumps(schema.to_dict()) - properties = { - "version_hash": schema.stored_version_hash, - "schema_name": schema.name, - "version": schema.version, - "engine_version": schema.ENGINE_VERSION, - "inserted_at": str(pendulum.now()), - "schema": schema_str, - } + values = [ + schema.version, + schema.ENGINE_VERSION, + str(pendulum.now().isoformat()), + schema.name, + schema.stored_version_hash, + schema_str, + ] + assert len(values) == len(self.version_collection_properties) + properties = {k: v for k, v in zip(self.version_collection_properties, values)} version_table_name = self._make_qualified_collection_name(self.schema.version_table_name) - self._create_point(properties, version_table_name) + self._create_point_no_vector(properties, version_table_name) def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: @@ -460,6 +503,10 @@ def _collection_exists(self, table_name: str, qualify_table_name: bool = True) - ) self.db_client.get_collection(table_name) return True + except ValueError as e: + if "not found" in str(e): + return False + raise e except UnexpectedResponse as e: if e.status_code == 404: return False diff --git a/dlt/destinations/impl/redshift/__init__.py b/dlt/destinations/impl/redshift/__init__.py index 8a8cae84b4..e69de29bb2 100644 --- a/dlt/destinations/impl/redshift/__init__.py +++ b/dlt/destinations/impl/redshift/__init__.py @@ -1,25 +0,0 @@ -from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_redshift_identifier - caps.escape_literal = escape_redshift_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 127 - caps.max_column_identifier_length = 127 - caps.max_query_length = 16 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 65535 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - caps.alter_add_multi_column = False - - return caps diff --git a/dlt/destinations/impl/redshift/configuration.py b/dlt/destinations/impl/redshift/configuration.py index 72d7f70a9f..3b84c8663e 100644 --- a/dlt/destinations/impl/redshift/configuration.py +++ b/dlt/destinations/impl/redshift/configuration.py @@ -23,7 +23,9 @@ class RedshiftCredentials(PostgresCredentials): class RedshiftClientConfiguration(PostgresClientConfiguration): destination_type: Final[str] = dataclasses.field(default="redshift", init=False, repr=False, compare=False) # type: ignore credentials: RedshiftCredentials = None + staging_iam_role: Optional[str] = None + has_case_sensitive_identifiers: bool = False def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" diff --git a/dlt/destinations/impl/redshift/factory.py b/dlt/destinations/impl/redshift/factory.py index d80ef9dcad..7e6638be1e 100644 --- a/dlt/destinations/impl/redshift/factory.py +++ b/dlt/destinations/impl/redshift/factory.py @@ -1,12 +1,14 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.normalizers.naming import NamingConvention from dlt.destinations.impl.redshift.configuration import ( RedshiftCredentials, RedshiftClientConfiguration, ) -from dlt.destinations.impl.redshift import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.redshift.redshift import RedshiftClient @@ -15,8 +17,31 @@ class redshift(Destination[RedshiftClientConfiguration, "RedshiftClient"]): spec = RedshiftClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet"] + # redshift is case insensitive and will lower case identifiers when stored + # you can enable case sensitivity https://docs.aws.amazon.com/redshift/latest/dg/r_enable_case_sensitive_identifier.html + # then redshift behaves like postgres + caps.escape_identifier = escape_redshift_identifier + caps.escape_literal = escape_redshift_literal + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 127 + caps.max_column_identifier_length = 127 + caps.max_query_length = 16 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 65535 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.alter_add_multi_column = False + + return caps @property def client_class(self) -> t.Type["RedshiftClient"]: @@ -27,8 +52,8 @@ def client_class(self) -> t.Type["RedshiftClient"]: def __init__( self, credentials: t.Union[RedshiftCredentials, t.Dict[str, t.Any], str] = None, - create_indexes: bool = True, staging_iam_role: t.Optional[str] = None, + has_case_sensitive_identifiers: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -40,15 +65,28 @@ def __init__( Args: credentials: Credentials to connect to the redshift database. Can be an instance of `RedshiftCredentials` or a connection string in the format `redshift://user:password@host:port/database` - create_indexes: Should unique indexes be created staging_iam_role: IAM role to use for staging data in S3 + has_case_sensitive_identifiers: Are case sensitive identifiers enabled for a database **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, - create_indexes=create_indexes, staging_iam_role=staging_iam_role, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: RedshiftClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + if config.has_case_sensitive_identifiers: + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 672fceb7b2..faa037078a 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -1,11 +1,6 @@ import platform import os -from dlt.common.exceptions import TerminalValueError -from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient - -from dlt.common.schema.utils import table_schema_has_type, table_schema_has_type_with_precision - if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 @@ -15,25 +10,27 @@ # from psycopg2.sql import SQL, Composed -from typing import ClassVar, Dict, List, Optional, Sequence, Any +from typing import Dict, List, Optional, Sequence, Any, Tuple + -from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( NewLoadJob, CredentialsConfiguration, SupportsStagingDestination, ) from dlt.common.data_types import TDataType +from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.utils import table_schema_has_type, table_schema_has_type_with_precision +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat, TTableSchemaColumns from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob - -from dlt.destinations.impl.redshift import capabilities +from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.redshift.configuration import RedshiftClientConfiguration from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase @@ -109,8 +106,6 @@ def from_db_type( class RedshiftSqlClient(Psycopg2SqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - @staticmethod def _maybe_make_terminal_exception_from_data_error( pg_ex: psycopg2.DataError, @@ -151,7 +146,6 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: "CREDENTIALS" f" 'aws_access_key_id={aws_access_key};aws_secret_access_key={aws_secret_key}'" ) - table_name = table["name"] # get format ext = os.path.splitext(bucket_path)[1][1:] @@ -191,10 +185,9 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: raise ValueError(f"Unsupported file type {ext} for Redshift.") with self._sql_client.begin_transaction(): - dataset_name = self._sql_client.dataset_name # TODO: if we ever support csv here remember to add column names to COPY self._sql_client.execute_sql(f""" - COPY {dataset_name}.{table_name} + COPY {self._sql_client.make_qualified_table_name(table['name'])} FROM '{bucket_path}' {file_type} {dateformat} @@ -231,10 +224,15 @@ def gen_key_table_clauses( class RedshiftClient(InsertValuesJobClient, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: - sql_client = RedshiftSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: RedshiftClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = RedshiftSqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) super().__init__(schema, config, sql_client) self.sql_client = sql_client self.config: RedshiftClientConfiguration = config @@ -249,7 +247,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in HINT_TO_REDSHIFT_ATTR.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return ( f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) diff --git a/dlt/destinations/impl/snowflake/__init__.py b/dlt/destinations/impl/snowflake/__init__.py index dde4d5a382..e69de29bb2 100644 --- a/dlt/destinations/impl/snowflake/__init__.py +++ b/dlt/destinations/impl/snowflake/__init__.py @@ -1,25 +0,0 @@ -from dlt.common.data_writers.escape import escape_bigquery_identifier -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import escape_snowflake_identifier -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl", "parquet"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_snowflake_identifier - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 2 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 16 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - caps.alter_add_multi_column = True - caps.supports_clone_table = True - return caps diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 8529fbe5c8..1211b78672 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -1,8 +1,9 @@ import dataclasses import base64 -from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING, Union +from typing import Final, Optional, Any, Dict, ClassVar, List from dlt import version +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.libs.sql_alchemy import URL from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue @@ -135,6 +136,9 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration) keep_staged_files: bool = True """Whether to keep or delete the staged files after COPY INTO succeeds""" + csv_format: Optional[CsvFormatConfiguration] = None + """Optional csv format configuration""" + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/snowflake/factory.py b/dlt/destinations/impl/snowflake/factory.py index c4459232b7..f531b8704e 100644 --- a/dlt/destinations/impl/snowflake/factory.py +++ b/dlt/destinations/impl/snowflake/factory.py @@ -1,11 +1,14 @@ import typing as t +from dlt.common.data_writers.configuration import CsvFormatConfiguration +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_snowflake_identifier +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + from dlt.destinations.impl.snowflake.configuration import ( SnowflakeCredentials, SnowflakeClientConfiguration, ) -from dlt.destinations.impl.snowflake import capabilities -from dlt.common.destination import Destination, DestinationCapabilitiesContext if t.TYPE_CHECKING: from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient @@ -14,8 +17,31 @@ class snowflake(Destination[SnowflakeClientConfiguration, "SnowflakeClient"]): spec = SnowflakeClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl", "parquet", "csv"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet", "csv"] + # snowflake is case sensitive but all unquoted identifiers are upper cased + # so upper case identifiers are considered case insensitive + caps.escape_identifier = escape_snowflake_identifier + # dlt is configured to create case insensitive identifiers + # note that case sensitive naming conventions will change this setting to "str" (case sensitive) + caps.casefold_identifier = str.upper + caps.has_case_sensitive_identifiers = True + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 2 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 16 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.alter_add_multi_column = True + caps.supports_clone_table = True + return caps @property def client_class(self) -> t.Type["SnowflakeClient"]: @@ -28,6 +54,7 @@ def __init__( credentials: t.Union[SnowflakeCredentials, t.Dict[str, t.Any], str] = None, stage_name: t.Optional[str] = None, keep_staged_files: bool = True, + csv_format: t.Optional[CsvFormatConfiguration] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -46,6 +73,7 @@ def __init__( credentials=credentials, stage_name=stage_name, keep_staged_files=keep_staged_files, + csv_format=csv_format, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 70377de709..2a5671b7e7 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,6 +1,7 @@ -from typing import ClassVar, Optional, Sequence, Tuple, List, Any +from typing import Optional, Sequence, List from urllib.parse import urlparse, urlunparse +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( FollowupJob, @@ -14,7 +15,6 @@ AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults, ) -from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat @@ -24,13 +24,10 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.snowflake import capabilities from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_jobs import SqlJobParams from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -86,6 +83,7 @@ def __init__( table_name: str, load_id: str, client: SnowflakeSqlClient, + config: SnowflakeClientConfiguration, stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None, @@ -108,6 +106,14 @@ def __init__( credentials_clause = "" files_clause = "" stage_file_path = "" + on_error_clause = "" + + case_folding = ( + "CASE_SENSITIVE" + if client.capabilities.casefold_identifier is str + else "CASE_INSENSITIVE" + ) + column_match_clause = f"MATCH_BY_COLUMN_NAME='{case_folding}'" if bucket_path: bucket_url = urlparse(bucket_path) @@ -164,9 +170,28 @@ def __init__( from_clause = f"FROM {stage_file_path}" # decide on source format, stage_file_path will either be a local file or a bucket path - source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" + if file_name.endswith("jsonl"): + source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" if file_name.endswith("parquet"): - source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)" + source_format = ( + "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)" + # TODO: USE_VECTORIZED_SCANNER inserts null strings into VARIANT JSON + # " USE_VECTORIZED_SCANNER = TRUE)" + ) + if file_name.endswith("csv"): + # empty strings are NULL, no data is NULL, missing columns (ERROR_ON_COLUMN_COUNT_MISMATCH) are NULL + csv_format = config.csv_format or CsvFormatConfiguration() + source_format = ( + "(TYPE = 'CSV', BINARY_FORMAT = 'UTF-8', PARSE_HEADER =" + f" {csv_format.include_header}, FIELD_OPTIONALLY_ENCLOSED_BY = '\"', NULL_IF =" + " (''), ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE," + f" FIELD_DELIMITER='{csv_format.delimiter}', ENCODING='{csv_format.encoding}')" + ) + # disable column match if headers are not provided + if not csv_format.include_header: + column_match_clause = "" + if csv_format.on_error_continue: + on_error_clause = "ON_ERROR = CONTINUE" with client.begin_transaction(): # PUT and COPY in one tx if local file, otherwise only copy @@ -180,7 +205,8 @@ def __init__( {files_clause} {credentials_clause} FILE_FORMAT = {source_format} - MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' + {column_match_clause} + {on_error_clause} """) if stage_file_path and not keep_staged_files: client.execute_sql(f"REMOVE {stage_file_path}") @@ -193,10 +219,15 @@ def exception(self) -> str: class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None: - sql_client = SnowflakeSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: SnowflakeClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = SnowflakeSqlClient( + config.normalize_dataset_name(schema), config.credentials, capabilities + ) super().__init__(schema, config, sql_client) self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore @@ -211,6 +242,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> table["name"], load_id, self.sql_client, + self.config, stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, staging_credentials=( @@ -241,7 +273,7 @@ def _get_table_update_sql( sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) cluster_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster") ] if cluster_list: @@ -255,17 +287,7 @@ def _from_db_type( return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(c["name"]) + name = self.sql_client.escape_column_name(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) - - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - table_name = table_name.upper() # All snowflake tables are uppercased in information schema - exists, table = super().get_storage_table(table_name) - if not exists: - return exists, table - # Snowflake converts all unquoted columns to UPPER CASE - # Convert back to lower case to enable comparison with dlt schema - table = {col_name.lower(): dict(col, name=col_name.lower()) for col_name, col in table.items()} # type: ignore - return exists, table diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index 4a602ce0e8..e033a9f455 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -17,7 +17,6 @@ ) from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials -from dlt.destinations.impl.snowflake import capabilities class SnowflakeCursorImpl(DBApiCursorImpl): @@ -31,10 +30,14 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: class SnowflakeSqlClient(SqlClientBase[snowflake_lib.SnowflakeConnection], DBTransaction): dbapi: ClassVar[DBApi] = snowflake_lib - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: SnowflakeCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + credentials: SnowflakeCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, capabilities) self._conn: snowflake_lib.SnowflakeConnection = None self.credentials = credentials @@ -112,12 +115,6 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB self.open_connection() raise outer - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - # Always escape for uppercase - if escape: - return self.capabilities.escape_identifier(self.dataset_name) - return self.dataset_name.upper() - def _reset_connection(self) -> None: self._conn.rollback() self._conn.autocommit(True) diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py index f6ad7369c1..e69de29bb2 100644 --- a/dlt/destinations/impl/synapse/__init__.py +++ b/dlt/destinations/impl/synapse/__init__.py @@ -1,54 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.wei import EVM_DECIMAL_PRECISION - -from dlt.destinations.impl.synapse.synapse_adapter import synapse_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet"] - - caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 - - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_mssql_literal - - # Synapse has a max precision of 38 - # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#DataTypes - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - - # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#LimitationsRestrictions - caps.max_identifier_length = 128 - caps.max_column_identifier_length = 128 - - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-service-capacity-limits#queries - caps.max_query_length = 65536 * 4096 - caps.is_max_query_length_in_bytes = True - - # nvarchar(max) can store 2 GB - # https://learn.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver16#nvarchar---n--max-- - caps.max_text_data_type_length = 2 * 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-develop-transactions - caps.supports_transactions = True - caps.supports_ddl_transactions = False - - # Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries." - # if number of records exceeds a certain number. Which exact number that is seems not deterministic: - # in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same. - # 10.000 records is a "safe" amount that always seems to work. - caps.max_rows_per_insert = 10000 - - # datetimeoffset can store 7 digits for fractional seconds - # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 - caps.timestamp_precision = 7 - - return caps diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index 100878ae05..4820056e66 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -1,8 +1,10 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.normalizers.naming import NamingConvention +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.configuration import ( SynapseCredentials, SynapseClientConfiguration, @@ -21,8 +23,57 @@ class synapse(Destination[SynapseClientConfiguration, "SynapseClient"]): # def spec(self) -> t.Type[SynapseClientConfiguration]: # return SynapseClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet"] + + caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 + + # similarly to mssql case sensitivity depends on database collation + # https://learn.microsoft.com/en-us/sql/relational-databases/collations/collation-and-unicode-support?view=sql-server-ver16#collations-in-azure-sql-database + # note that special option CATALOG_COLLATION is used to change it + caps.escape_identifier = escape_postgres_identifier + caps.escape_literal = escape_mssql_literal + # we allow to reconfigure capabilities in the mssql factory + caps.has_case_sensitive_identifiers = False + + # Synapse has a max precision of 38 + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#DataTypes + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#LimitationsRestrictions + caps.max_identifier_length = 128 + caps.max_column_identifier_length = 128 + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-service-capacity-limits#queries + caps.max_query_length = 65536 * 4096 + caps.is_max_query_length_in_bytes = True + + # nvarchar(max) can store 2 GB + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver16#nvarchar---n--max-- + caps.max_text_data_type_length = 2 * 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-develop-transactions + caps.supports_transactions = True + caps.supports_ddl_transactions = False + + # Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries." + # if number of records exceeds a certain number. Which exact number that is seems not deterministic: + # in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same. + # 10.000 records is a "safe" amount that always seems to work. + caps.max_rows_per_insert = 10000 + + # datetimeoffset can store 7 digits for fractional seconds + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 + caps.timestamp_precision = 7 + + return caps @property def client_class(self) -> t.Type["SynapseClient"]: @@ -36,6 +87,7 @@ def __init__( default_table_index_type: t.Optional[TTableIndexType] = "heap", create_indexes: bool = False, staging_use_msi: bool = False, + has_case_sensitive_identifiers: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -50,6 +102,7 @@ def __init__( default_table_index_type: Maps directly to the default_table_index_type attribute of the SynapseClientConfiguration object. create_indexes: Maps directly to the create_indexes attribute of the SynapseClientConfiguration object. staging_use_msi: Maps directly to the staging_use_msi attribute of the SynapseClientConfiguration object. + has_case_sensitive_identifiers: Are identifiers used by synapse database case sensitive (following the catalog collation) **kwargs: Additional arguments passed to the destination config """ super().__init__( @@ -57,7 +110,21 @@ def __init__( default_table_index_type=default_table_index_type, create_indexes=create_indexes, staging_use_msi=staging_use_msi, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: SynapseClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + if config.has_case_sensitive_identifiers: + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/synapse/sql_client.py b/dlt/destinations/impl/synapse/sql_client.py index 089c58e57c..db1b3e7cf6 100644 --- a/dlt/destinations/impl/synapse/sql_client.py +++ b/dlt/destinations/impl/synapse/sql_client.py @@ -5,15 +5,12 @@ from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.configuration import SynapseCredentials from dlt.destinations.exceptions import DatabaseUndefinedRelation class SynapseSqlClient(PyOdbcMsSqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def drop_tables(self, *tables: str) -> None: if not tables: return diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 48171ace4c..de2f9d4472 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -1,5 +1,5 @@ import os -from typing import ClassVar, Sequence, List, Dict, Any, Optional, cast, Union +from typing import Sequence, List, Dict, Any, Optional, cast, Union from copy import deepcopy from textwrap import dedent from urllib.parse import urlparse, urlunparse @@ -29,12 +29,11 @@ from dlt.destinations.impl.mssql.mssql import ( MsSqlTypeMapper, - MsSqlClient, + MsSqlJobClient, VARCHAR_MAX_N, VARBINARY_MAX_N, ) -from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient from dlt.destinations.impl.synapse.configuration import SynapseClientConfiguration from dlt.destinations.impl.synapse.synapse_adapter import ( @@ -53,14 +52,17 @@ } -class SynapseClient(MsSqlClient, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: - super().__init__(schema, config) +class SynapseClient(MsSqlJobClient, SupportsStagingDestination): + def __init__( + self, + schema: Schema, + config: SynapseClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) self.config: SynapseClientConfiguration = config self.sql_client = SynapseSqlClient( - config.normalize_dataset_name(schema), config.credentials + config.normalize_dataset_name(schema), config.credentials, capabilities ) self.active_hints = deepcopy(HINT_TO_SYNAPSE_ATTR) diff --git a/dlt/destinations/impl/weaviate/__init__.py b/dlt/destinations/impl/weaviate/__init__.py index 143e0260d2..e69de29bb2 100644 --- a/dlt/destinations/impl/weaviate/__init__.py +++ b/dlt/destinations/impl/weaviate/__init__.py @@ -1,19 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.impl.weaviate.weaviate_adapter import weaviate_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] - - caps.max_identifier_length = 200 - caps.max_column_identifier_length = 1024 - caps.max_query_length = 8 * 1024 * 1024 - caps.is_max_query_length_in_bytes = False - caps.max_text_data_type_length = 8 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = False - caps.supports_ddl_transactions = False - caps.naming_convention = "dlt.destinations.impl.weaviate.naming" - - return caps diff --git a/dlt/destinations/impl/weaviate/ci_naming.py b/dlt/destinations/impl/weaviate/ci_naming.py index cc8936f42d..6e1b0c129e 100644 --- a/dlt/destinations/impl/weaviate/ci_naming.py +++ b/dlt/destinations/impl/weaviate/ci_naming.py @@ -2,6 +2,12 @@ class NamingConvention(WeaviateNamingConvention): + """Case insensitive naming convention for Weaviate. Lower cases all identifiers""" + + @property + def is_case_sensitive(self) -> bool: + return False + def _lowercase_property(self, identifier: str) -> str: """Lowercase the whole property to become case insensitive""" return identifier.lower() diff --git a/dlt/destinations/impl/weaviate/exceptions.py b/dlt/destinations/impl/weaviate/exceptions.py index ee798e4e76..11e440a811 100644 --- a/dlt/destinations/impl/weaviate/exceptions.py +++ b/dlt/destinations/impl/weaviate/exceptions.py @@ -1,16 +1,16 @@ from dlt.common.destination.exceptions import DestinationException, DestinationTerminalException -class WeaviateBatchError(DestinationException): +class WeaviateGrpcError(DestinationException): pass class PropertyNameConflict(DestinationTerminalException): - def __init__(self) -> None: + def __init__(self, error: str) -> None: super().__init__( "Your data contains items with identical property names when compared case insensitive." " Weaviate cannot handle such data. Please clean up your data before loading or change" " to case insensitive naming convention. See" " https://dlthub.com/docs/dlt-ecosystem/destinations/weaviate#names-normalization for" - " details." + f" details. [{error}]" ) diff --git a/dlt/destinations/impl/weaviate/factory.py b/dlt/destinations/impl/weaviate/factory.py index 0449e6cdd5..3d78c9582a 100644 --- a/dlt/destinations/impl/weaviate/factory.py +++ b/dlt/destinations/impl/weaviate/factory.py @@ -6,7 +6,6 @@ WeaviateCredentials, WeaviateClientConfiguration, ) -from dlt.destinations.impl.weaviate import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient @@ -15,8 +14,26 @@ class weaviate(Destination[WeaviateClientConfiguration, "WeaviateClient"]): spec = WeaviateClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl"] + # weaviate names are case sensitive following GraphQL naming convention + # https://weaviate.io/developers/weaviate/config-refs/schema + caps.has_case_sensitive_identifiers = False + # weaviate will upper case first letter of class name and lower case first letter of a property + # we assume that naming convention will do that + caps.casefold_identifier = str + caps.max_identifier_length = 200 + caps.max_column_identifier_length = 1024 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 8 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = False + caps.naming_convention = "dlt.destinations.impl.weaviate.naming" + + return caps @property def client_class(self) -> t.Type["WeaviateClient"]: diff --git a/dlt/destinations/impl/weaviate/naming.py b/dlt/destinations/impl/weaviate/naming.py index f5c94c872f..81a53dafd3 100644 --- a/dlt/destinations/impl/weaviate/naming.py +++ b/dlt/destinations/impl/weaviate/naming.py @@ -1,14 +1,20 @@ import re +from typing import ClassVar from dlt.common.normalizers.naming import NamingConvention as BaseNamingConvention from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.typing import REPattern class NamingConvention(SnakeCaseNamingConvention): """Normalizes identifiers according to Weaviate documentation: https://weaviate.io/developers/weaviate/config-refs/schema#class""" + @property + def is_case_sensitive(self) -> bool: + return True + RESERVED_PROPERTIES = {"id": "__id", "_id": "___id", "_additional": "__additional"} - _RE_UNDERSCORES = re.compile("([^_])__+") + RE_UNDERSCORES: ClassVar[REPattern] = re.compile("([^_])__+") _STARTS_DIGIT = re.compile("^[0-9]") _STARTS_NON_LETTER = re.compile("^[0-9_]") _SPLIT_UNDERSCORE_NON_CAP = re.compile("(_[^A-Z])") @@ -51,11 +57,11 @@ def _lowercase_property(self, identifier: str) -> str: def _base_normalize(self, identifier: str) -> str: # all characters that are not letters digits or a few special chars are replaced with underscore normalized_ident = identifier.translate(self._TR_REDUCE_ALPHABET) - normalized_ident = self._RE_NON_ALPHANUMERIC.sub("_", normalized_ident) + normalized_ident = self.RE_NON_ALPHANUMERIC.sub("_", normalized_ident) # replace trailing _ with x stripped_ident = normalized_ident.rstrip("_") strip_count = len(normalized_ident) - len(stripped_ident) stripped_ident += "x" * strip_count # replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR - return self._RE_UNDERSCORES.sub(r"\1_", stripped_ident) + return self.RE_UNDERSCORES.sub(r"\1_", stripped_ident) diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index 2d75ca0809..71f2f13e76 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -31,20 +31,23 @@ from dlt.common.time import ensure_pendulum_datetime from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TTableSchemaColumns from dlt.common.schema.typing import TColumnSchema, TColumnType -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + loads_table, + normalize_table_identifiers, + version_table, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync -from dlt.common.data_types import TDataType from dlt.common.storages import FileStorage from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT - from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo -from dlt.destinations.impl.weaviate import capabilities from dlt.destinations.impl.weaviate.configuration import WeaviateClientConfiguration -from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateBatchError +from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateGrpcError from dlt.destinations.type_mapping import TypeMapper +from dlt.destinations.utils import get_pipeline_state_query_columns NON_VECTORIZED_CLASS = { @@ -104,7 +107,7 @@ def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: if "conflict for property" in str(status_ex) or "none vectorizer module" in str( status_ex ): - raise PropertyNameConflict() + raise PropertyNameConflict(str(status_ex)) raise DestinationTerminalException(status_ex) # looks like there are no more terminal exception raise DestinationTransientException(status_ex) @@ -115,23 +118,25 @@ def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore -def wrap_batch_error(f: TFun) -> TFun: +def wrap_grpc_error(f: TFun) -> TFun: @wraps(f) def _wrap(*args: Any, **kwargs: Any) -> Any: try: return f(*args, **kwargs) # those look like terminal exceptions - except WeaviateBatchError as batch_ex: + except WeaviateGrpcError as batch_ex: errors = batch_ex.args[0] message = errors["error"][0]["message"] # TODO: actually put the job in failed/retry state and prepare exception message with full info on failing item if "invalid" in message and "property" in message and "on class" in message: raise DestinationTerminalException( - f"Batch failed {errors} AND WILL **NOT** BE RETRIED" + f"Grpc (batch, query) failed {errors} AND WILL **NOT** BE RETRIED" ) if "conflict for property" in message: - raise PropertyNameConflict() - raise DestinationTransientException(f"Batch failed {errors} AND WILL BE RETRIED") + raise PropertyNameConflict(message) + raise DestinationTransientException( + f"Grpc (batch, query) failed {errors} AND WILL BE RETRIED" + ) except Exception: raise DestinationTransientException("Batch failed AND WILL BE RETRIED") @@ -174,14 +179,14 @@ def load_batch(self, f: IO[str]) -> None: Weaviate batch supports retries so we do not need to do that. """ - @wrap_batch_error + @wrap_grpc_error def check_batch_result(results: List[StrAny]) -> None: """This kills batch on first error reported""" if results is not None: for result in results: if "result" in result and "errors" in result["result"]: if "error" in result["result"]["errors"]: - raise WeaviateBatchError(result["result"]["errors"]) + raise WeaviateGrpcError(result["result"]["errors"]) with self.db_client.batch( batch_size=self.client_config.batch_size, @@ -233,20 +238,25 @@ def exception(self) -> str: class WeaviateClient(JobClientBase, WithStateSync): """Weaviate client implementation.""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - state_properties: ClassVar[List[str]] = [ - "version", - "engine_version", - "pipeline_name", - "state", - "created_at", - "_dlt_load_id", - ] - - def __init__(self, schema: Schema, config: WeaviateClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: WeaviateClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) + # get definitions of the dlt tables, normalize column names and keep for later use + version_table_ = normalize_table_identifiers(version_table(), schema.naming) + self.version_collection_properties = list(version_table_["columns"].keys()) + loads_table_ = normalize_table_identifiers(loads_table(), schema.naming) + self.loads_collection_properties = list(loads_table_["columns"].keys()) + state_table_ = normalize_table_identifiers( + get_pipeline_state_query_columns(), schema.naming + ) + self.pipeline_state_properties = list(state_table_["columns"].keys()) + self.config: WeaviateClientConfiguration = config - self.db_client = self.create_db_client(config) + self.db_client: weaviate.Client = None self._vectorizer_config = { "vectorizer": config.vectorizer, @@ -451,15 +461,23 @@ def update_stored_schema( return applied_update def _execute_schema_update(self, only_tables: Iterable[str]) -> None: - for table_name in only_tables or self.schema.tables: + for table_name in only_tables or self.schema.tables.keys(): exists, existing_columns = self.get_storage_table(table_name) # TODO: detect columns where vectorization was added or removed and modify it. currently we ignore change of hints - new_columns = self.schema.get_new_table_columns(table_name, existing_columns) + new_columns = self.schema.get_new_table_columns( + table_name, + existing_columns, + case_sensitive=self.capabilities.has_case_sensitive_identifiers + and self.capabilities.casefold_identifier is str, + ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: + is_collection_vectorized = self._is_collection_vectorized(table_name) for column in new_columns: - prop = self._make_property_schema(column["name"], column) + prop = self._make_property_schema( + column["name"], column, is_collection_vectorized + ) self.create_class_property(table_name, prop) else: class_schema = self.make_weaviate_class_schema(table_name) @@ -487,6 +505,11 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Loads compressed state from destination storage""" + # normalize properties + p_load_id = self.schema.naming.normalize_identifier("load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") + p_status = self.schema.naming.normalize_identifier("status") # we need to find a stored state that matches a load id that was completed # we retrieve the state in blocks of 10 for this @@ -496,44 +519,45 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: state_records = self.get_records( self.schema.state_table_name, # search by package load id which is guaranteed to increase over time - sort={"path": ["_dlt_load_id"], "order": "desc"}, + sort={"path": [p_dlt_load_id], "order": "desc"}, where={ - "path": ["pipeline_name"], + "path": [p_pipeline_name], "operator": "Equal", "valueString": pipeline_name, }, limit=stepsize, offset=offset, - properties=self.state_properties, + properties=self.pipeline_state_properties, ) offset += stepsize if len(state_records) == 0: return None for state in state_records: - load_id = state["_dlt_load_id"] + load_id = state[p_dlt_load_id] load_records = self.get_records( self.schema.loads_table_name, where={ - "path": ["load_id"], + "path": [p_load_id], "operator": "Equal", "valueString": load_id, }, limit=1, - properties=["load_id", "status"], + properties=[p_load_id, p_status], ) # if there is a load for this state which was successful, return the state if len(load_records): - state["dlt_load_id"] = state.pop("_dlt_load_id") return StateInfo(**state) def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") try: record = self.get_records( self.schema.version_table_name, - sort={"path": ["inserted_at"], "order": "desc"}, + sort={"path": [p_inserted_at], "order": "desc"}, where={ - "path": ["schema_name"], + "path": [p_schema_name], "operator": "Equal", "valueString": self.schema.name, }, @@ -544,11 +568,12 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: return None def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + p_version_hash = self.schema.naming.normalize_identifier("version_hash") try: record = self.get_records( self.schema.version_table_name, where={ - "path": ["version_hash"], + "path": [p_version_hash], "operator": "Equal", "valueString": schema_hash, }, @@ -585,8 +610,13 @@ def get_records( query = query.with_offset(offset) response = query.do() + # if json rpc is used, weaviate does not raise exceptions + if "errors" in response: + raise WeaviateGrpcError(response["errors"]) full_class_name = self.make_qualified_class_name(table_name) records = response["data"]["Get"][full_class_name] + if records is None: + raise DestinationTransientException(f"Could not obtain records for {full_class_name}") return cast(List[Dict[str, Any]], records) def make_weaviate_class_schema(self, table_name: str) -> Dict[str, Any]: @@ -597,31 +627,39 @@ def make_weaviate_class_schema(self, table_name: str) -> Dict[str, Any]: } # check if any column requires vectorization - if get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT): + if self._is_collection_vectorized(table_name): class_schema.update(self._vectorizer_config) else: class_schema.update(NON_VECTORIZED_CLASS) return class_schema + def _is_collection_vectorized(self, table_name: str) -> bool: + """Tells is any of the columns has vectorize hint set""" + return ( + len(get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT)) > 0 + ) + def _make_properties(self, table_name: str) -> List[Dict[str, Any]]: """Creates a Weaviate properties schema from a table schema. Args: table: The table name for which columns should be converted to properties """ - + is_collection_vectorized = self._is_collection_vectorized(table_name) return [ - self._make_property_schema(column_name, column) + self._make_property_schema(column_name, column, is_collection_vectorized) for column_name, column in self.schema.get_table_columns(table_name).items() ] - def _make_property_schema(self, column_name: str, column: TColumnSchema) -> Dict[str, Any]: + def _make_property_schema( + self, column_name: str, column: TColumnSchema, is_collection_vectorized: bool + ) -> Dict[str, Any]: extra_kv = {} vectorizer_name = self._vectorizer_config["vectorizer"] # x-weaviate-vectorize: (bool) means that this field should be vectorized - if not column.get(VECTORIZE_HINT, False): + if is_collection_vectorized and not column.get(VECTORIZE_HINT, False): # tell weaviate explicitly to not vectorize when column has no vectorize hint extra_kv["moduleConfig"] = { vectorizer_name: { @@ -655,15 +693,20 @@ def restore_file_load(self, file_path: str) -> LoadJob: @wrap_weaviate_error def complete_load(self, load_id: str) -> None: - properties = { - "load_id": load_id, - "schema_name": self.schema.name, - "status": 0, - "inserted_at": pendulum.now().isoformat(), - } + # corresponds to order of the columns in loads_table() + values = [ + load_id, + self.schema.name, + 0, + pendulum.now().isoformat(), + self.schema.version_hash, + ] + assert len(values) == len(self.loads_collection_properties) + properties = {k: v for k, v in zip(self.loads_collection_properties, values)} self.create_object(properties, self.schema.loads_table_name) def __enter__(self) -> "WeaviateClient": + self.db_client = self.create_db_client(self.config) return self def __exit__( @@ -672,18 +715,22 @@ def __exit__( exc_val: BaseException, exc_tb: TracebackType, ) -> None: - pass + if self.db_client: + self.db_client = None def _update_schema_in_storage(self, schema: Schema) -> None: schema_str = json.dumps(schema.to_dict()) - properties = { - "version_hash": schema.stored_version_hash, - "schema_name": schema.name, - "version": schema.version, - "engine_version": schema.ENGINE_VERSION, - "inserted_at": pendulum.now().isoformat(), - "schema": schema_str, - } + # corresponds to order of the columns in version_table() + values = [ + schema.version, + schema.ENGINE_VERSION, + str(pendulum.now().isoformat()), + schema.name, + schema.stored_version_hash, + schema_str, + ] + assert len(values) == len(self.version_collection_properties) + properties = {k: v for k, v in zip(self.version_collection_properties, values)} self.create_object(properties, self.schema.version_table_name) def _from_db_type( diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 74e14f0221..652d13f556 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -36,6 +36,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # the procedure below will split the inserts into max_query_length // 2 packs with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: header = f.readline() + # format and casefold header + header = self._sql_client.capabilities.casefold_identifier(header).format( + qualified_table_name + ) writer_type = self._sql_client.capabilities.insert_values_writer_type if writer_type == "default": sep = "," @@ -70,7 +74,7 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # Chunk by max_rows - 1 for simplicity because one more row may be added for chunk in chunks(values_rows, max_rows - 1): processed += len(chunk) - insert_sql.append(header.format(qualified_table_name)) + insert_sql.append(header) if writer_type == "default": insert_sql.append(values_mark) if processed == len_rows: @@ -82,11 +86,9 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st else: # otherwise write all content in a single INSERT INTO if writer_type == "default": - insert_sql.extend( - [header.format(qualified_table_name), values_mark, content + until_nl] - ) + insert_sql.extend([header, values_mark, content + until_nl]) elif writer_type == "select_union": - insert_sql.extend([header.format(qualified_table_name), content + until_nl]) + insert_sql.extend([header, content + until_nl]) # actually this may be empty if we were able to read a full file into content if not is_eof: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index ac3636db2b..0a627bbdfb 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -1,44 +1,39 @@ import os from abc import abstractmethod import base64 -import binascii import contextlib from copy import copy -import datetime # noqa: 251 from types import TracebackType from typing import ( Any, - ClassVar, List, - NamedTuple, Optional, Sequence, Tuple, Type, Iterable, Iterator, - ContextManager, - cast, ) import zlib import re -from dlt.common import logger +from dlt.common import pendulum, logger from dlt.common.json import json -from dlt.common.pendulum import pendulum -from dlt.common.data_types import TDataType from dlt.common.schema.typing import ( COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, - TWriteDisposition, TTableFormat, ) +from dlt.common.schema.utils import ( + loads_table, + normalize_table_identifiers, + version_table, +) from dlt.common.storages import FileStorage from dlt.common.storages.load_package import LoadJobInfo from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables -from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME from dlt.common.destination.reference import ( StateInfo, StorageSchemaInfo, @@ -59,6 +54,11 @@ from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob from dlt.destinations.typing import TNativeConn from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.utils import ( + get_pipeline_state_query_columns, + info_schema_null_to_bool, + verify_sql_job_client_schema, +) # this should suffice for now DDL_COMMANDS = ["ALTER", "CREATE", "DROP"] @@ -78,7 +78,7 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: sql_client.execute_many(self._split_fragments(sql)) # if we detect ddl transactions, only execute transaction if supported by client elif ( - not self._string_containts_ddl_queries(sql) + not self._string_contains_ddl_queries(sql) or sql_client.capabilities.supports_ddl_transactions ): # with sql_client.begin_transaction(): @@ -95,7 +95,7 @@ def exception(self) -> str: # this part of code should be never reached raise NotImplementedError() - def _string_containts_ddl_queries(self, sql: str) -> bool: + def _string_contains_ddl_queries(self, sql: str) -> bool: for cmd in DDL_COMMANDS: if re.search(cmd, sql, re.IGNORECASE): return True @@ -133,37 +133,28 @@ def state(self) -> TLoadJobState: class SqlJobClientBase(JobClientBase, WithStateSync): - _VERSION_TABLE_SCHEMA_COLUMNS: ClassVar[Tuple[str, ...]] = ( - "version_hash", - "schema_name", - "version", - "engine_version", - "inserted_at", - "schema", - ) - _STATE_TABLE_COLUMNS: ClassVar[Tuple[str, ...]] = ( - "version", - "engine_version", - "pipeline_name", - "state", - "created_at", - "_dlt_load_id", - ) - def __init__( self, schema: Schema, config: DestinationClientConfiguration, sql_client: SqlClientBase[TNativeConn], ) -> None: + # get definitions of the dlt tables, normalize column names and keep for later use + version_table_ = normalize_table_identifiers(version_table(), schema.naming) self.version_table_schema_columns = ", ".join( - sql_client.escape_column_name(col) for col in self._VERSION_TABLE_SCHEMA_COLUMNS + sql_client.escape_column_name(col) for col in version_table_["columns"] + ) + loads_table_ = normalize_table_identifiers(loads_table(), schema.naming) + self.loads_table_schema_columns = ", ".join( + sql_client.escape_column_name(col) for col in loads_table_["columns"] + ) + state_table_ = normalize_table_identifiers( + get_pipeline_state_query_columns(), schema.naming ) self.state_table_columns = ", ".join( - sql_client.escape_column_name(col) for col in self._STATE_TABLE_COLUMNS + sql_client.escape_column_name(col) for col in state_table_["columns"] ) - - super().__init__(schema, config) + super().__init__(schema, config, sql_client.capabilities) self.sql_client = sql_client assert isinstance(config, DestinationClientDwhConfiguration) self.config: DestinationClientDwhConfiguration = config @@ -250,10 +241,12 @@ def _create_replace_followup_jobs( def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" - jobs = super().create_table_chain_completed_followup_jobs(table_chain, table_chain_jobs) + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) write_disposition = table_chain[0]["write_disposition"] if write_disposition == "append": jobs.extend(self._create_append_followup_jobs(table_chain)) @@ -290,8 +283,7 @@ def complete_load(self, load_id: str) -> None: name = self.sql_client.make_qualified_table_name(self.schema.loads_table_name) now_ts = pendulum.now() self.sql_client.execute_sql( - f"INSERT INTO {name}(load_id, schema_name, status, inserted_at, schema_version_hash)" - " VALUES(%s, %s, %s, %s, %s);", + f"INSERT INTO {name}({self.loads_table_schema_columns}) VALUES(%s, %s, %s, %s, %s);", load_id, self.schema.name, 0, @@ -308,54 +300,84 @@ def __exit__( ) -> None: self.sql_client.close_connection() - def _get_storage_table_query_columns(self) -> List[str]: - """Column names used when querying table from information schema. - Override for databases that use different namings. - """ - fields = ["column_name", "data_type", "is_nullable"] - if self.capabilities.schema_supports_numeric_precision: - fields += ["numeric_precision", "numeric_scale"] - return fields + def get_storage_tables( + self, table_names: Iterable[str] + ) -> Iterable[Tuple[str, TTableSchemaColumns]]: + """Uses INFORMATION_SCHEMA to retrieve table and column information for tables in `table_names` iterator. + Table names should be normalized according to naming convention and will be further converted to desired casing + in order to (in most cases) create case-insensitive name suitable for search in information schema. - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - def _null_to_bool(v: str) -> bool: - if v == "NO": - return False - elif v == "YES": - return True - raise ValueError(v) + The column names are returned as in information schema. To match those with columns in existing table, you'll need to use + `schema.get_new_table_columns` method and pass the correct casing. Most of the casing function are irreversible so it is not + possible to convert identifiers into INFORMATION SCHEMA back into case sensitive dlt schema. + """ + table_names = list(table_names) + if len(table_names) == 0: + # empty generator + return + # get schema search components + catalog_name, schema_name, folded_table_names = ( + self.sql_client._get_information_schema_components(*table_names) + ) + # create table name conversion lookup table + name_lookup = { + folded_name: name for folded_name, name in zip(folded_table_names, table_names) + } + # this should never happen: we verify schema for name collisions before loading + assert len(name_lookup) == len(table_names), ( + f"One or more of tables in {table_names} after applying" + f" {self.capabilities.casefold_identifier} produced a name collision." + ) - fields = self._get_storage_table_query_columns() - db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split( - ".", 3 + # rows = self.sql_client.execute_sql(query, *db_params) + query, db_params = self._get_info_schema_columns_query( + catalog_name, schema_name, folded_table_names ) - query = f""" -SELECT {",".join(fields)} - FROM INFORMATION_SCHEMA.COLUMNS -WHERE """ - if len(db_params) == 3: - query += "table_catalog = %s AND " - query += "table_schema = %s AND table_name = %s ORDER BY ordinal_position;" rows = self.sql_client.execute_sql(query, *db_params) - - # if no rows we assume that table does not exist - schema_table: TTableSchemaColumns = {} - if len(rows) == 0: - # TODO: additionally check if table exists - return False, schema_table - # TODO: pull more data to infer indexes, PK and uniques attributes/constraints + prev_table: str = None + storage_columns: TTableSchemaColumns = None for c in rows: + # make sure that new table is known + assert ( + c[0] in name_lookup + ), f"Table name {c[0]} not in expected tables {name_lookup.keys()}" + table_name = name_lookup[c[0]] + if prev_table != table_name: + # yield what we have + if storage_columns: + yield (prev_table, storage_columns) + # we have new table + storage_columns = {} + prev_table = table_name + # remove from table_names + table_names.remove(prev_table) + # add columns + col_name = c[1] numeric_precision = ( - c[3] if self.capabilities.schema_supports_numeric_precision else None + c[4] if self.capabilities.schema_supports_numeric_precision else None ) - numeric_scale = c[4] if self.capabilities.schema_supports_numeric_precision else None + numeric_scale = c[5] if self.capabilities.schema_supports_numeric_precision else None + schema_c: TColumnSchemaBase = { - "name": c[0], - "nullable": _null_to_bool(c[2]), - **self._from_db_type(c[1], numeric_precision, numeric_scale), + "name": col_name, + "nullable": info_schema_null_to_bool(c[3]), + **self._from_db_type(c[2], numeric_precision, numeric_scale), } - schema_table[c[0]] = schema_c # type: ignore - return True, schema_table + storage_columns[col_name] = schema_c # type: ignore + # yield last table, it must have at least one column or we had no rows + if storage_columns: + yield (prev_table, storage_columns) + # if no columns we assume that table does not exist + for table_name in table_names: + yield (table_name, {}) + + def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: + """Uses get_storage_tables to get single `table_name` schema. + + Returns (True, ...) if table exists and (False, {}) when not + """ + storage_table = list(self.get_storage_tables([table_name]))[0] + return len(storage_table[1]) > 0, storage_table[1] @abstractmethod def _from_db_type( @@ -365,31 +387,90 @@ def _from_db_type( def get_stored_schema(self) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) + c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at") query = ( - f"SELECT {self.version_table_schema_columns} FROM {name} WHERE schema_name = %s ORDER" - " BY inserted_at DESC;" + f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" + f" ORDER BY {c_inserted_at} DESC;" ) return self._row_to_schema_info(query, self.schema.name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name) loads_table = self.sql_client.make_qualified_table_name(self.schema.loads_table_name) + c_load_id, c_dlt_load_id, c_pipeline_name, c_status = self._norm_and_escape_columns( + "load_id", "_dlt_load_id", "pipeline_name", "status" + ) query = ( f"SELECT {self.state_table_columns} FROM {state_table} AS s JOIN {loads_table} AS l ON" - " l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY" - " l.load_id DESC" + f" l.{c_load_id} = s.{c_dlt_load_id} WHERE {c_pipeline_name} = %s AND l.{c_status} = 0" + f" ORDER BY {c_load_id} DESC" ) with self.sql_client.execute_query(query, pipeline_name) as cur: row = cur.fetchone() if not row: return None - return StateInfo(row[0], row[1], row[2], row[3], pendulum.instance(row[4])) + # NOTE: we request order of columns in SELECT statement which corresponds to StateInfo + return StateInfo( + version=row[0], + engine_version=row[1], + pipeline_name=row[2], + state=row[3], + created_at=pendulum.instance(row[4]), + _dlt_load_id=row[5], + ) + + def _norm_and_escape_columns(self, *columns: str) -> Iterator[str]: + return map( + self.sql_client.escape_column_name, map(self.schema.naming.normalize_path, columns) + ) def get_stored_schema_by_hash(self, version_hash: str) -> StorageSchemaInfo: - name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - query = f"SELECT {self.version_table_schema_columns} FROM {name} WHERE version_hash = %s;" + table_name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) + (c_version_hash,) = self._norm_and_escape_columns("version_hash") + query = ( + f"SELECT {self.version_table_schema_columns} FROM {table_name} WHERE" + f" {c_version_hash} = %s;" + ) return self._row_to_schema_info(query, version_hash) + def _get_info_schema_columns_query( + self, catalog_name: Optional[str], schema_name: str, folded_table_names: List[str] + ) -> Tuple[str, List[Any]]: + """Generates SQL to query INFORMATION_SCHEMA.COLUMNS for a set of tables in `folded_table_names`. Input identifiers must be already + in a form that can be passed to a query via db_params. `catalogue_name` is optional and when None, the part of query selecting it + is skipped. + + Returns: query and list of db_params tuple + """ + query = f""" +SELECT {",".join(self._get_storage_table_query_columns())} + FROM INFORMATION_SCHEMA.COLUMNS +WHERE """ + + db_params = [] + if catalog_name: + db_params.append(catalog_name) + query += "table_catalog = %s AND " + db_params.append(schema_name) + db_params = db_params + folded_table_names + # placeholder for each table + table_placeholders = ",".join(["%s"] * len(folded_table_names)) + query += ( + f"table_schema = %s AND table_name IN ({table_placeholders}) ORDER BY table_name," + " ordinal_position;" + ) + + return query, db_params + + def _get_storage_table_query_columns(self) -> List[str]: + """Column names used when querying table from information schema. + Override for databases that use different namings. + """ + fields = ["table_name", "column_name", "data_type", "is_nullable"] + if self.capabilities.schema_supports_numeric_precision: + fields += ["numeric_precision", "numeric_scale"] + return fields + def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTables: sql_scripts, schema_update = self._build_schema_update_sql(only_tables) # Stay within max query size when doing DDL. @@ -416,12 +497,16 @@ def _build_schema_update_sql( """ sql_updates = [] schema_update: TSchemaTables = {} - for table_name in only_tables or self.schema.tables: - exists, storage_table = self.get_storage_table(table_name) - new_columns = self._create_table_update(table_name, storage_table) + for table_name, storage_columns in self.get_storage_tables( + only_tables or self.schema.tables.keys() + ): + # this will skip incomplete columns + new_columns = self._create_table_update(table_name, storage_columns) if len(new_columns) > 0: # build and add sql to execute - sql_statements = self._get_table_update_sql(table_name, new_columns, exists) + sql_statements = self._get_table_update_sql( + table_name, new_columns, len(storage_columns) > 0 + ) for sql in sql_statements: if not sql.endswith(";"): sql += ";" @@ -472,7 +557,7 @@ def _get_table_update_sql( for hint in COLUMN_HINTS: if any(c.get(hint, False) is True for c in new_columns): hint_columns = [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get(hint, False) ] @@ -501,8 +586,13 @@ def _gen_not_null(v: bool) -> str: def _create_table_update( self, table_name: str, storage_columns: TTableSchemaColumns ) -> Sequence[TColumnSchema]: - # compare table with stored schema and produce delta - updates = self.schema.get_new_table_columns(table_name, storage_columns) + """Compares storage columns with schema table and produce delta columns difference""" + updates = self.schema.get_new_table_columns( + table_name, + storage_columns, + case_sensitive=self.capabilities.has_case_sensitive_identifiers + and self.capabilities.casefold_identifier is str, + ) logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") return updates @@ -526,16 +616,17 @@ def _row_to_schema_info(self, query: str, *args: Any) -> StorageSchemaInfo: pass # make utc datetime - inserted_at = pendulum.instance(row[4]) + inserted_at = pendulum.instance(row[2]) - return StorageSchemaInfo(row[0], row[1], row[2], row[3], inserted_at, schema_str) + return StorageSchemaInfo(row[4], row[3], row[0], row[1], inserted_at, schema_str) def _delete_schema_in_storage(self, schema: Schema) -> None: """ Delete all stored versions with the same name as given schema """ name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - self.sql_client.execute_sql(f"DELETE FROM {name} WHERE schema_name = %s;", schema.name) + (c_schema_name,) = self._norm_and_escape_columns("schema_name") + self.sql_client.execute_sql(f"DELETE FROM {name} WHERE {c_schema_name} = %s;", schema.name) def _update_schema_in_storage(self, schema: Schema) -> None: # get schema string or zip @@ -554,14 +645,21 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: self.sql_client.execute_sql( f"INSERT INTO {name}({self.version_table_schema_columns}) VALUES (%s, %s, %s, %s, %s," " %s);", - schema.stored_version_hash, - schema.name, schema.version, schema.ENGINE_VERSION, now_ts, + schema.name, + schema.stored_version_hash, schema_str, ) + def _verify_schema(self) -> None: + super()._verify_schema() + if exceptions := verify_sql_job_client_schema(self.schema, warnings=True): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] + class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): in_staging_mode: bool = False diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 9b73d7d28c..7912ac4561 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -30,13 +30,15 @@ class SqlClientBase(ABC, Generic[TNativeConn]): dbapi: ClassVar[DBApi] = None - capabilities: ClassVar[DestinationCapabilitiesContext] = None - def __init__(self, database_name: str, dataset_name: str) -> None: + def __init__( + self, database_name: str, dataset_name: str, capabilities: DestinationCapabilitiesContext + ) -> None: if not dataset_name: raise ValueError(dataset_name) self.dataset_name = dataset_name self.database_name = database_name + self.capabilities = capabilities @abstractmethod def open_connection(self) -> TNativeConn: @@ -75,9 +77,12 @@ def has_dataset(self) -> bool: SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE """ - db_params = self.fully_qualified_dataset_name(escape=False).split(".", 2) - if len(db_params) == 2: + catalog_name, schema_name, _ = self._get_information_schema_components() + db_params: List[str] = [] + if catalog_name is not None: query += " catalog_name = %s AND " + db_params.append(catalog_name) + db_params.append(schema_name) query += "schema_name = %s" rows = self.execute_sql(query, *db_params) return len(rows) > 0 @@ -137,16 +142,39 @@ def execute_many( ret.append(result) return ret - @abstractmethod + def catalog_name(self, escape: bool = True) -> Optional[str]: + # default is no catalogue component of the name, which typically means that + # connection is scoped to a current database + return None + def fully_qualified_dataset_name(self, escape: bool = True) -> str: - pass + return ".".join(self.make_qualified_table_name_path(None, escape=escape)) def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: + return ".".join(self.make_qualified_table_name_path(table_name, escape=escape)) + + def make_qualified_table_name_path( + self, table_name: Optional[str], escape: bool = True + ) -> List[str]: + """Returns a list with path components leading from catalog to table_name. + Used to construct fully qualified names. `table_name` is optional. + """ + path: List[str] = [] + if catalog_name := self.catalog_name(escape=escape): + path.append(catalog_name) + dataset_name = self.capabilities.casefold_identifier(self.dataset_name) if escape: - table_name = self.capabilities.escape_identifier(table_name) - return f"{self.fully_qualified_dataset_name(escape=escape)}.{table_name}" + dataset_name = self.capabilities.escape_identifier(dataset_name) + path.append(dataset_name) + if table_name: + table_name = self.capabilities.casefold_identifier(table_name) + if escape: + table_name = self.capabilities.escape_identifier(table_name) + path.append(table_name) + return path def escape_column_name(self, column_name: str, escape: bool = True) -> str: + column_name = self.capabilities.casefold_identifier(column_name) if escape: return self.capabilities.escape_identifier(column_name) return column_name @@ -191,6 +219,18 @@ def is_dbapi_exception(ex: Exception) -> bool: def make_staging_dataset_name(dataset_name: str) -> str: return dataset_name + "_staging" + def _get_information_schema_components(self, *tables: str) -> Tuple[str, str, List[str]]: + """Gets catalog name, schema name and name of the tables in format that can be directly + used to query INFORMATION_SCHEMA. catalog name is optional: in that case None is + returned in the first element of the tuple. + """ + schema_path = self.make_qualified_table_name_path(None, escape=False) + return ( + self.catalog_name(escape=False), + schema_path[-1], + [self.make_qualified_table_name_path(table, escape=False)[-1] for table in tables], + ) + # # generate sql statements # @@ -220,6 +260,11 @@ def _get_columns(self) -> List[str]: return [c[0] for c in self.native_cursor.description] def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: + """Fetches results as data frame in full or in specified chunks. + + May use native pandas/arrow reader if available. Depending on + the native implementation chunk size may vary. + """ from dlt.common.libs.pandas_sql import _wrap_result columns = self._get_columns() diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 4f8e29ae0d..b9539fe114 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -117,7 +117,7 @@ def _generate_insert_sql( table_name = sql_client.make_qualified_table_name(table["name"]) columns = ", ".join( map( - sql_client.capabilities.escape_identifier, + sql_client.escape_column_name, get_columns_names_with_prop(table, "name"), ) ) @@ -361,10 +361,8 @@ def gen_merge_sql( sql: List[str] = [] root_table = table_chain[0] - escape_id = sql_client.capabilities.escape_identifier + escape_column_id = sql_client.escape_column_name escape_lit = sql_client.capabilities.escape_literal - if escape_id is None: - escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier if escape_lit is None: escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal @@ -376,13 +374,13 @@ def gen_merge_sql( # get merge and primary keys from top level primary_keys = list( map( - escape_id, + escape_column_id, get_columns_names_with_prop(root_table, "primary_key"), ) ) merge_keys = list( map( - escape_id, + escape_column_id, get_columns_names_with_prop(root_table, "merge_key"), ) ) @@ -419,7 +417,7 @@ def gen_merge_sql( f" {root_table['name']} so it is not possible to link child tables to it.", ) # get first unique column - unique_column = escape_id(unique_columns[0]) + unique_column = escape_column_id(unique_columns[0]) # create temp table with unique identifier create_delete_temp_table_sql, delete_temp_table_name = ( cls.gen_delete_temp_table_sql( @@ -442,14 +440,14 @@ def gen_merge_sql( f" {table['name']} so it is not possible to refer to top level table" f" {root_table['name']} unique column {unique_column}", ) - root_key_column = escape_id(root_key_columns[0]) + root_key_column = escape_column_id(root_key_columns[0]) sql.append( cls.gen_delete_from_sql( table_name, root_key_column, delete_temp_table_name, unique_column ) ) - # delete from top table now that child tables have been prcessed + # delete from top table now that child tables have been processed sql.append( cls.gen_delete_from_sql( root_table_name, unique_column, delete_temp_table_name, unique_column @@ -461,10 +459,10 @@ def gen_merge_sql( hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") if hard_delete_col is not None: # any value indicates a delete for non-boolean columns - not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" + not_deleted_cond = f"{escape_column_id(hard_delete_col)} IS NULL" if root_table["columns"][hard_delete_col]["data_type"] == "bool": # only True values indicate a delete for boolean columns - not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" + not_deleted_cond += f" OR {escape_column_id(hard_delete_col)} = {escape_lit(False)}" # get dedup sort information dedup_sort = get_dedup_sort_tuple(root_table) @@ -503,7 +501,7 @@ def gen_merge_sql( uniq_column = unique_column if table.get("parent") is None else root_key_column insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" - columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) + columns = list(map(escape_column_id, get_columns_names_with_prop(table, "name"))) col_str = ", ".join(columns) select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" if len(primary_keys) > 0 and len(table_chain) == 1: @@ -534,9 +532,11 @@ def gen_scd2_sql( # get column names caps = sql_client.capabilities - escape_id = caps.escape_identifier - from_, to = list(map(escape_id, get_validity_column_names(root_table))) # validity columns - hash_ = escape_id( + escape_column_id = sql_client.escape_column_name + from_, to = list( + map(escape_column_id, get_validity_column_names(root_table)) + ) # validity columns + hash_ = escape_column_id( get_first_column_name_with_prop(root_table, "x-row-version") ) # row hash column @@ -568,7 +568,7 @@ def gen_scd2_sql( """) # insert new active records in root table - columns = map(escape_id, list(root_table["columns"].keys())) + columns = map(escape_column_id, list(root_table["columns"].keys())) col_str = ", ".join([c for c in columns if c not in (from_, to)]) sql.append(f""" INSERT INTO {root_table_name} ({col_str}, {from_}, {to}) @@ -592,7 +592,7 @@ def gen_scd2_sql( " it is not possible to link child tables to it.", ) # get first unique column - unique_column = escape_id(unique_columns[0]) + unique_column = escape_column_id(unique_columns[0]) # TODO: - based on deterministic child hashes (OK) # - if row hash changes all is right # - if it does not we only capture new records, while we should replace existing with those in stage diff --git a/dlt/destinations/utils.py b/dlt/destinations/utils.py index c02460fe58..d24ad7c5a7 100644 --- a/dlt/destinations/utils.py +++ b/dlt/destinations/utils.py @@ -1,9 +1,23 @@ import re +from typing import Any, List, Optional, Tuple + +from dlt.common import logger +from dlt.common.schema import Schema +from dlt.common.schema.exceptions import SchemaCorruptedException +from dlt.common.schema.typing import MERGE_STRATEGIES, TTableSchema +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_first_column_name_with_prop, + has_column_with_prop, + pipeline_state_table, +) from typing import Any, cast, Tuple, Dict, Type from dlt.destinations.exceptions import DatabaseTransientException from dlt.extract import DltResource, resource as make_resource +RE_DATA_TYPE = re.compile(r"([A-Z]+)\((\d+)(?:,\s?(\d+))?\)") + def ensure_resource(data: Any) -> DltResource: """Wraps `data` in a DltResource if it's not a DltResource already.""" @@ -13,6 +27,119 @@ def ensure_resource(data: Any) -> DltResource: return cast(DltResource, make_resource(data, name=resource_name)) +def info_schema_null_to_bool(v: str) -> bool: + """Converts INFORMATION SCHEMA truth values to Python bool""" + if v in ("NO", "0"): + return False + elif v in ("YES", "1"): + return True + raise ValueError(v) + + +def parse_db_data_type_str_with_precision(db_type: str) -> Tuple[str, Optional[int], Optional[int]]: + """Parses a db data type with optional precision or precision and scale information""" + # Search for matches using the regular expression + match = RE_DATA_TYPE.match(db_type) + + # If the pattern matches, extract the type, precision, and scale + if match: + db_type = match.group(1) + precision = int(match.group(2)) + scale = int(match.group(3)) if match.group(3) else None + return db_type, precision, scale + + # If the pattern does not match, return the original type without precision and scale + return db_type, None, None + + +def get_pipeline_state_query_columns() -> TTableSchema: + """We get definition of pipeline state table without columns we do not need for the query""" + state_table = pipeline_state_table() + # we do not need version_hash to be backward compatible as long as we can + state_table["columns"].pop("version_hash") + return state_table + + +def verify_sql_job_client_schema(schema: Schema, warnings: bool = True) -> List[Exception]: + log = logger.warning if warnings else logger.info + # collect all exceptions to show all problems in the schema + exception_log: List[Exception] = [] + + # verifies schema settings specific to sql job client + for table in schema.data_tables(): + table_name = table["name"] + if table.get("write_disposition") == "merge": + if "x-merge-strategy" in table and table["x-merge-strategy"] not in MERGE_STRATEGIES: # type: ignore[typeddict-item] + exception_log.append( + SchemaCorruptedException( + schema.name, + f'"{table["x-merge-strategy"]}" is not a valid merge strategy. ' # type: ignore[typeddict-item] + f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""", + ) + ) + if ( + table.get("x-merge-strategy") == "delete-insert" + and not has_column_with_prop(table, "primary_key") + and not has_column_with_prop(table, "merge_key") + ): + log( + f"Table {table_name} has `write_disposition` set to `merge`" + " and `merge_strategy` set to `delete-insert`, but no primary or" + " merge keys defined." + " dlt will fall back to `append` for this table." + ) + if has_column_with_prop(table, "hard_delete"): + if len(get_columns_names_with_prop(table, "hard_delete")) > 1: + exception_log.append( + SchemaCorruptedException( + schema.name, + f'Found multiple "hard_delete" column hints for table "{table_name}" in' + f' schema "{schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "hard_delete"))}.', + ) + ) + if table.get("write_disposition") in ("replace", "append"): + log( + f"""The "hard_delete" column hint for column "{get_first_column_name_with_prop(table, 'hard_delete')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{schema.name}" will be ignored.' + ' The "hard_delete" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if has_column_with_prop(table, "dedup_sort"): + if len(get_columns_names_with_prop(table, "dedup_sort")) > 1: + exception_log.append( + SchemaCorruptedException( + schema.name, + f'Found multiple "dedup_sort" column hints for table "{table_name}" in' + f' schema "{schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "dedup_sort"))}.', + ) + ) + if table.get("write_disposition") in ("replace", "append"): + log( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if table.get("write_disposition") == "merge" and not has_column_with_prop( + table, "primary_key" + ): + log( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when a' + " primary key has been specified." + ) + return exception_log + + def _convert_to_old_pyformat( new_style_string: str, args: Tuple[Any, ...], operational_error_cls: Type[Exception] ) -> Tuple[str, Dict[str, Any]]: diff --git a/dlt/extract/__init__.py b/dlt/extract/__init__.py index 03b2e59539..4029241634 100644 --- a/dlt/extract/__init__.py +++ b/dlt/extract/__init__.py @@ -4,13 +4,14 @@ from dlt.extract.decorators import source, resource, transformer, defer from dlt.extract.incremental import Incremental from dlt.extract.wrappers import wrap_additional_type -from dlt.extract.extractors import materialize_schema_item +from dlt.extract.extractors import materialize_schema_item, with_file_import __all__ = [ "DltResource", "DltSource", "with_table_name", "with_hints", + "with_file_import", "make_hints", "source", "resource", diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 2bb4a3ce87..ad10ef3ad3 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -35,22 +35,20 @@ from dlt.common.schema.schema import Schema from dlt.common.schema.typing import ( TColumnNames, + TFileFormat, TWriteDisposition, TWriteDispositionConfig, TAnySchemaColumns, TSchemaContract, TTableFormat, ) -from dlt.extract.hints import make_hints -from dlt.extract.utils import ( - simulate_func_call, - wrap_compat_transformer, - wrap_resource_gen, -) from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.typing import AnyFun, ParamSpec, Concatenate, TDataItem, TDataItems from dlt.common.utils import get_callable_name, get_module_name, is_inner_callable + +from dlt.extract.hints import make_hints +from dlt.extract.utils import simulate_func_call from dlt.extract.exceptions import ( CurrentSourceNotAvailable, DynamicNameNotStandaloneResource, @@ -64,8 +62,6 @@ SourceNotAFunction, CurrentSourceSchemaNotAvailable, ) -from dlt.extract.incremental import IncrementalResourceWrapper - from dlt.extract.items import TTableHintTemplate from dlt.extract.source import DltSource from dlt.extract.resource import DltResource, TUnboundDltResource, TDltResourceImpl @@ -210,16 +206,16 @@ def decorator( source_sections = (known_sections.SOURCES, source_section, effective_name) conf_f = with_config(f, spec=spec, sections=source_sections) - def _eval_rv(_rv: Any) -> TDltSourceImpl: + def _eval_rv(_rv: Any, schema_copy: Schema) -> TDltSourceImpl: """Evaluates return value from the source function or coroutine""" if _rv is None: - raise SourceDataIsNone(schema.name) + raise SourceDataIsNone(schema_copy.name) # if generator, consume it immediately if inspect.isgenerator(_rv): _rv = list(_rv) # convert to source - s = _impl_cls.from_data(schema.clone(update_normalizers=True), source_section, _rv) + s = _impl_cls.from_data(schema_copy, source_section, _rv) # apply hints if max_table_nesting is not None: s.max_table_nesting = max_table_nesting @@ -231,7 +227,10 @@ def _eval_rv(_rv: Any) -> TDltSourceImpl: @wraps(conf_f) def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: """Wrap a regular function, injection context must be a part of the wrap""" - with Container().injectable_context(SourceSchemaInjectableContext(schema)): + # clone the schema passed to decorator, update normalizers, remove processing hints + # NOTE: source may be called several times in many different settings + schema_copy = schema.clone(update_normalizers=True, remove_processing_hints=True) + with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name @@ -239,18 +238,21 @@ def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: ConfigSectionContext( pipeline_name=pipeline_name, sections=source_sections, - source_state_key=schema.name, + source_state_key=schema_copy.name, ) ): rv = conf_f(*args, **kwargs) - return _eval_rv(rv) + return _eval_rv(rv, schema_copy) @wraps(conf_f) async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: """In case of co-routine we must wrap the whole injection context in awaitable, there's no easy way to avoid some code duplication """ - with Container().injectable_context(SourceSchemaInjectableContext(schema)): + # clone the schema passed to decorator, update normalizers, remove processing hints + # NOTE: source may be called several times in many different settings + schema_copy = schema.clone(update_normalizers=True, remove_processing_hints=True) + with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name @@ -258,11 +260,11 @@ async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: ConfigSectionContext( pipeline_name=pipeline_name, sections=source_sections, - source_state_key=schema.name, + source_state_key=schema_copy.name, ) ): rv = await conf_f(*args, **kwargs) - return _eval_rv(rv) + return _eval_rv(rv, schema_copy) # get spec for wrapped function SPEC = get_fun_spec(conf_f) @@ -296,6 +298,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -316,6 +319,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -336,6 +340,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -359,6 +364,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -378,6 +384,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -413,9 +420,10 @@ def resource( If not present, the name of the decorated function will be used. table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. - max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. + write_disposition (TTableHintTemplate[TWriteDispositionConfig], optional): Controls how to write data to a table. Accepts a shorthand string literal or configuration dictionary. Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. @@ -433,7 +441,12 @@ def resource( This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. schema_contract (TSchemaContract, optional): Schema contract settings that will be applied to all resources of this source (if not overridden in the resource itself) - table_format (Literal["iceberg"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, other destinations ignore this hint. + + table_format (Literal["iceberg", "delta"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, and "delta" on the filesystem. + Other destinations ignore this hint. + + file_format (Literal["preferred", ...], optional): Format of the file in which resource data is stored. Useful when importing external files. Use `preferred` to force + a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. selected (bool, optional): When `True` `dlt pipeline` will extract and load this resource, if `False`, the resource will be ignored. @@ -464,6 +477,7 @@ def make_resource(_name: str, _section: str, _data: Any) -> TDltResourceImpl: merge_key=merge_key, schema_contract=schema_contract, table_format=table_format, + file_format=file_format, ) resource = _impl_cls.from_data( @@ -574,10 +588,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: str = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -591,10 +609,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: TTableHintTemplate[str] = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -612,10 +634,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: str = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -629,10 +655,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: TTableHintTemplate[str] = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -646,10 +676,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: TTableHintTemplate[str] = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -692,6 +726,8 @@ def transformer( table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. + write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. @@ -704,6 +740,14 @@ def transformer( merge_key (str | Sequence[str]): A column name or a list of column names that define a merge key. Typically used with "merge" write disposition to remove overlapping data ranges ie. to keep a single record for a given day. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + schema_contract (TSchemaContract, optional): Schema contract settings that will be applied to all resources of this source (if not overridden in the resource itself) + + table_format (Literal["iceberg", "delta"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, and "delta" on the filesystem. + Other destinations ignore this hint. + + file_format (Literal["preferred", ...], optional): Format of the file in which resource data is stored. Useful when importing external files. Use `preferred` to force + a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. + selected (bool, optional): When `True` `dlt pipeline` will extract and load this resource, if `False`, the resource will be ignored. spec (Type[BaseConfiguration], optional): A specification of configuration and secret values required by the source. @@ -722,10 +766,14 @@ def transformer( f, name=name, table_name=table_name, + max_table_nesting=max_table_nesting, write_disposition=write_disposition, columns=columns, primary_key=primary_key, merge_key=merge_key, + schema_contract=schema_contract, + table_format=table_format, + file_format=file_format, selected=selected, spec=spec, standalone=standalone, @@ -741,8 +789,11 @@ def _maybe_load_schema_for_callable(f: AnyFun, name: str) -> Optional[Schema]: try: file = inspect.getsourcefile(f) if file: - return SchemaStorage.load_schema_file(os.path.dirname(file), name) - + schema = SchemaStorage.load_schema_file( + os.path.dirname(file), name, remove_processing_hints=True + ) + schema.update_normalizers() + return schema except SchemaNotFoundError: pass return None diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index f8966c3ced..5769be1a8d 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -170,6 +170,9 @@ def add_item(item: Any) -> bool: class Extract(WithStepInfo[ExtractMetrics, ExtractInfo]): + original_data: Any + """Original data from which the extracted DltSource was created. Will be used to describe in extract info""" + def __init__( self, schema_storage: SchemaStorage, @@ -181,6 +184,7 @@ def __init__( self.collector = collector self.schema_storage = schema_storage self.extract_storage = ExtractStorage(normalize_storage_config) + # TODO: this should be passed together with DltSource to extract() self.original_data: Any = original_data super().__init__() @@ -370,7 +374,9 @@ def extract( load_package_state_update: Optional[Dict[str, Any]] = None, ) -> str: # generate load package to be able to commit all the sources together later - load_id = self.extract_storage.create_load_package(source.discover_schema()) + load_id = self.extract_storage.create_load_package( + source.discover_schema(), reuse_exiting_package=True + ) with Container().injectable_context( SourceSchemaInjectableContext(source.schema) ), Container().injectable_context( @@ -405,14 +411,10 @@ def extract( commit_load_package_state() return load_id - def commit_packages(self, pipline_state_doc: TPipelineStateDoc = None) -> None: - """Commits all extracted packages to normalize storage, and adds the pipeline state to the load package""" + def commit_packages(self) -> None: + """Commits all extracted packages to normalize storage""" # commit load packages for load_id, metrics in self._load_id_metrics.items(): - if pipline_state_doc: - package_state = self.extract_storage.new_packages.get_load_package_state(load_id) - package_state["pipeline_state"] = {**pipline_state_doc, "dlt_load_id": load_id} - self.extract_storage.new_packages.save_load_package_state(load_id, package_state) self.extract_storage.commit_new_load_package( load_id, self.schema_storage[metrics[0]["schema_name"]] ) diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index 48f0d6968e..4a1de2517d 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -1,14 +1,14 @@ from copy import copy -from typing import Set, Dict, Any, Optional, List +from typing import Set, Dict, Any, Optional, List, Union from dlt.common.configuration import known_sections, resolve_configuration, with_config from dlt.common import logger from dlt.common.configuration.specs import BaseConfiguration, configspec +from dlt.common.data_writers import DataWriterMetrics from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import MissingDependencyException - from dlt.common.runtime.collector import Collector, NULL_COLLECTOR -from dlt.common.typing import TDataItems, TDataItem +from dlt.common.typing import TDataItems, TDataItem, TLoaderFileFormat from dlt.common.schema import Schema, utils from dlt.common.schema.typing import ( TSchemaContractDict, @@ -17,9 +17,9 @@ TTableSchemaColumns, TPartialTableSchema, ) -from dlt.extract.hints import HintsMeta +from dlt.extract.hints import HintsMeta, TResourceHints from dlt.extract.resource import DltResource -from dlt.extract.items import TableNameMeta +from dlt.extract.items import DataItemWithMeta, TableNameMeta from dlt.extract.storage import ExtractorItemStorage from dlt.normalize.configuration import ItemsNormalizerConfiguration @@ -47,6 +47,50 @@ def materialize_schema_item() -> MaterializedEmptyList: return MaterializedEmptyList() +class ImportFileMeta(HintsMeta): + __slots__ = ("file_path", "metrics", "file_format") + + def __init__( + self, + file_path: str, + metrics: DataWriterMetrics, + file_format: TLoaderFileFormat = None, + hints: TResourceHints = None, + create_table_variant: bool = None, + ) -> None: + super().__init__(hints, create_table_variant) + self.file_path = file_path + self.metrics = metrics + self.file_format = file_format + + +def with_file_import( + file_path: str, + file_format: TLoaderFileFormat, + items_count: int = 0, + hints: Union[TResourceHints, TDataItem] = None, +) -> DataItemWithMeta: + """Marks file under `file_path` to be associated with current resource and imported into the load package as a file of + type `file_format`. + + You can provide optional `hints` that will be applied to the current resource. Note that you should avoid schema inference at + runtime if possible and if that is not possible - to do that only once per extract process. Use `make_hints` in `mark` module + to create hints. You can also pass Arrow table or Pandas data frame form which schema will be taken (but content discarded). + Create `TResourceHints` with `make_hints`. + + If number of records in `file_path` is known, pass it in `items_count` so `dlt` can generate correct extract metrics. + + Note that `dlt` does not sniff schemas from data and will not guess right file format for you. + """ + metrics = DataWriterMetrics(file_path, items_count, 0, 0, 0) + item: TDataItem = None + # if hints are dict assume that this is dlt schema, if not - that it is arrow table + if not isinstance(hints, dict): + item = hints + hints = None + return DataItemWithMeta(ImportFileMeta(file_path, metrics, file_format, hints, False), item) + + class Extractor: @configspec class ExtractorConfiguration(BaseConfiguration): @@ -78,7 +122,7 @@ def __init__( def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None: """Write `items` to `resource` optionally computing table schemas and revalidating/filtering data""" - if isinstance(meta, HintsMeta): + if isinstance(meta, HintsMeta) and meta.hints: # update the resource with new hints, remove all caches so schema is recomputed # and contracts re-applied resource.merge_hints(meta.hints, meta.create_table_variant) @@ -93,7 +137,7 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No self._write_to_static_table(resource, table_name, items, meta) else: # table has name or other hints depending on data items - self._write_to_dynamic_table(resource, items) + self._write_to_dynamic_table(resource, items, meta) def write_empty_items_file(self, table_name: str) -> None: table_name = self.naming.normalize_table_identifier(table_name) @@ -129,7 +173,24 @@ def _write_item( if isinstance(items, MaterializedEmptyList): self.resources_with_empty.add(resource_name) - def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> None: + def _import_item( + self, + table_name: str, + resource_name: str, + meta: ImportFileMeta, + ) -> None: + metrics = self.item_storage.import_items_file( + self.load_id, + self.schema.name, + table_name, + meta.file_path, + meta.metrics, + meta.file_format, + ) + self.collector.update(table_name, inc=metrics.items_count) + self.resources_with_items.add(resource_name) + + def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems, meta: Any) -> None: if not isinstance(items, list): items = [items] @@ -143,7 +204,10 @@ def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> N ) # write to storage with inferred table name if table_name not in self._filtered_tables: - self._write_item(table_name, resource.name, item) + if isinstance(meta, ImportFileMeta): + self._import_item(table_name, resource.name, meta) + else: + self._write_item(table_name, resource.name, item) def _write_to_static_table( self, resource: DltResource, table_name: str, items: TDataItems, meta: Any @@ -151,11 +215,16 @@ def _write_to_static_table( if table_name not in self._table_contracts: items = self._compute_and_update_table(resource, table_name, items, meta) if table_name not in self._filtered_tables: - self._write_item(table_name, resource.name, items) + if isinstance(meta, ImportFileMeta): + self._import_item(table_name, resource.name, meta) + else: + self._write_item(table_name, resource.name, items) def _compute_table(self, resource: DltResource, items: TDataItems, meta: Any) -> TTableSchema: """Computes a schema for a new or dynamic table and normalizes identifiers""" - return self.schema.normalize_table_identifiers(resource.compute_table_schema(items, meta)) + return utils.normalize_table_identifiers( + resource.compute_table_schema(items, meta), self.schema.naming + ) def _compute_and_update_table( self, resource: DltResource, table_name: str, items: TDataItems, meta: Any @@ -173,11 +242,11 @@ def _compute_and_update_table( # this is a new table so allow evolve once if schema_contract["columns"] != "evolve" and self.schema.is_new_table(table_name): - computed_table["x-normalizer"] = {"evolve-columns-once": True} # type: ignore[typeddict-unknown-key] + computed_table["x-normalizer"] = {"evolve-columns-once": True} existing_table = self.schema._schema_tables.get(table_name, None) if existing_table: # TODO: revise this. computed table should overwrite certain hints (ie. primary and merge keys) completely - diff_table = utils.diff_table(existing_table, computed_table) + diff_table = utils.diff_table(self.schema.name, existing_table, computed_table) else: diff_table = computed_table @@ -335,7 +404,7 @@ def _compute_table( computed_table = super()._compute_table(resource, item, Any) # Merge the columns to include primary_key and other hints that may be set on the resource if arrow_table: - utils.merge_table(computed_table, arrow_table) + utils.merge_table(self.schema.name, computed_table, arrow_table) else: arrow_table = copy(computed_table) arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema) @@ -353,8 +422,7 @@ def _compute_table( } # normalize arrow table before merging - arrow_table = self.schema.normalize_table_identifiers(arrow_table) - + arrow_table = utils.normalize_table_identifiers(arrow_table, self.schema.naming) # issue warnings when overriding computed with arrow override_warn: bool = False for col_name, column in arrow_table["columns"].items(): diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 6fd1928970..bc10177223 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -5,6 +5,7 @@ from dlt.common.schema.typing import ( TColumnNames, TColumnProp, + TFileFormat, TPartialTableSchema, TTableSchema, TTableSchemaColumns, @@ -48,6 +49,7 @@ class TResourceHints(TypedDict, total=False): incremental: Incremental[Any] schema_contract: TTableHintTemplate[TSchemaContract] table_format: TTableHintTemplate[TTableFormat] + file_format: TTableHintTemplate[TFileFormat] validator: ValidateItem original_columns: TTableHintTemplate[TAnySchemaColumns] @@ -72,6 +74,7 @@ def make_hints( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, ) -> TResourceHints: """A convenience function to create resource hints. Accepts both static and dynamic hints based on data. @@ -91,6 +94,7 @@ def make_hints( columns=clean_columns, # type: ignore schema_contract=schema_contract, # type: ignore table_format=table_format, # type: ignore + file_format=file_format, # type: ignore ) if not table_name: new_template.pop("name") @@ -209,6 +213,7 @@ def apply_hints( schema_contract: TTableHintTemplate[TSchemaContract] = None, additional_table_hints: Optional[Dict[str, TTableHintTemplate[Any]]] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, create_table_variant: bool = False, ) -> None: """Creates or modifies existing table schema by setting provided hints. Accepts both static and dynamic hints based on data. @@ -256,6 +261,7 @@ def apply_hints( merge_key, schema_contract, table_format, + file_format, ) else: t = self._clone_hints(t) @@ -320,6 +326,11 @@ def apply_hints( t["table_format"] = table_format else: t.pop("table_format", None) + if file_format is not None: + if file_format: + t["file_format"] = file_format + else: + t.pop("file_format", None) # set properties that can't be passed to make_hints if incremental is not None: @@ -375,6 +386,7 @@ def merge_hints( incremental=hints_template.get("incremental"), schema_contract=hints_template.get("schema_contract"), table_format=hints_template.get("table_format"), + file_format=hints_template.get("file_format"), create_table_variant=create_table_variant, ) diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 8b4cae4090..947e21f7b8 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -213,7 +213,7 @@ def compute_unique_values(self, item: "TAnyArrowItem", unique_columns: List[str] def compute_unique_values_with_index( self, item: "TAnyArrowItem", unique_columns: List[str] - ) -> List[Tuple[int, str]]: + ) -> List[Tuple[Any, str]]: if not unique_columns: return [] indices = item[self._dlt_index].to_pylist() @@ -318,12 +318,12 @@ def __call__( for i, uq_val in unique_values_index if uq_val in self.start_unique_hashes ] - # find rows with unique ids that were stored from previous run - remove_idx = pa.array(i for i, _ in unique_values_index) - # Filter the table - tbl = tbl.filter( - pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)) - ) + if len(unique_values_index) > 0: + # find rows with unique ids that were stored from previous run + remove_idx = pa.array(i for i, _ in unique_values_index) + tbl = tbl.filter( + pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)) + ) if ( self.last_value is None diff --git a/dlt/extract/items.py b/dlt/extract/items.py index fec31e2846..4cf8d2191f 100644 --- a/dlt/extract/items.py +++ b/dlt/extract/items.py @@ -160,6 +160,10 @@ class FilterItem(ItemTransform[bool]): def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: if isinstance(item, list): + # preserve empty lists + if len(item) == 0: + return item + if self._f_meta: item = [i for i in item if self._f_meta(i, meta)] else: diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index eecb570375..93eb9d1189 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -1,4 +1,3 @@ -from copy import deepcopy import inspect from functools import partial from typing import ( @@ -14,6 +13,7 @@ ) from typing_extensions import TypeVar, Self +from dlt.common import logger from dlt.common.configuration.inject import get_fun_spec, with_config from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import known_sections @@ -394,6 +394,11 @@ def _gen_wrap(gen: TPipeStep) -> TPipeStep: else: # keep function as function to not evaluate generators before pipe starts self._pipe.replace_gen(partial(_gen_wrap, gen)) + else: + logger.warning( + f"Setting add_limit to a transformer {self.name} has no effect. Set the limit on" + " the top level resource." + ) return self def parallelize(self: TDltResourceImpl) -> TDltResourceImpl: diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 658f884c40..9953b56117 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -11,6 +11,7 @@ from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnName, TSchemaContract +from dlt.common.schema.utils import normalize_table_identifiers from dlt.common.typing import StrAny, TDataItem from dlt.common.configuration.container import Container from dlt.common.pipeline import ( @@ -245,26 +246,39 @@ def exhausted(self) -> bool: @property def root_key(self) -> bool: """Enables merging on all resources by propagating root foreign key to child tables. This option is most useful if you plan to change write disposition of a resource to disable/enable merge""" + # this also check the normalizer type config = RelationalNormalizer.get_normalizer_config(self._schema).get("propagation") + data_normalizer = self._schema.data_item_normalizer + assert isinstance(data_normalizer, RelationalNormalizer) return ( config is not None and "root" in config - and "_dlt_id" in config["root"] - and config["root"]["_dlt_id"] == "_dlt_root_id" + and data_normalizer.c_dlt_id in config["root"] + and config["root"][data_normalizer.c_dlt_id] == data_normalizer.c_dlt_root_id ) @root_key.setter def root_key(self, value: bool) -> None: + # this also check the normalizer type + config = RelationalNormalizer.get_normalizer_config(self._schema) + data_normalizer = self._schema.data_item_normalizer + assert isinstance(data_normalizer, RelationalNormalizer) + if value is True: RelationalNormalizer.update_normalizer_config( - self._schema, {"propagation": {"root": {"_dlt_id": TColumnName("_dlt_root_id")}}} + self._schema, + { + "propagation": { + "root": { + data_normalizer.c_dlt_id: TColumnName(data_normalizer.c_dlt_root_id) + } + } + }, ) else: if self.root_key: - propagation_config = RelationalNormalizer.get_normalizer_config(self._schema)[ - "propagation" - ] - propagation_config["root"].pop("_dlt_id") # type: ignore + propagation_config = config["propagation"] + propagation_config["root"].pop(data_normalizer.c_dlt_id) @property def resources(self) -> DltResourceDict: @@ -291,8 +305,8 @@ def discover_schema(self, item: TDataItem = None) -> Schema: for r in self.selected_resources.values(): # names must be normalized here with contextlib.suppress(DataItemRequiredForDynamicTableHints): - partial_table = self._schema.normalize_table_identifiers( - r.compute_table_schema(item) + partial_table = normalize_table_identifiers( + r.compute_table_schema(item), self._schema.naming ) schema.update_table(partial_table) return schema diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index 8abc679ea2..836da516e9 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,11 +1,10 @@ -from typing import TYPE_CHECKING, Literal, Optional +from typing import Optional from dlt.common.configuration import configspec +from dlt.common.destination.capabilities import TLoaderParallelismStrategy from dlt.common.storages import LoadStorageConfiguration from dlt.common.runners.configuration import PoolRunnerConfiguration, TPoolType -TLoaderParallelismStrategy = Literal["parallel", "table-sequential", "sequential"] - @configspec class LoaderConfiguration(PoolRunnerConfiguration): diff --git a/dlt/load/load.py b/dlt/load/load.py index abbeee5ddf..9d1d953f7f 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -80,7 +80,6 @@ def __init__( self.initial_client_config = initial_client_config self.initial_staging_client_config = initial_staging_client_config self.destination = destination - self.capabilities = destination.capabilities() self.staging_destination = staging_destination self.pool = NullExecutor() self.load_storage: LoadStorage = self.create_storage(is_storage_owner) @@ -88,7 +87,7 @@ def __init__( super().__init__() def create_storage(self, is_storage_owner: bool) -> LoadStorage: - supported_file_formats = self.capabilities.supported_loader_file_formats + supported_file_formats = self.destination.capabilities().supported_loader_file_formats if self.staging_destination: supported_file_formats = ( self.staging_destination.capabilities().supported_loader_file_formats @@ -150,7 +149,7 @@ def w_spool_job( if job_info.file_format not in self.load_storage.supported_job_file_formats: raise LoadClientUnsupportedFileFormats( job_info.file_format, - self.capabilities.supported_loader_file_formats, + self.destination.capabilities().supported_loader_file_formats, file_path, ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") @@ -197,7 +196,7 @@ def w_spool_job( def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJob]]: # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs load_files = filter_new_jobs( - self.load_storage.list_new_jobs(load_id), self.capabilities, self.config + self.load_storage.list_new_jobs(load_id), self.destination.capabilities(), self.config ) file_count = len(load_files) if file_count == 0: @@ -259,13 +258,20 @@ def create_followup_jobs( schema.tables, starting_job.job_file_info().table_name ) # if all tables of chain completed, create follow up jobs - all_jobs = self.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs_states = self.load_storage.normalized_packages.list_all_jobs_with_states( + load_id + ) if table_chain := get_completed_table_chain( - schema, all_jobs, top_job_table, starting_job.job_file_info().job_id() + schema, all_jobs_states, top_job_table, starting_job.job_file_info().job_id() ): table_chain_names = [table["name"] for table in table_chain] + # create job infos that contain full path to job table_chain_jobs = [ - job for job in all_jobs if job.job_file_info.table_name in table_chain_names + self.load_storage.normalized_packages.job_to_job_info(load_id, *job_state) + for job_state in all_jobs_states + if job_state[1].table_name in table_chain_names + # job being completed is still in started_jobs + and job_state[0] in ("completed_jobs", "started_jobs") ] if follow_up_jobs := client.create_table_chain_completed_followup_jobs( table_chain, table_chain_jobs @@ -359,7 +365,7 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) ) ): job_client.complete_load(load_id) - self._maybe_trancate_staging_dataset(schema, job_client) + self._maybe_truncate_staging_dataset(schema, job_client) self.load_storage.complete_load_package(load_id, aborted) # collect package info @@ -432,10 +438,10 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: self.complete_package(load_id, schema, False) return # update counter we only care about the jobs that are scheduled to be loaded - package_info = self.load_storage.normalized_packages.get_load_package_info(load_id) - total_jobs = reduce(lambda p, c: p + len(c), package_info.jobs.values(), 0) - no_failed_jobs = len(package_info.jobs["failed_jobs"]) - no_completed_jobs = len(package_info.jobs["completed_jobs"]) + no_failed_jobs + package_jobs = self.load_storage.normalized_packages.get_load_package_jobs(load_id) + total_jobs = reduce(lambda p, c: p + len(c), package_jobs.values(), 0) + no_failed_jobs = len(package_jobs["failed_jobs"]) + no_completed_jobs = len(package_jobs["completed_jobs"]) + no_failed_jobs self.collector.update("Jobs", no_completed_jobs, total_jobs) if no_failed_jobs > 0: self.collector.update( @@ -447,26 +453,28 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: remaining_jobs = self.complete_jobs(load_id, jobs, schema) if len(remaining_jobs) == 0: # get package status - package_info = self.load_storage.normalized_packages.get_load_package_info( + package_jobs = self.load_storage.normalized_packages.get_load_package_jobs( load_id ) # possibly raise on failed jobs if self.config.raise_on_failed_jobs: - if package_info.jobs["failed_jobs"]: - failed_job = package_info.jobs["failed_jobs"][0] + if package_jobs["failed_jobs"]: + failed_job = package_jobs["failed_jobs"][0] raise LoadClientJobFailed( load_id, - failed_job.job_file_info.job_id(), - failed_job.failed_message, + failed_job.job_id(), + self.load_storage.normalized_packages.get_job_failed_message( + load_id, failed_job + ), ) # possibly raise on too many retries if self.config.raise_on_max_retries: - for new_job in package_info.jobs["new_jobs"]: - r_c = new_job.job_file_info.retry_count + for new_job in package_jobs["new_jobs"]: + r_c = new_job.retry_count if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: raise LoadClientJobRetry( load_id, - new_job.job_file_info.job_id(), + new_job.job_id(), r_c, self.config.raise_on_max_retries, ) @@ -512,7 +520,7 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: return TRunMetrics(False, len(self.load_storage.list_normalized_packages())) - def _maybe_trancate_staging_dataset(self, schema: Schema, job_client: JobClientBase) -> None: + def _maybe_truncate_staging_dataset(self, schema: Schema, job_client: JobClientBase) -> None: """ Truncate the staging dataset if one used, and configuration requests truncation. diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 4e5099855b..7db05674fa 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -1,8 +1,8 @@ -from typing import List, Set, Iterable, Callable, Optional, Sequence +from typing import List, Set, Iterable, Callable, Optional, Tuple, Sequence from itertools import groupby from dlt.common import logger -from dlt.common.storages.load_package import LoadJobInfo, PackageStorage +from dlt.common.storages.load_package import LoadJobInfo, PackageStorage, TJobState from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, get_child_tables, @@ -22,7 +22,7 @@ def get_completed_table_chain( schema: Schema, - all_jobs: Iterable[LoadJobInfo], + all_jobs: Iterable[Tuple[TJobState, ParsedLoadJobFileName]], top_merged_table: TTableSchema, being_completed_job_id: str = None, ) -> List[TTableSchema]: @@ -54,8 +54,8 @@ def get_completed_table_chain( else: # all jobs must be completed in order for merge to be created if any( - job.state not in ("failed_jobs", "completed_jobs") - and job.job_file_info.job_id() != being_completed_job_id + job[0] not in ("failed_jobs", "completed_jobs") + and job[1].job_id() != being_completed_job_id for job in table_jobs ): return None diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 6678f6edee..5f84d57d7a 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -6,6 +6,7 @@ from dlt.common.data_writers import DataWriterMetrics from dlt.common.data_writers.writers import ArrowToObjectAdapter from dlt.common.json import custom_pua_decode, may_have_pua +from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer from dlt.common.runtime import signals from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict from dlt.common.schema.utils import has_table_seen_data @@ -149,7 +150,7 @@ def _normalize_chunk( continue # theres a new table or new columns in existing table # update schema and save the change - schema.update_table(partial_table) + schema.update_table(partial_table, normalize_identifiers=False) table_updates = schema_update.setdefault(table_name, []) table_updates.append(partial_table) @@ -200,6 +201,7 @@ def __call__( ) schema_updates.append(partial_update) logger.debug(f"Processed {line_no+1} lines from file {extracted_items_file}") + # empty json files are when replace write disposition is used in order to truncate table(s) if line is None and root_table_name in self.schema.tables: # TODO: we should push the truncate jobs via package state # not as empty jobs. empty jobs should be reserved for @@ -234,8 +236,9 @@ def _write_with_dlt_columns( schema = self.schema load_id = self.load_id schema_update: TSchemaUpdate = {} + data_normalizer = schema.data_item_normalizer - if add_dlt_id: + if add_dlt_id and isinstance(data_normalizer, RelationalNormalizer): table_update = schema.update_table( { "name": root_table_name, @@ -249,7 +252,7 @@ def _write_with_dlt_columns( new_columns.append( ( -1, - pa.field("_dlt_id", pyarrow.pyarrow.string(), nullable=False), + pa.field(data_normalizer.c_dlt_id, pyarrow.pyarrow.string(), nullable=False), lambda batch: pa.array(generate_dlt_ids(batch.num_rows)), ) ) @@ -375,3 +378,32 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch ) return base_schema_update + + +class FileImportNormalizer(ItemsNormalizer): + def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSchemaUpdate]: + logger.info( + f"Table {root_table_name} {self.item_storage.writer_spec.file_format} file" + f" {extracted_items_file} will be directly imported without normalization" + ) + completed_columns = self.schema.get_table_columns(root_table_name) + if not completed_columns: + logger.warning( + f"Table {root_table_name} has no completed columns for imported file" + f" {extracted_items_file} and will not be created! Pass column hints to the" + " resource or with dlt.mark.with_hints or create the destination table yourself." + ) + with self.normalize_storage.extracted_packages.storage.open_file( + extracted_items_file, "rb" + ) as f: + # TODO: sniff the schema depending on a file type + file_metrics = DataWriterMetrics(extracted_items_file, 0, f.tell(), 0, 0) + parts = ParsedLoadJobFileName.parse(extracted_items_file) + self.item_storage.import_items_file( + self.load_id, + self.schema.name, + parts.table_name, + self.normalize_storage.extracted_packages.storage.make_full_path(extracted_items_file), + file_metrics, + ) + return [] diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 75cb9be707..98154cd5cf 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -1,33 +1,23 @@ import os import itertools -from typing import Callable, List, Dict, NamedTuple, Sequence, Tuple, Set, Optional +from typing import List, Dict, Sequence, Optional, Callable from concurrent.futures import Future, Executor from dlt.common import logger from dlt.common.runtime.signals import sleep from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config -from dlt.common.configuration.container import Container -from dlt.common.data_writers import ( - DataWriter, - DataWriterMetrics, - TDataItemFormat, - resolve_best_writer_spec, - get_best_writer_spec, - is_native_writer, -) +from dlt.common.data_writers import DataWriterMetrics from dlt.common.data_writers.writers import EMPTY_DATA_WRITER_METRICS from dlt.common.runners import TRunMetrics, Runnable, NullExecutor from dlt.common.runtime import signals from dlt.common.runtime.collector import Collector, NULL_COLLECTOR -from dlt.common.schema.typing import TStoredSchema, TTableSchema +from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import merge_schema_updates from dlt.common.storages import ( NormalizeStorage, SchemaStorage, LoadStorage, - LoadStorageConfiguration, - NormalizeStorageConfiguration, ParsedLoadJobFileName, ) from dlt.common.schema import TSchemaUpdate, Schema @@ -40,20 +30,10 @@ ) from dlt.common.storages.exceptions import LoadPackageNotFound from dlt.common.storages.load_package import LoadPackageInfo -from dlt.common.utils import chunks from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed -from dlt.normalize.items_normalizers import ( - ArrowItemsNormalizer, - JsonLItemsNormalizer, - ItemsNormalizer, -) - - -class TWorkerRV(NamedTuple): - schema_updates: List[TSchemaUpdate] - file_metrics: List[DataWriterMetrics] +from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV # normalize worker wrapping function signature @@ -99,211 +79,19 @@ def create_storages(self) -> None: config=self.config._load_storage_config, ) - @staticmethod - def w_normalize_files( - config: NormalizeConfiguration, - normalize_storage_config: NormalizeStorageConfiguration, - loader_storage_config: LoadStorageConfiguration, - stored_schema: TStoredSchema, - load_id: str, - extracted_items_files: Sequence[str], - ) -> TWorkerRV: - destination_caps = config.destination_capabilities - schema_updates: List[TSchemaUpdate] = [] - # normalizers are cached per table name - item_normalizers: Dict[str, ItemsNormalizer] = {} - - preferred_file_format = ( - destination_caps.preferred_loader_file_format - or destination_caps.preferred_staging_file_format - ) - # TODO: capabilities.supported_*_formats can be None, it should have defaults - supported_file_formats = destination_caps.supported_loader_file_formats or [] - supported_table_formats = destination_caps.supported_table_formats or [] - - # process all files with data items and write to buffered item storage - with Container().injectable_context(destination_caps): - schema = Schema.from_stored_schema(stored_schema) - normalize_storage = NormalizeStorage(False, normalize_storage_config) - load_storage = LoadStorage(False, supported_file_formats, loader_storage_config) - - def _get_items_normalizer( - item_format: TDataItemFormat, table_schema: Optional[TTableSchema] - ) -> ItemsNormalizer: - table_name = table_schema["name"] - if table_name in item_normalizers: - return item_normalizers[table_name] - - if ( - "table_format" in table_schema - and table_schema["table_format"] not in supported_table_formats - ): - logger.warning( - "Destination does not support the configured `table_format` value " - f"`{table_schema['table_format']}` for table `{table_schema['name']}`. " - "The setting will probably be ignored." - ) - - items_preferred_file_format = preferred_file_format - items_supported_file_formats = supported_file_formats - if destination_caps.loader_file_format_adapter is not None: - items_preferred_file_format, items_supported_file_formats = ( - destination_caps.loader_file_format_adapter( - preferred_file_format, - ( - supported_file_formats.copy() - if isinstance(supported_file_formats, list) - else supported_file_formats - ), - table_schema=table_schema, - ) - ) - - # force file format - best_writer_spec = None - if config.loader_file_format: - if config.loader_file_format in items_supported_file_formats: - # TODO: pass supported_file_formats, when used in pipeline we already checked that - # but if normalize is used standalone `supported_loader_file_formats` may be unresolved - best_writer_spec = get_best_writer_spec( - item_format, config.loader_file_format - ) - else: - logger.warning( - f"The configured value `{config.loader_file_format}` " - "for `loader_file_format` is not supported for table " - f"`{table_schema['name']}` and will be ignored. Dlt " - "will use a supported format instead." - ) - - if best_writer_spec is None: - # find best spec among possible formats taking into account destination preference - best_writer_spec = resolve_best_writer_spec( - item_format, items_supported_file_formats, items_preferred_file_format - ) - # if best_writer_spec.file_format != preferred_file_format: - # logger.warning( - # f"For data items yielded as {item_format} jobs in file format" - # f" {preferred_file_format} cannot be created." - # f" {best_writer_spec.file_format} jobs will be used instead." - # " This may decrease the performance." - # ) - item_storage = load_storage.create_item_storage(best_writer_spec) - if not is_native_writer(item_storage.writer_cls): - logger.warning( - f"For data items yielded as {item_format} and job file format" - f" {best_writer_spec.file_format} native writer could not be found. A" - f" {item_storage.writer_cls.__name__} writer is used that internally" - f" converts {item_format}. This will degrade performance." - ) - cls = ArrowItemsNormalizer if item_format == "arrow" else JsonLItemsNormalizer - logger.info( - f"Created items normalizer {cls.__name__} with writer" - f" {item_storage.writer_cls.__name__} for item format {item_format} and file" - f" format {item_storage.writer_spec.file_format}" - ) - norm = item_normalizers[table_name] = cls( - item_storage, - normalize_storage, - schema, - load_id, - config, - ) - return norm - - def _gather_metrics_and_close( - parsed_fn: ParsedLoadJobFileName, in_exception: bool - ) -> List[DataWriterMetrics]: - writer_metrics: List[DataWriterMetrics] = [] - try: - try: - for normalizer in item_normalizers.values(): - normalizer.item_storage.close_writers(load_id, skip_flush=in_exception) - except Exception: - # if we had exception during flushing the writers, close them without flushing - if not in_exception: - for normalizer in item_normalizers.values(): - normalizer.item_storage.close_writers(load_id, skip_flush=True) - raise - finally: - # always gather metrics - for normalizer in item_normalizers.values(): - norm_metrics = normalizer.item_storage.closed_files(load_id) - writer_metrics.extend(norm_metrics) - for normalizer in item_normalizers.values(): - normalizer.item_storage.remove_closed_files(load_id) - except Exception as exc: - if in_exception: - # swallow exception if we already handle exceptions - return writer_metrics - else: - # enclose the exception during the closing in job failed exception - job_id = parsed_fn.job_id() if parsed_fn else "" - raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) - return writer_metrics - - parsed_file_name: ParsedLoadJobFileName = None - try: - root_tables: Set[str] = set() - for extracted_items_file in extracted_items_files: - parsed_file_name = ParsedLoadJobFileName.parse(extracted_items_file) - # normalize table name in case the normalization changed - # NOTE: this is the best we can do, until a full lineage information is in the schema - root_table_name = schema.naming.normalize_table_identifier( - parsed_file_name.table_name - ) - root_tables.add(root_table_name) - normalizer = _get_items_normalizer( - DataWriter.item_format_from_file_extension(parsed_file_name.file_format), - stored_schema["tables"].get(root_table_name, {"name": root_table_name}), - ) - logger.debug( - f"Processing extracted items in {extracted_items_file} in load_id" - f" {load_id} with table name {root_table_name} and schema {schema.name}" - ) - partial_updates = normalizer(extracted_items_file, root_table_name) - schema_updates.extend(partial_updates) - logger.debug(f"Processed file {extracted_items_file}") - except Exception as exc: - job_id = parsed_file_name.job_id() if parsed_file_name else "" - writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=True) - raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) from exc - else: - writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=False) - - logger.info(f"Processed all items in {len(extracted_items_files)} files") - return TWorkerRV(schema_updates, writer_metrics) - - def update_table(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: + def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: for schema_update in schema_updates: for table_name, table_updates in schema_update.items(): logger.info( f"Updating schema for table {table_name} with {len(table_updates)} deltas" ) for partial_table in table_updates: - # merge columns - schema.update_table(partial_table) - - @staticmethod - def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[str]]: - # sort files so the same tables are in the same worker - files = list(sorted(files)) - - chunk_size = max(len(files) // no_groups, 1) - chunk_files = list(chunks(files, chunk_size)) - # distribute the remainder files to existing groups starting from the end - remainder_l = len(chunk_files) - no_groups - l_idx = 0 - while remainder_l > 0: - for idx, file in enumerate(reversed(chunk_files.pop())): - chunk_files[-l_idx - idx - remainder_l].append(file) # type: ignore - remainder_l -= 1 - l_idx = idx + 1 - return chunk_files + # merge columns where we expect identifiers to be normalized + schema.update_table(partial_table, normalize_identifiers=False) def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: workers: int = getattr(self.pool, "_max_workers", 1) - chunk_files = self.group_worker_files(files, workers) + chunk_files = group_worker_files(files, workers) schema_dict: TStoredSchema = schema.to_dict() param_chunk = [ ( @@ -319,10 +107,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW # return stats summary = TWorkerRV([], []) # push all tasks to queue - tasks = [ - (self.pool.submit(Normalize.w_normalize_files, *params), params) - for params in param_chunk - ] + tasks = [(self.pool.submit(w_normalize_files, *params), params) for params in param_chunk] while len(tasks) > 0: sleep(0.3) @@ -337,7 +122,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW result: TWorkerRV = pending.result() try: # gather schema from all manifests, validate consistency and combine - self.update_table(schema, result[0]) + self.update_schema(schema, result[0]) summary.schema_updates.extend(result.schema_updates) summary.file_metrics.extend(result.file_metrics) # update metrics @@ -358,7 +143,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW # TODO: it's time for a named tuple params = params[:3] + (schema_dict,) + params[4:] retry_pending: Future[TWorkerRV] = self.pool.submit( - Normalize.w_normalize_files, *params + w_normalize_files, *params ) tasks.append((retry_pending, params)) # remove finished tasks @@ -368,7 +153,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW return summary def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: - result = Normalize.w_normalize_files( + result = w_normalize_files( self.config, self.normalize_storage.config, self.load_storage.config, @@ -376,7 +161,7 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWor load_id, files, ) - self.update_table(schema, result.schema_updates) + self.update_schema(schema, result.schema_updates) self.collector.update("Files", len(result.file_metrics)) self.collector.update( "Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count @@ -399,7 +184,7 @@ def spool_files( # update normalizer specific info for table_name in table_metrics: table = schema.tables[table_name] - x_normalizer = table.setdefault("x-normalizer", {}) # type: ignore[typeddict-item] + x_normalizer = table.setdefault("x-normalizer", {}) # drop evolve once for all tables that seen data x_normalizer.pop("evolve-columns-once", None) # mark that table have seen data only if there was data diff --git a/dlt/normalize/worker.py b/dlt/normalize/worker.py new file mode 100644 index 0000000000..d5d4a028d9 --- /dev/null +++ b/dlt/normalize/worker.py @@ -0,0 +1,254 @@ +from typing import Callable, List, Dict, NamedTuple, Sequence, Set, Optional, Type + +from dlt.common import logger +from dlt.common.configuration.container import Container +from dlt.common.data_writers import ( + DataWriter, + DataWriterMetrics, + create_import_spec, + resolve_best_writer_spec, + get_best_writer_spec, + is_native_writer, +) +from dlt.common.utils import chunks +from dlt.common.schema.typing import TStoredSchema, TTableSchema +from dlt.common.storages import ( + NormalizeStorage, + LoadStorage, + LoadStorageConfiguration, + NormalizeStorageConfiguration, + ParsedLoadJobFileName, +) +from dlt.common.schema import TSchemaUpdate, Schema + +from dlt.normalize.configuration import NormalizeConfiguration +from dlt.normalize.exceptions import NormalizeJobFailed +from dlt.normalize.items_normalizers import ( + ArrowItemsNormalizer, + FileImportNormalizer, + JsonLItemsNormalizer, + ItemsNormalizer, +) + + +class TWorkerRV(NamedTuple): + schema_updates: List[TSchemaUpdate] + file_metrics: List[DataWriterMetrics] + + +def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[str]]: + # sort files so the same tables are in the same worker + files = list(sorted(files)) + + chunk_size = max(len(files) // no_groups, 1) + chunk_files = list(chunks(files, chunk_size)) + # distribute the remainder files to existing groups starting from the end + remainder_l = len(chunk_files) - no_groups + l_idx = 0 + while remainder_l > 0: + for idx, file in enumerate(reversed(chunk_files.pop())): + chunk_files[-l_idx - idx - remainder_l].append(file) # type: ignore + remainder_l -= 1 + l_idx = idx + 1 + return chunk_files + + +def w_normalize_files( + config: NormalizeConfiguration, + normalize_storage_config: NormalizeStorageConfiguration, + loader_storage_config: LoadStorageConfiguration, + stored_schema: TStoredSchema, + load_id: str, + extracted_items_files: Sequence[str], +) -> TWorkerRV: + destination_caps = config.destination_capabilities + schema_updates: List[TSchemaUpdate] = [] + # normalizers are cached per table name + item_normalizers: Dict[str, ItemsNormalizer] = {} + + preferred_file_format = ( + destination_caps.preferred_loader_file_format + or destination_caps.preferred_staging_file_format + ) + # TODO: capabilities.supported_*_formats can be None, it should have defaults + supported_file_formats = destination_caps.supported_loader_file_formats or [] + supported_table_formats = destination_caps.supported_table_formats or [] + + # process all files with data items and write to buffered item storage + with Container().injectable_context(destination_caps): + schema = Schema.from_stored_schema(stored_schema) + normalize_storage = NormalizeStorage(False, normalize_storage_config) + load_storage = LoadStorage(False, supported_file_formats, loader_storage_config) + + def _get_items_normalizer( + parsed_file_name: ParsedLoadJobFileName, table_schema: TTableSchema + ) -> ItemsNormalizer: + item_format = DataWriter.item_format_from_file_extension(parsed_file_name.file_format) + + table_name = table_schema["name"] + if table_name in item_normalizers: + return item_normalizers[table_name] + + if ( + "table_format" in table_schema + and table_schema["table_format"] not in supported_table_formats + ): + logger.warning( + "Destination does not support the configured `table_format` value " + f"`{table_schema['table_format']}` for table `{table_schema['name']}`. " + "The setting will probably be ignored." + ) + + items_preferred_file_format = preferred_file_format + items_supported_file_formats = supported_file_formats + if destination_caps.loader_file_format_adapter is not None: + items_preferred_file_format, items_supported_file_formats = ( + destination_caps.loader_file_format_adapter( + preferred_file_format, + ( + supported_file_formats.copy() + if isinstance(supported_file_formats, list) + else supported_file_formats + ), + table_schema=table_schema, + ) + ) + + best_writer_spec = None + if item_format == "file": + # if we want to import file, create a spec that may be used only for importing + best_writer_spec = create_import_spec( + parsed_file_name.file_format, items_supported_file_formats # type: ignore[arg-type] + ) + + config_loader_file_format = config.loader_file_format + if file_format := table_schema.get("file_format"): + # resource has a file format defined so use it + if file_format == "preferred": + # use destination preferred + config_loader_file_format = items_preferred_file_format + else: + # use resource format + config_loader_file_format = file_format + logger.info( + f"A file format for table {table_name} was specified to {file_format} in the" + f" resource so {config_loader_file_format} format being used." + ) + + if config_loader_file_format and best_writer_spec is None: + # force file format + if config_loader_file_format in items_supported_file_formats: + # TODO: pass supported_file_formats, when used in pipeline we already checked that + # but if normalize is used standalone `supported_loader_file_formats` may be unresolved + best_writer_spec = get_best_writer_spec(item_format, config_loader_file_format) + else: + logger.warning( + f"The configured value `{config_loader_file_format}` " + "for `loader_file_format` is not supported for table " + f"`{table_name}` and will be ignored. Dlt " + "will use a supported format instead." + ) + + if best_writer_spec is None: + # find best spec among possible formats taking into account destination preference + best_writer_spec = resolve_best_writer_spec( + item_format, items_supported_file_formats, items_preferred_file_format + ) + # if best_writer_spec.file_format != preferred_file_format: + # logger.warning( + # f"For data items yielded as {item_format} jobs in file format" + # f" {preferred_file_format} cannot be created." + # f" {best_writer_spec.file_format} jobs will be used instead." + # " This may decrease the performance." + # ) + item_storage = load_storage.create_item_storage(best_writer_spec) + if not is_native_writer(item_storage.writer_cls): + logger.warning( + f"For data items yielded as {item_format} and job file format" + f" {best_writer_spec.file_format} native writer could not be found. A" + f" {item_storage.writer_cls.__name__} writer is used that internally" + f" converts {item_format}. This will degrade performance." + ) + cls: Type[ItemsNormalizer] + if item_format == "arrow": + cls = ArrowItemsNormalizer + elif item_format == "object": + cls = JsonLItemsNormalizer + else: + cls = FileImportNormalizer + logger.info( + f"Created items normalizer {cls.__name__} with writer" + f" {item_storage.writer_cls.__name__} for item format {item_format} and file" + f" format {item_storage.writer_spec.file_format}" + ) + norm = item_normalizers[table_name] = cls( + item_storage, + normalize_storage, + schema, + load_id, + config, + ) + return norm + + def _gather_metrics_and_close( + parsed_fn: ParsedLoadJobFileName, in_exception: bool + ) -> List[DataWriterMetrics]: + writer_metrics: List[DataWriterMetrics] = [] + try: + try: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=in_exception) + except Exception: + # if we had exception during flushing the writers, close them without flushing + if not in_exception: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=True) + raise + finally: + # always gather metrics + for normalizer in item_normalizers.values(): + norm_metrics = normalizer.item_storage.closed_files(load_id) + writer_metrics.extend(norm_metrics) + for normalizer in item_normalizers.values(): + normalizer.item_storage.remove_closed_files(load_id) + except Exception as exc: + if in_exception: + # swallow exception if we already handle exceptions + return writer_metrics + else: + # enclose the exception during the closing in job failed exception + job_id = parsed_fn.job_id() if parsed_fn else "" + raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) + return writer_metrics + + parsed_file_name: ParsedLoadJobFileName = None + try: + root_tables: Set[str] = set() + for extracted_items_file in extracted_items_files: + parsed_file_name = ParsedLoadJobFileName.parse(extracted_items_file) + # normalize table name in case the normalization changed + # NOTE: this is the best we can do, until a full lineage information is in the schema + root_table_name = schema.naming.normalize_table_identifier( + parsed_file_name.table_name + ) + root_tables.add(root_table_name) + normalizer = _get_items_normalizer( + parsed_file_name, + stored_schema["tables"].get(root_table_name, {"name": root_table_name}), + ) + logger.debug( + f"Processing extracted items in {extracted_items_file} in load_id" + f" {load_id} with table name {root_table_name} and schema {schema.name}" + ) + partial_updates = normalizer(extracted_items_file, root_table_name) + schema_updates.extend(partial_updates) + logger.debug(f"Processed file {extracted_items_file}") + except Exception as exc: + job_id = parsed_file_name.job_id() if parsed_file_name else "" + writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=True) + raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) from exc + else: + writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=False) + + logger.info(f"Processed all items in {len(extracted_items_files)} files") + return TWorkerRV(schema_updates, writer_metrics) diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 20ba0b07d0..4efc7716e6 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -173,31 +173,39 @@ def attach( pipeline_name: str = None, pipelines_dir: str = None, pipeline_salt: TSecretValue = None, - full_refresh: Optional[bool] = None, - dev_mode: bool = False, - credentials: Any = None, + destination: TDestinationReferenceArg = None, + staging: TDestinationReferenceArg = None, progress: TCollectorArg = _NULL_COLLECTOR, **injection_kwargs: Any, ) -> Pipeline: - """Attaches to the working folder of `pipeline_name` in `pipelines_dir` or in default directory. Requires that valid pipeline state exists in working folder.""" + """Attaches to the working folder of `pipeline_name` in `pipelines_dir` or in default directory. Requires that valid pipeline state exists in working folder. + Pre-configured `destination` and `staging` factories may be provided. If not present, default factories are created from pipeline state. + """ ensure_correct_pipeline_kwargs(attach, **injection_kwargs) - full_refresh_argument_deprecated("attach", full_refresh) # if working_dir not provided use temp folder if not pipelines_dir: pipelines_dir = get_dlt_pipelines_dir() progress = collector_from_name(progress) + destination = Destination.from_reference( + destination or injection_kwargs["destination_type"], + destination_name=injection_kwargs["destination_name"], + ) + staging = Destination.from_reference( + staging or injection_kwargs.get("staging_type", None), + destination_name=injection_kwargs.get("staging_name", None), + ) # create new pipeline instance p = Pipeline( pipeline_name, pipelines_dir, pipeline_salt, + destination, + staging, None, None, None, - credentials, - None, None, - full_refresh if full_refresh is not None else dev_mode, + False, # always False as dev_mode so we do not wipe the working folder progress, True, last_config(**injection_kwargs), diff --git a/dlt/pipeline/dbt.py b/dlt/pipeline/dbt.py index ee900005fd..0b6ec5f896 100644 --- a/dlt/pipeline/dbt.py +++ b/dlt/pipeline/dbt.py @@ -38,7 +38,7 @@ def get_venv( # keep venv inside pipeline if path is relative if not os.path.isabs(venv_path): pipeline._pipeline_storage.create_folder(venv_path, exists_ok=True) - venv_dir = pipeline._pipeline_storage.make_full_path(venv_path) + venv_dir = pipeline._pipeline_storage.make_full_path_safe(venv_path) else: venv_dir = venv_path # try to restore existing venv diff --git a/dlt/pipeline/mark.py b/dlt/pipeline/mark.py index 3956d9bbe2..5f3122e7a5 100644 --- a/dlt/pipeline/mark.py +++ b/dlt/pipeline/mark.py @@ -2,6 +2,7 @@ from dlt.extract import ( with_table_name, with_hints, + with_file_import, make_hints, materialize_schema_item as materialize_table_schema, ) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 8dfb93b8da..11f8d6223e 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1,6 +1,5 @@ import contextlib import os -import datetime # noqa: 251 from contextlib import contextmanager from functools import wraps from typing import ( @@ -38,7 +37,6 @@ DestinationUndefinedEntity, ) from dlt.common.exceptions import MissingDependencyException -from dlt.common.normalizers import explicit_normalizers, import_normalizers from dlt.common.runtime import signals, initialize_runtime from dlt.common.schema.typing import ( TColumnNames, @@ -84,12 +82,12 @@ DestinationClientStagingConfiguration, DestinationClientDwhWithStagingConfiguration, ) +from dlt.common.normalizers.naming import NamingConvention from dlt.common.pipeline import ( ExtractInfo, LoadInfo, NormalizeInfo, PipelineContext, - StepInfo, TStepInfo, SupportsPipeline, TPipelineLocalState, @@ -104,7 +102,7 @@ from dlt.common.warnings import deprecated, Dlt04DeprecationWarning from dlt.common.versioned_state import json_encode_state, json_decode_state -from dlt.extract import DltSource, DltResource +from dlt.extract import DltSource from dlt.extract.exceptions import SourceExhausted from dlt.extract.extract import Extract, data_to_sources from dlt.normalize import Normalize @@ -125,7 +123,6 @@ PipelineStepFailed, SqlClientNotAvailable, FSClientNotAvailable, - PipelineNeverRan, ) from dlt.pipeline.trace import ( PipelineTrace, @@ -360,14 +357,14 @@ def __init__( self._init_working_dir(pipeline_name, pipelines_dir) with self.managed_state() as state: + self.credentials = credentials + self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline) # changing the destination could be dangerous if pipeline has pending load packages - self._set_destinations(destination=destination, staging=staging) + self._set_destinations(destination=destination, staging=staging, initializing=True) # set the pipeline properties from state, destination and staging will not be set self._state_to_props(state) # we overwrite the state with the values from init self._set_dataset_name(dataset_name) - self.credentials = credentials - self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline) def drop(self, pipeline_name: str = None) -> "Pipeline": """Deletes local pipeline state, schemas and any working files. @@ -448,14 +445,13 @@ def extract( refresh=refresh or self.refresh, ) # extract state - state: TPipelineStateDoc = None if self.config.restore_from_destination: # this will update state version hash so it will not be extracted again by with_state_sync - state = self._bump_version_and_extract_state( + self._bump_version_and_extract_state( self._container[StateInjectableContext].state, True, extract_step ) # commit load packages with state - extract_step.commit_packages(state) + extract_step.commit_packages() return self._get_step_info(extract_step) except Exception as exc: # emit step info @@ -867,6 +863,11 @@ def state(self) -> TPipelineState: """Returns a dictionary with the pipeline state""" return self._get_state() + @property + def naming(self) -> NamingConvention: + """Returns naming convention of the default schema""" + return self._get_schema_or_create().naming + @property def last_trace(self) -> PipelineTrace: """Returns or loads last trace generated by pipeline. The trace is loaded from standard location.""" @@ -1021,7 +1022,7 @@ def _get_schema_or_create(self, schema_name: str = None) -> Schema: return Schema(self.pipeline_name) def _sql_job_client(self, schema: Schema, credentials: Any = None) -> SqlJobClientBase: - client_config = self._get_destination_client_initial_config(credentials) + client_config = self._get_destination_client_initial_config(credentials=credentials) client = self._get_destination_clients(schema, client_config)[0] if isinstance(client, SqlJobClientBase): return client @@ -1142,10 +1143,6 @@ def _extract_source( source, max_parallel_items, workers, load_package_state_update=load_package_state_update ) - # save import with fully discovered schema - # NOTE: moved to with_schema_sync, remove this if all test pass - # self._schema_storage.save_import_schema_if_not_exists(source.schema) - # update live schema but not update the store yet source.schema = self._schema_storage.set_live_schema(source.schema) @@ -1253,10 +1250,28 @@ def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: "Please provide `destination` argument to `pipeline`, `run` or `load` method" " directly or via .dlt config.toml file or environment variable.", ) - return self.destination.capabilities() + # check if default schema is present + if ( + self.default_schema_name is not None + and self.default_schema_name in self._schema_storage + ): + naming = self.default_schema.naming + else: + naming = None + return self.destination.capabilities(naming=naming) def _get_staging_capabilities(self) -> Optional[DestinationCapabilitiesContext]: - return self.staging.capabilities() if self.staging is not None else None + if self.staging is None: + return None + # check if default schema is present + if ( + self.default_schema_name is not None + and self.default_schema_name in self._schema_storage + ): + naming = self.default_schema.naming + else: + naming = None + return self.staging.capabilities(naming=naming) def _validate_pipeline_name(self) -> None: try: @@ -1292,9 +1307,11 @@ def _set_destinations( destination_name: Optional[str] = None, staging: Optional[TDestinationReferenceArg] = None, staging_name: Optional[str] = None, + initializing: bool = False, ) -> None: - # destination_mod = DestinationReference.from_name(destination) - if destination: + destination_changed = destination is not None and destination != self.destination + # set destination if provided but do not swap if factory is the same + if destination_changed: self.destination = Destination.from_reference( destination, destination_name=destination_name ) @@ -1313,7 +1330,8 @@ def _set_destinations( staging = "filesystem" staging_name = "filesystem" - if staging: + staging_changed = staging is not None and staging != self.staging + if staging_changed: staging_module = Destination.from_reference(staging, destination_name=staging_name) if staging_module and not issubclass( staging_module.spec, DestinationClientStagingConfiguration @@ -1321,9 +1339,16 @@ def _set_destinations( raise DestinationNoStagingMode(staging_module.destination_name) self.staging = staging_module - with self._maybe_destination_capabilities(): - # default normalizers must match the destination - self._set_default_normalizers() + if staging_changed or destination_changed: + # make sure that capabilities can be generated + with self._maybe_destination_capabilities(): + # update normalizers in all live schemas, only when destination changed + if destination_changed and not initializing: + for schema in self._schema_storage.live_schemas.values(): + schema.update_normalizers() + # set new context + if not initializing: + self._set_context(is_active=True) @contextmanager def _maybe_destination_capabilities( @@ -1351,9 +1376,6 @@ def _maybe_destination_capabilities( if injected_caps: injected_caps.__exit__(None, None, None) - def _set_default_normalizers(self) -> None: - _, self._default_naming, _ = import_normalizers(explicit_normalizers()) - def _set_dataset_name(self, new_dataset_name: str) -> None: if not new_dataset_name and not self.dataset_name: # dataset name is required but not provided - generate the default now @@ -1600,7 +1622,7 @@ def _bump_version_and_extract_state( extract: Extract = None, load_package_state_update: Optional[Dict[str, Any]] = None, schema: Optional[Schema] = None, - ) -> TPipelineStateDoc: + ) -> None: """Merges existing state into `state` and extracts state using `storage` if extract_state is True. Storage will be created on demand. In that case the extracted package will be immediately committed. @@ -1608,13 +1630,24 @@ def _bump_version_and_extract_state( _, hash_, _ = bump_pipeline_state_version_if_modified(self._props_to_state(state)) should_extract = hash_ != state["_local"].get("_last_extracted_hash") if should_extract and extract_state: - data, doc = state_resource(state) - extract_ = extract or Extract( - self._schema_storage, self._normalize_storage_config(), original_data=data + extract_ = extract or Extract(self._schema_storage, self._normalize_storage_config()) + # create or get load package upfront to get load_id to create state doc + schema = schema or self.default_schema + # note that we preferably retrieve existing package for `schema` + # same thing happens in extract_.extract so the load_id is preserved + load_id = extract_.extract_storage.create_load_package( + schema, reuse_exiting_package=True ) + data, doc = state_resource(state, load_id) + # keep the original data to be used in the metrics + if extract_.original_data is None: + extract_.original_data = data + # append pipeline state to package state + load_package_state_update = load_package_state_update or {} + load_package_state_update["pipeline_state"] = doc self._extract_source( extract_, - data_to_sources(data, self, schema or self.default_schema)[0], + data_to_sources(data, self, schema)[0], 1, 1, load_package_state_update=load_package_state_update, @@ -1623,9 +1656,7 @@ def _bump_version_and_extract_state( mark_state_extracted(state, hash_) # commit only if we created storage if not extract: - extract_.commit_packages(doc) - return doc - return None + extract_.commit_packages() def _list_schemas_sorted(self) -> List[str]: """Lists schema names sorted to have deterministic state""" diff --git a/dlt/pipeline/state_sync.py b/dlt/pipeline/state_sync.py index 41009f2909..11648328f2 100644 --- a/dlt/pipeline/state_sync.py +++ b/dlt/pipeline/state_sync.py @@ -4,7 +4,8 @@ import dlt from dlt.common.pendulum import pendulum from dlt.common.typing import DictStrAny -from dlt.common.schema.typing import STATE_TABLE_NAME, TTableSchemaColumns +from dlt.common.schema.typing import PIPELINE_STATE_TABLE_NAME +from dlt.common.schema.utils import pipeline_state_table from dlt.common.destination.reference import WithStateSync, Destination, StateInfo from dlt.common.versioned_state import ( generate_state_version_hash, @@ -24,20 +25,6 @@ PIPELINE_STATE_ENGINE_VERSION = 4 LOAD_PACKAGE_STATE_KEY = "pipeline_state" -# state table columns -STATE_TABLE_COLUMNS: TTableSchemaColumns = { - "version": {"name": "version", "data_type": "bigint", "nullable": False}, - "engine_version": {"name": "engine_version", "data_type": "bigint", "nullable": False}, - "pipeline_name": {"name": "pipeline_name", "data_type": "text", "nullable": False}, - "state": {"name": "state", "data_type": "text", "nullable": False}, - "created_at": {"name": "created_at", "data_type": "timestamp", "nullable": False}, - "version_hash": { - "name": "version_hash", - "data_type": "text", - "nullable": True, - }, # set to nullable so we can migrate existing tables -} - def generate_pipeline_state_version_hash(state: TPipelineState) -> str: return generate_state_version_hash(state, exclude_attrs=["_local"]) @@ -98,27 +85,28 @@ def state_doc(state: TPipelineState, load_id: str = None) -> TPipelineStateDoc: state = copy(state) state.pop("_local") state_str = compress_state(state) - doc: TPipelineStateDoc = { - "version": state["_state_version"], - "engine_version": state["_state_engine_version"], - "pipeline_name": state["pipeline_name"], - "state": state_str, - "created_at": pendulum.now(), - "version_hash": state["_version_hash"], - } - if load_id: - doc["dlt_load_id"] = load_id - return doc + info = StateInfo( + version=state["_state_version"], + engine_version=state["_state_engine_version"], + pipeline_name=state["pipeline_name"], + state=state_str, + created_at=pendulum.now(), + version_hash=state["_version_hash"], + _dlt_load_id=load_id, + ) + return info.as_doc() -def state_resource(state: TPipelineState) -> Tuple[DltResource, TPipelineStateDoc]: - doc = state_doc(state) +def state_resource(state: TPipelineState, load_id: str) -> Tuple[DltResource, TPipelineStateDoc]: + doc = state_doc(state, load_id) + state_table = pipeline_state_table() return ( dlt.resource( [doc], - name=STATE_TABLE_NAME, - write_disposition="append", - columns=STATE_TABLE_COLUMNS, + name=PIPELINE_STATE_TABLE_NAME, + write_disposition=state_table["write_disposition"], + file_format=state_table["file_format"], + columns=state_table["columns"], ), doc, ) diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index 22cdc9b415..b6702797e9 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -420,6 +420,10 @@ def update_request(self, request: Request) -> None: request.url = self._next_reference + # Clear the query parameters from the previous request otherwise they + # will be appended to the next URL in Session.prepare_request + request.params = None + class HeaderLinkPaginator(BaseNextUrlPaginator): """A paginator that uses the 'Link' header in HTTP responses diff --git a/docs/examples/conftest.py b/docs/examples/conftest.py index 87ccffe53b..be1a03990b 100644 --- a/docs/examples/conftest.py +++ b/docs/examples/conftest.py @@ -1,3 +1,4 @@ +import sys import os import pytest from unittest.mock import patch @@ -47,7 +48,11 @@ def _initial_providers(): ): # extras work when container updated glob_ctx.add_extras() - yield + try: + sys.path.insert(0, dname) + yield + finally: + sys.path.pop(0) def pytest_configure(config): diff --git a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py index 380912a9a7..ce4b2a12d0 100644 --- a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py +++ b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py @@ -86,7 +86,7 @@ def bigquery_insert( pipeline_name="csv_to_bigquery_insert", destination=bigquery_insert, dataset_name="mydata", - full_refresh=True, + dev_mode=True, ) load_info = pipeline.run(resource(url=OWID_DISASTERS_URL)) diff --git a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py index 9d75d90f99..ba815d4fcd 100644 --- a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py +++ b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py @@ -38,7 +38,9 @@ from dlt.sources.helpers.rest_client import RESTClient, AuthConfigBase # access secrets to get openai key and instantiate embedding function -openai_api_key: str = dlt.secrets.get("destination.lancedb.credentials.embedding_model_provider_api_key") +openai_api_key: str = dlt.secrets.get( + "destination.lancedb.credentials.embedding_model_provider_api_key" +) func = get_registry().get("openai").create(name="text-embedding-3-small", api_key=openai_api_key) diff --git a/docs/examples/custom_naming/.dlt/config.toml b/docs/examples/custom_naming/.dlt/config.toml new file mode 100644 index 0000000000..ba5c8ab73a --- /dev/null +++ b/docs/examples/custom_naming/.dlt/config.toml @@ -0,0 +1,2 @@ +[sources.sql_ci_no_collision.schema] +naming="sql_ci_no_collision" \ No newline at end of file diff --git a/docs/examples/custom_naming/__init__.py b/docs/examples/custom_naming/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/custom_naming/custom_naming.py b/docs/examples/custom_naming/custom_naming.py new file mode 100644 index 0000000000..e99e582213 --- /dev/null +++ b/docs/examples/custom_naming/custom_naming.py @@ -0,0 +1,90 @@ +""" +--- +title: Create and use own naming convention +description: We demonstrate how to create naming conventions that allow UNICODE letters and never generate collisions +keywords: [example] +--- + +This example shows how to add and use custom naming convention. Naming conventions translate identifiers found in source data into identifiers in +destination, where rules for a valid identifier are constrained. + +Custom naming conventions are classes that derive from `NamingConvention` that you can import from `dlt.common.normalizers.naming`. We recommend the following module layout: +1. Each naming convention resides in a separate Python module (file) +2. The class is always named `NamingConvention` + +There are two naming conventions in this example: +1. A variant of `sql_ci` that generates identifier collisions with a low (user defined) probability by appending a deterministic tag to each name. +2. A variant of `sql_cs` that allows for LATIN (ie. umlaut) characters + +With this example you will learn to: +* Create a naming convention module with a recommended layout +* Use naming convention by explicitly passing it to `duckdb` destination factory +* Use naming convention by configuring it config.toml +* Changing the declared case sensitivity by overriding `is_case_sensitive` property +* Providing custom normalization logic by overriding `normalize_identifier` method + +""" + +import dlt + +if __name__ == "__main__": + # sql_cs_latin2 module + import sql_cs_latin2 # type: ignore[import-not-found] + + # create postgres destination with a custom naming convention. pass sql_cs_latin2 as module + # NOTE: ql_cs_latin2 is case sensitive and postgres accepts UNICODE letters in identifiers + dest_ = dlt.destinations.postgres( + "postgresql://loader:loader@localhost:5432/dlt_data", naming_convention=sql_cs_latin2 + ) + # run a pipeline + pipeline = dlt.pipeline( + pipeline_name="sql_cs_latin2_pipeline", + destination=dest_, + dataset_name="example_data", + dev_mode=True, + ) + # Extract, normalize, and load the data + load_info = pipeline.run([{"StückId": 1}], table_name="Ausrüstung") + print(load_info) + # make sure nothing failed + load_info.raise_on_failed_jobs() + with pipeline.sql_client() as client: + # NOTE: we quote case sensitive identifers + with client.execute_query('SELECT "StückId" FROM "Ausrüstung"') as cur: + print(cur.description) + print(cur.fetchone()) + + # sql_ci_no_collision (configured in config toml) + # NOTE: pipeline with name `sql_ci_no_collision` will create default schema with the same name + # so we are free to use it in config.toml to just affect this pipeline and leave the postgres pipeline as it is + pipeline = dlt.pipeline( + pipeline_name="sql_ci_no_collision", + destination="duckdb", + dataset_name="example_data", + dev_mode=True, + ) + # duckdb is case insensitive so tables and columns below would clash but sql_ci_no_collision prevents that + data_1 = {"ItemID": 1, "itemid": "collides"} + load_info = pipeline.run([data_1], table_name="BigData") + load_info.raise_on_failed_jobs() + + data_2 = {"1Data": 1, "_1data": "collides"} + # use colliding table + load_info = pipeline.run([data_2], table_name="bigdata") + load_info.raise_on_failed_jobs() + + with pipeline.sql_client() as client: + from duckdb import DuckDBPyConnection + + conn: DuckDBPyConnection = client.native_connection + # tags are deterministic so we can just use the naming convention to get table names to select + first_table = pipeline.default_schema.naming.normalize_table_identifier("BigData") + sql = f"DESCRIBE TABLE {first_table}" + print(sql) + print(conn.sql(sql)) + second_table = pipeline.default_schema.naming.normalize_table_identifier("bigdata") + sql = f"DESCRIBE TABLE {second_table}" + print(sql) + print(conn.sql(sql)) + + # print(pipeline.default_schema.to_pretty_yaml()) diff --git a/docs/examples/custom_naming/sql_ci_no_collision.py b/docs/examples/custom_naming/sql_ci_no_collision.py new file mode 100644 index 0000000000..276107ea2b --- /dev/null +++ b/docs/examples/custom_naming/sql_ci_no_collision.py @@ -0,0 +1,34 @@ +from typing import ClassVar + +from dlt.common.normalizers.naming.sql_cs_v1 import NamingConvention as SqlNamingConvention +from dlt.common.schema.typing import DLT_NAME_PREFIX + + +class NamingConvention(SqlNamingConvention): + """Case insensitive naming convention with all identifiers lowercases but with unique short tag added""" + + # we will reuse the code we use for shortening + # 1 in 100 prob of collision for identifiers identical after normalization + _DEFAULT_COLLISION_PROB: ClassVar[float] = 0.01 + + def normalize_identifier(self, identifier: str) -> str: + # compute unique tag on original (not normalized) identifier + # NOTE: you may wrap method below in lru_cache if you often normalize the same names + tag = self._compute_tag(identifier, self._DEFAULT_COLLISION_PROB) + # lower case + norm_identifier = identifier.lower() + # add tag if (not a dlt identifier) and tag was not added before (simple heuristics) + if "_4" in norm_identifier: + _, existing_tag = norm_identifier.rsplit("_4", 1) + has_tag = len(existing_tag) == len(tag) + else: + has_tag = False + if not norm_identifier.startswith(DLT_NAME_PREFIX) and not has_tag: + norm_identifier = norm_identifier + "_4" + tag + # run identifier through standard sql cleaning and shortening + return super().normalize_identifier(norm_identifier) + + @property + def is_case_sensitive(self) -> bool: + # switch the naming convention to case insensitive + return False diff --git a/docs/examples/custom_naming/sql_cs_latin2.py b/docs/examples/custom_naming/sql_cs_latin2.py new file mode 100644 index 0000000000..7cf31cc76a --- /dev/null +++ b/docs/examples/custom_naming/sql_cs_latin2.py @@ -0,0 +1,21 @@ +from typing import ClassVar + +# NOTE: we use regex library that supports unicode +import regex as re + +from dlt.common.normalizers.naming.sql_cs_v1 import NamingConvention as SqlNamingConvention +from dlt.common.typing import REPattern + + +class NamingConvention(SqlNamingConvention): + """Case sensitive naming convention which allows basic unicode characters, including latin 2 characters""" + + RE_NON_ALPHANUMERIC: ClassVar[REPattern] = re.compile(r"[^\p{Latin}\d_]+") # type: ignore + + def normalize_identifier(self, identifier: str) -> str: + # typically you'd change how a single + return super().normalize_identifier(identifier) + + @property + def is_case_sensitive(self) -> bool: + return True diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 809a6cfbd6..5fbba98a21 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -25,7 +25,7 @@ import os import dlt -from dlt.destinations.impl.weaviate import weaviate_adapter +from dlt.destinations.adapters import weaviate_adapter from PyPDF2 import PdfReader diff --git a/docs/examples/postgres_to_postgres/postgres_to_postgres.py b/docs/examples/postgres_to_postgres/postgres_to_postgres.py index f5327ee236..c6502f236a 100644 --- a/docs/examples/postgres_to_postgres/postgres_to_postgres.py +++ b/docs/examples/postgres_to_postgres/postgres_to_postgres.py @@ -91,7 +91,7 @@ def pg_resource_chunked( order_date: str, load_type: str = "merge", columns: str = "*", - credentials: ConnectionStringCredentials = dlt.secrets["sources.postgres.credentials"], + credentials: ConnectionStringCredentials = None, ): print( f"dlt.resource write_disposition: `{load_type}` -- ", @@ -162,6 +162,7 @@ def table_desc(table_name, pk, schema_name, order_date, columns="*"): table["order_date"], load_type=load_type, columns=table["columns"], + credentials=dlt.secrets["sources.postgres.credentials"], ) ) @@ -170,7 +171,7 @@ def table_desc(table_name, pk, schema_name, order_date, columns="*"): pipeline_name=pipeline_name, destination="duckdb", dataset_name=target_schema_name, - full_refresh=True, + dev_mode=True, progress="alive_progress", ) else: @@ -178,8 +179,8 @@ def table_desc(table_name, pk, schema_name, order_date, columns="*"): pipeline_name=pipeline_name, destination="postgres", dataset_name=target_schema_name, - full_refresh=False, - ) # full_refresh=False + dev_mode=False, + ) # dev_mode=False # start timer startTime = pendulum.now() diff --git a/docs/technical/README.md b/docs/technical/README.md deleted file mode 100644 index 6e2b5048a8..0000000000 --- a/docs/technical/README.md +++ /dev/null @@ -1,10 +0,0 @@ -## Finished documents - -1. [general_usage.md](general_usage.md) -2. [create_pipeline.md](create_pipeline.md) -3. [secrets_and_config.md](secrets_and_config.md) -4. [working_with_schemas.md](working_with_schemas.md) - -## In progress - -5. [customization_and_hacking.md](customization_and_hacking.md) diff --git a/docs/technical/create_pipeline.md b/docs/technical/create_pipeline.md deleted file mode 100644 index f6603d08b8..0000000000 --- a/docs/technical/create_pipeline.md +++ /dev/null @@ -1,441 +0,0 @@ -# Create Pipeline -marks features that are: - -⛔ not implemented, hard to add - -☮️ not implemented, easy to add - - -## Example from `dlt` module docstring -It is possible to create "intuitive" pipeline just by providing a list of objects to `dlt.run` methods No decorators and secret files, configurations are necessary. - -```python -import dlt -from dlt.sources.helpers import requests - -dlt.run( - requests.get("https://api.chess.com/pub/player/magnuscarlsen/games/2022/11").json()["games"], - destination="duckdb", - table_name="magnus_games" -) -``` - -Run your pipeline script -`$ python magnus_games.py` - -See and query your data with autogenerated Streamlit app -`$ dlt pipeline dlt_magnus_games show` - -## Source extractor function the preferred way -General guidelines: -1. the source extractor is a function decorated with `@dlt.source`. that function **yields** or **returns** a list of resources. -2. resources are generator functions that always **yield** data (enforced by exception which I hope is user friendly). Access to external endpoints, databases etc. should happen from that generator function. Generator functions may be decorated with `@dlt.resource` to provide alternative names, write disposition etc. -3. resource generator functions can be OFC parametrized and resources may be created dynamically -4. the resource generator function may yield **anything that is json serializable**. we prefer to yield _dict_ or list of dicts. -> yielding lists is much more efficient in terms of processing! -5. like any other iterator, the @dlt.source and @dlt.resource **can be iterated and thus extracted and loaded only once**, see example below. - -**Remarks:** - -1. the **@dlt.resource** let's you define the table schema hints: `name`, `write_disposition`, `columns` -2. the **@dlt.source** let's you define global schema props: `name` (which is also source name), `schema` which is Schema object if explicit schema is provided `nesting` to set nesting level etc. -3. decorators can also be used as functions ie in case of dlt.resource and `lazy_function` (see examples) - -```python -endpoints = ["songs", "playlist", "albums"] -# return list of resourced -return [dlt.resource(lazy_function(endpoint, name=endpoint) for endpoint in endpoints)] - -``` - -### Extracting data -Source function is not meant to extract the data, but in many cases getting some metadata ie. to generate dynamic resources (like in case of google sheets example) is unavoidable. The source function's body is evaluated **outside** the pipeline `run` (if `dlt.source` is a generator, it is immediately consumed). - -Actual extraction of the data should happen inside the `dlt.resource` which is lazily executed inside the `dlt` pipeline. - -> both a `dlt` source and resource are regular Python iterators and can be passed to any python function that accepts them ie to `list`. `dlt` will evaluate such iterators, also parallel and async ones and provide mock state to it. - -## Multiple resources and resource selection when loading -The source extraction function may contain multiple resources. The resources can be defined as multiple resource functions or created dynamically ie. with parametrized generators. -The user of the pipeline can check what resources are available and select the resources to load. - - -**each resource has a a separate resource function** -```python -from dlt.sources.helpers import requests -import dlt - -@dlt.source -def hubspot(...): - - @dlt.resource(write_disposition="replace") - def users(): - # calls to API happens here - ... - yield users - - @dlt.resource(write_disposition="append") - def transactions(): - ... - yield transactions - - # return a list of resources - return users, transactions - -# load all resources -taktile_data(1).run(destination=bigquery) -# load only decisions -taktile_data(1).with_resources("decisions").run(....) - -# alternative form: -source = taktile_data(1) -# select only decisions to be loaded -source.resources.select("decisions") -# see what is selected -print(source.selected_resources) -# same as this -print(source.resources.selected) -``` - -Except being accessible via `source.resources` dictionary, **every resource is available as an attribute of the source**. For the example above -```python -print(list(source.decisions)) # will iterate decisions resource -source.logs.selected = False # deselect resource -``` - -## Resources may be created dynamically -Here we implement a single parametrized function that **yields** data and we call it repeatedly. Mind that the function body won't be executed immediately, only later when generator is consumed in extract stage. - -```python - -@dlt.source -def spotify(): - - endpoints = ["songs", "playlists", "albums"] - - def get_resource(endpoint): - # here we yield the whole response - yield requests.get(url + "/" + endpoint).json() - - # here we yield resources because this produces cleaner code - for endpoint in endpoints: - # calling get_resource creates generator, the actual code of the function will be executed in extractor - yield dlt.resource(get_resource(endpoint), name=endpoint) - -``` - -## Unbound (parametrized) resources -Imagine the situation in which you have a resource for which you want (or require) user to pass some options ie. the number of records returned. - -> try it, it is ⚡ powerful - -1. In all examples above you do that via the source and returned resources are not parametrized. -OR -2. You can return a **parametrized (unbound)** resources from the source. - -```python - -@dlt.source -def chess(chess_api_url): - - # let people choose player title, the default is grand master - @dlt.resource - def players(title_filter="GM", max_results=10): - yield - - # ❗ return the players without the calling - return players - -s = chess("url") -# let's parametrize the resource to select masters. you simply call `bind` method on the resource to bind it -# if you do not bind it, the default values are used -s.players.bind("M", max_results=1000) -# load the masters -s.run() - -``` - -## A standalone @resource -A general purpose resource (ie. jsonl reader, generic sql query reader etc.) that you want to add to any of your sources or multiple instances of it to your pipelines? -Yeah definitely possible. Just replace `@source` with `@resource` decorator. - -```python -@dlt.resource(name="logs", write_disposition="append") -def taktile_data(initial_log_id, taktile_api_key=dlt.secret.value): - - # yes, this will also work but data will be obtained immediately when taktile_data() is called. - resp = requests.get( - "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - for item in resp.json()["result"]: - yield item - -# this will load the resource into default schema. see `general_usage.md) -dlt.run(source=taktile_data(1), destination=bigquery) - -``` -How standalone resource works: -1. It can be used like a source that contains only one resource (ie. single endpoint) -2. The main difference is that when extracted it will join the default schema in the pipeline (or explicitly passed schema) -3. It can be called from a `@source` function and then it becomes a resource of that source and joins the source schema - -## `dlt` state availability - -The state is a python dictionary-like object that is available within the `@dlt.source` and `@dlt.resource` decorated functions and may be read and written to. -The data within the state is loaded into destination together with any other extracted data and made automatically available to the source/resource extractor functions when they are run next time. -When using the state: -* Any JSON-serializable values can be written and the read from the state. -* The state available in the `dlt source` is read only and any changes will be discarded. Still it may be used to initialize the resources. -* The state available in the `dlt resource` is writable and written values will be available only once - -### State sharing and isolation across sources - -1. Each source and resources **in the same Python module** (no matter if they are standalone, inner or created dynamically) share the same state dictionary and is separated from other sources -2. Source accepts `section` argument which creates a separate state for that resource (and separate configuration as well). All sources with the same `section` share the state. -2. All the standalone resources and generators that do not belong to any source share the same state when being extracted (they are extracted withing ad-hoc created source) - -## Stream resources: dispatching data to several tables from single resources -What about resource like rasa tracker or singer tap that send a stream of events that should be routed to different tables? we have an answer (actually two): -1. in many cases the table name is based on the data item content (ie. you dispatch events of given type to different tables by event type). We can pass a function that takes the data item as input and returns table name. -```python -# send item to a table with name item["type"] -@dlt.resource(table_name=lambda i: i['type']) -def repo_events() -> Iterator[TDataItems]: - yield item -``` - -2. You can mark the yielded data with a table name (`dlt.mark.with_table_name`). This gives you full control on the name of the table - -see [here](docs/examples/sources/rasa/rasa.py) and [here](docs/examples/sources/singer_tap.py). - -## Source / resource config sections and arguments injection -You should read [secrets_and_config](secrets_and_config.md) now to understand how configs and credentials are passed to the decorated functions and how the users of them can configure their projects. - -Also look at the following [test](/tests/extract/test_decorators.py) : `test_source_sections` - -## Example sources and resources - -### With inner resource function -Resource functions can be placed inside the source extractor function. That lets them get access to source function input arguments and all the computations within the source function via so called closure. - -```python -from dlt.sources.helpers import requests -import dlt - -# the `dlt.source` tell the library that the decorated function is a source -# it will use function name `taktile_data` to name the source and the generated schema by default -# in general `@source` should **return** a list of resources or list of generators (function that yield data) -# @source may also **yield** resources or generators - if yielding is more convenient -# if @source returns or yields data - this will generate exception with a proper explanation. dlt user can always load the data directly without any decorators like in the previous example! -@dlt.source -def taktile_data(initial_log_id, taktile_api_key=dlt.secret.value): - - # the `dlt.resource` tells the `dlt.source` that the function defines a resource - # will use function name `logs` as resource/table name by default - # the function should **yield** the data items one by one or **yield** a list. - # here the decorator is optional: there are no parameters to `dlt.resource` - @dlt.resource - def logs(): - resp = requests.get( - "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - # option 1: yield the whole list - yield resp.json()["result"] - # or -> this is useful if you deal with a stream of data and for that you need an API that supports that, for example you could yield lists containing paginated results - for item in resp.json()["result"]: - yield item - - # as mentioned we return a resource or a list of resources - return logs - # this will also work - # return logs() -``` - -### With outer generator yielding data, and @resource created dynamically -```python - -def taktile_logs_data(initial_log_id, taktile_api_key=dlt.secret.value) - yield data - - -@dlt.source -def taktile_data(initial_log_id, taktile_api_key): - # pass the arguments and convert to resource - return dlt.resource(taktile_logs_data(initial_log_id, taktile_api_key), name="logs", write_disposition="append") -``` - -### A source with resources defined elsewhere -Example of the above -```python -from taktile.resources import logs - -@dlt.source -def taktile_data(initial_log_id, taktile_api_key=dlt.secret.value): - return logs(initial_log_id, taktile_api_key) -``` - -## Advanced Topics - -### Transformers ⚡ -This happens all the time: -1. We have an endpoint that returns a list of users and then we must get each profile with a separate call. -2. The situation above is getting even more complicated when we need that list in two places in our source ie. we want to get the profiles but also a list of transactions per user. - -Ideally we would obtain the list only once and then call and yield from the profiles and transactions endpoint in parallel so the extraction time is minimized. - -Here's example how to do that: [run resources and transformers in parallel threads](/docs/examples/chess/chess.py) and test named `test_evolve_schema` - -More on transformers: -1. you can have unbound (parametrized) transformers as well -2. you can use pipe '|' operator to pipe data from resources to transformers instead of binding them statically with `data_from`. -> see our [singer tap](/docs/examples/singer_tap_jsonl_example.py) example where we pipe a stream of document from `jsonl` into `raw_singer_tap` which is a standalone, unbound ⚡ transformer. -3. If transformer yields just one element you can `return` it instead. This allows you to apply the `retry` and `defer` (parallel execution) decorators directly to it. - -#### Transformer example - -Here we have a list of huge documents and we want to load into several tables. - -```python -@dlt.source -def spotify(): - - # deselect by default, we do not want to load the huge doc - @dlt.resource(selected=False) - def get_huge_doc(): - return requests.get(...) - - # make songs and playlists to be dependent on get_huge_doc - @dlt.transformer(data_from=get_huge_doc) - def songs(huge_doc): - yield huge_doc["songs"] - - @dlt.transformer(data_from=get_huge_doc) - def playlists(huge_doc): - yield huge_doc["playlists"] - - # as you can see the get_huge_doc is not even returned, nevertheless it will be evaluated (only once) - # the huge doc will not be extracted and loaded - return songs, playlists - # we could also use the pipe operator, intead of providing_data from - # return get_huge_doc | songs, get_huge_doc | playlists -``` - -## Data item transformations - -You can attach any number of transformations to your resource that are evaluated on item per item basis. The available transformation types: -* map - transform the data item -* filter - filter the data item -* yield map - a map that returns iterator (so single row may generate many rows) - -You can add and insert transformations on the `DltResource` object (ie. decorated function) -* resource.add_map -* resource.add_filter -* resource.add_yield_map - -> Transformations always deal with single items even if you return lists. - -You can add transformations to a resource (also within a source) **after it is created**. This allows to customize existing pipelines. The transformations may -be distributed with the pipeline or written ad hoc in pipeline script. -```python -# anonymize creates nice deterministic hash for any hashable data type (not implemented yet:) -from dlt.helpers import anonymize - -# example transformation provided by the user -def anonymize_user(user_data): - user_data["user_id"] = anonymize(user_data["user_id"]) - user_data["user_email"] = anonymize(user_data["user_email"]) - return user_data - -@dlt.source -def pipedrive(...): - ... - - @dlt.resource(write_disposition="replace") - def users(): - ... - users = requests.get(...) - ... - yield users - - return users, deals, customers -``` - -in pipeline script: -1. we want to remove user with id == "me" -2. we want to anonymize user data -3. we want to pivot `user_props` into KV table - -```python -from pipedrive import pipedrive, anonymize_user - -source = pipedrive() -# access resource in the source by name and add filter and map transformation -source.users.add_filter(lambda user: user["user_id"] != "me").add_map(anonymize_user) -# now we want to yield user props to separate table. we define our own generator function -def pivot_props(user): - # keep user - yield user - # yield user props to user_props table - yield from [ - dlt.mark.with_table_name({"user_id": user["user_id"], "name": k, "value": v}, "user_props") for k, v in user["props"] - ] - -source.user.add_yield_map(pivot_props) -pipeline.run(source) -``` - -We provide a library of various concrete transformations: - -* ☮️ a recursive versions of the map, filter and flat map which can be applied to any nesting level of the data item (the standard transformations work on recursion level 0). Possible applications - - ☮️ recursive rename of dict keys - - ☮️ converting all values to strings - - etc. - -## Some CS Theory - -### The power of decorators - -With decorators dlt can inspect and modify the code being decorated. -1. it knows what are the sources and resources without running them -2. it knows input arguments so it knows the config values and secret values (see `secrets_and_config`). with those we can generate deployments automatically -3. it can inject config and secret values automatically -4. it wraps the functions into objects that provide additional functionalities -- sources and resources are iterators so you can write -```python -items = list(source()) - -for item in source()["logs"]: - ... -``` -- you can select which resources to load with `source().select(*names)` -- you can add mappings and filters to resources - -### The power of yielding: The preferred way to write resources - -The Python function that yields is not a function but magical object that `dlt` can control: -1. it is not executed when you call it! the call just creates a generator (see below). in the example above `taktile_data(1)` will not execute the code inside, it will just return an object composed of function code and input parameters. dlt has control over the object and can execute the code later. this is called `lazy execution` -2. i can control when and how much of the code is executed. the function that yields typically looks like that - -```python -def lazy_function(endpoint_name): - # INIT - this will be executed only once when dlt wants! - get_configuration() - from_item = dlt.current.state.get("last_item", 0) - l = get_item_list_from_api(api_key, endpoint_name) - - # ITERATOR - this will be executed many times also when dlt wants more data! - for item in l: - yield requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json() - # CLEANUP - # this will be executed only once after the last item was yielded! - dlt.current.state["last_item"] = item["id"] -``` - -3. dlt will execute this generator in extractor. the whole execution is atomic (including writing to state). if anything fails with exception the whole extract function fails. -4. the execution can be parallelized by using a decorator or a simple modifier function ie: -```python -for item in l: - yield deferred(requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json()) -``` \ No newline at end of file diff --git a/docs/technical/general_usage.md b/docs/technical/general_usage.md index 19c93bcf38..336c892c66 100644 --- a/docs/technical/general_usage.md +++ b/docs/technical/general_usage.md @@ -90,7 +90,7 @@ p.extract([label1, label2, label3], name="labels") # will use default schema "s **By default, one dataset can handle multiple schemas**. The pipeline configuration option `use_single_dataset` controls the dataset layout in the destination. By default it is set to True. In that case only one dataset is created at the destination - by default dataset name which is the same as pipeline name. The dataset name can also be explicitly provided into `dlt.pipeline` `dlt.run` and `Pipeline::load` methods. -All the tables from all the schemas are stored in that dataset. The table names are **not prefixed** with schema names!. If there are any name clashes, tables in the destination will be unions of the fields of all the tables with same name in the schemas. +All the tables from all the schemas are stored in that dataset. The table names are **not prefixed** with schema names!. If there are any name collisions, tables in the destination will be unions of the fields of all the tables with same name in the schemas. **Enabling one dataset per schema layout** If you set `use_single_dataset` to False: @@ -181,44 +181,6 @@ The `run`, `extract`, `normalize` and `load` method raise `PipelineStepFailed` w > should we add it? I have a runner in `dlt` that would be easy to modify -## the `Pipeline` object -There are many ways to create or get current pipeline object. -```python - -# create and get default pipeline -p1 = dlt.pipeline() -# create explicitly configured pipeline -p2 = dlt.pipeline(pipeline_name="pipe", destination=bigquery) -# get recently created pipeline -assert dlt.pipeline() is p2 -# load data with recently created pipeline -assert dlt.run(taktile_data()) is p2 -assert taktile_data().run() is p2 - -``` - -The `Pipeline` object provides following functionalities: -1. `run`, `extract`, `normalize` and `load` methods -2. a `pipeline.schema` dictionary-like object to enumerate and get the schemas in pipeline -3. schema get with `pipeline.schemas[name]` is a live object: any modification to it is automatically applied to the pipeline with the next `run`, `load` etc. see [working_with_schemas.md](working_with_schemas.md) -4. it returns `sql_client` and `native_client` to get direct access to the destination (if destination supports SQL - currently all of them do) -5. it has several methods to inspect the pipeline state and I think those should be exposed via `dlt pipeline` CLI - -for example: -- list the extracted files if any -- list the load packages ready to load -- list the failed jobs in package -- show info on destination: what are the datasets, the current load_id, the current schema etc. - - -## Examples -[we have some here](/examples/) - -## command line interface - - -## logging -I need your input for user friendly logging. What should we log? What is important to see? ## pipeline runtime setup diff --git a/docs/technical/working_with_schemas.md b/docs/technical/working_with_schemas.md index d94edb8727..532f0e5a1d 100644 --- a/docs/technical/working_with_schemas.md +++ b/docs/technical/working_with_schemas.md @@ -1,134 +1,7 @@ -## General approach to define schemas -marks features that are: - -⛔ not implemented, hard to add - -☮️ not implemented, easy to add - -## Schema components - -### Schema content hash and version -Each schema file contains content based hash `version_hash` that is used to -1. detect manual changes to schema (ie. user edits content) -2. detect if the destination database schema is synchronized with the file schema - -Each time the schema is saved, the version hash is updated. - -Each schema contains also numeric version which increases automatically whenever schema is updated and saved. This version is mostly for informative purposes and there are cases where the increasing order will be lost. - -> Schema in the database is only updated if its hash is not stored in `_dlt_versions` table. In principle many pipelines may send data to a single dataset. If table name clash then a single table with the union of the columns will be created. If columns clash and they have different types etc. then the load will fail. - -### ❗ Normalizer and naming convention - -The parent table is created from all top level fields, if field are dictionaries they will be flattened. **all the key names will be converted with the configured naming convention**. The current naming convention -1. converts to snake_case, small caps. removes all ascii characters except alphanum and underscore -2. add `_` if name starts with number -3. multiples of `_` are converted into single `_` -4. the parent-child relation is expressed as double `_` in names. - -The nested lists will be converted into child tables. - -The data normalizer and the naming convention are part of the schema configuration. In principle the source can set own naming convention or json unpacking mechanism. Or user can overwrite those in `config.toml` - -> The table and column names are mapped automatically. **you cannot rename the columns or tables by changing the `name` property - you must rename your source documents** - -> if you provide any schema elements that contain identifiers via decorators or arguments (ie. `table_name` or `columns`) all the names used will be converted via the naming convention when adding to the schema. For example if you execute `dlt.run(... table_name="CamelCase")` the data will be loaded into `camel_case` - -> 💡 use simple, short small caps identifiers for everything! - -☠️ not implemented! - -⛔ The schema holds lineage information (from json paths to tables/columns) and (1) automatically adapts to destination limits ie. postgres 64 chars by recomputing all names (2) let's user to change the naming convention ie. to verbatim naming convention of `duckdb` where everything is allowed as identifier. - -⛔ Any naming convention generates name clashes. `dlt` detects and fixes name clashes using lineage information - - -#### JSON normalizer settings -Yes those are part of the normalizer module and can be plugged in. -1. column propagation from parent to child tables -2. nesting level - -```yaml -normalizers: - names: dlt.common.normalizers.names.snake_case - json: - module: dlt.common.normalizers.json.relational - config: - max_nesting: 5 - propagation: - # for all root tables - root: - # propagate root dlt id - _dlt_id: _dlt_root_id - tables: - # for particular tables - blocks: - # propagate timestamp as block_timestamp to child tables - timestamp: block_timestamp - hash: block_hash -``` - -## Data types -"text", "double", "bool", "timestamp", "bigint", "binary", "complex", "decimal", "wei" -⛔ you cannot specify scale and precision for bigint, binary, text and decimal - -☮️ there's no time and date type - -wei is a datatype that tries to best represent native Ethereum 256bit integers and fixed point decimals. it works correcly on postgres and bigquery ## Schema settings The `settings` section of schema let's you define various global rules that impact how tables and columns are inferred from data. -> 💡 it is the best practice to use those instead of providing the exact column schemas via `columns` argument or by pasting them in `yaml`. Any ideas for improvements? tell me. - -### Column hint rules -You can define a global rules that will apply hints to a newly inferred columns. Those rules apply to normalized column names. You can use column names directly or with regular expressions. ❗ when lineages are implemented the regular expressions will apply to lineages not to column names. - -Example from ethereum schema -```yaml -settings: - default_hints: - foreign_key: - - _dlt_parent_id - not_null: - - re:^_dlt_id$ - - _dlt_root_id - - _dlt_parent_id - - _dlt_list_idx - unique: - - _dlt_id - cluster: - - block_hash - partition: - - block_timestamp -``` - -### Preferred data types -You can define rules that will set the data type for newly created columns. Put the rules under `preferred_types` key of `settings`. On the left side there's a rule on a column name, on the right side is the data type. ❗See the column hint rules for naming convention! - -Example: -```yaml -settings: - preferred_types: - timestamp: timestamp - re:^inserted_at$: timestamp - re:^created_at$: timestamp - re:^updated_at$: timestamp - re:^_dlt_list_idx$: bigint -``` - -### data type autodetectors -You can define a set of functions that will be used to infer the data type of the column from a value. The functions are run from top to bottom on the lists. Look in `detections.py` to see what is available. -```yaml -settings: - detections: - - timestamp - - iso_timestamp - - iso_date -``` - -⛔ we may define `all_text` function that will generate string only schemas by telling `dlt` that all types should be coerced to strings. - ### Table exclude and include filters You can define the include and exclude filters on tables but you are much better off transforming and filtering your source data in python. The current implementation is both weird and quite powerful. In essence you can exclude columns and whole tables with regular expressions to which the inputs are normalized lineages of the values. Example @@ -191,54 +64,3 @@ p.run() ``` > The `normalize` stage creates standalone load packages each containing data and schema with particular version. Those packages are of course not impacted by the "live" schema changes. - -## Attaching schemas to sources -The general approach when creating a new pipeline is to setup a few global schema settings and then let the table and column schemas to be generated from the resource hints and data itself. - -> ⛔ I do not have any cool "schema builder" api yet to see the global settings. - -The `dlt.source` decorator accepts a schema instance that you can create yourself and whatever you want. It also support a few typical use cases: - -### Schema created implicitly by decorator -If no schema instance is passed, the decorator creates a schema with the name set to source name and all the settings to default. - -### Automatically load schema file stored with source python module -If no schema instance is passed, and a file with a name `{source name}_schema.yml` exists in the same folder as the module with the decorated function, it will be automatically loaded and used as the schema. - -This should make easier to bundle a fully specified (or non trivially configured) schema with a source. - -### Schema is modified in the source function body -What if you can configure your schema or add some tables only inside your schema function, when ie. you have the source credentials and user settings? You could for example add detailed schemas of all the database tables when someone requests a table data to be loaded. This information is available only at the moment source function is called. - -Similarly to the `state`, source and resource function has current schema available via `dlt.current.source_schema` - -Example: - -```python - -# apply schema to the source -@dlt.source -def createx(nesting_level: int): - - schema = dlt.current.source_schema() - - # get default normalizer config - normalizer_conf = dlt.schema.normalizer_config() - # set hash names convention which produces short names without clashes but very ugly - if short_names_convention: - normalizer_conf["names"] = dlt.common.normalizers.names.hash_names - - # apply normalizer conf - schema = Schema("createx", normalizer_conf) - # set nesting level, yeah it's ugly - schema._normalizers_config["json"].setdefault("config", {})["max_nesting"] = nesting_level - # remove date detector and add type detector that forces all fields to strings - schema._settings["detections"].remove("iso_timestamp") - schema._settings["detections"].insert(0, "all_text") - schema.compile_settings() - - return dlt.resource(...) - -``` - -Also look at the following [test](/tests/extract/test_decorators.py) : `test_source_schema_context` diff --git a/docs/tools/prepare_examples_tests.py b/docs/tools/prepare_examples_tests.py index a300b1eb8f..d39d311a50 100644 --- a/docs/tools/prepare_examples_tests.py +++ b/docs/tools/prepare_examples_tests.py @@ -17,6 +17,8 @@ # some stuff to insert for setting up and tearing down fixtures TEST_HEADER = """ +import pytest + from tests.utils import skipifgithubfork """ @@ -52,8 +54,12 @@ os.unlink(test_example_file) continue - with open(example_file, "r", encoding="utf-8") as f: - lines = f.read().split("\n") + try: + with open(example_file, "r", encoding="utf-8") as f: + lines = f.read().split("\n") + except FileNotFoundError: + print(f"Example file {example_file} not found, test prep will be skipped") + continue processed_lines = TEST_HEADER.split("\n") main_clause_found = False @@ -62,7 +68,8 @@ # convert the main clause to a test function if line.startswith(MAIN_CLAUSE): main_clause_found = True - processed_lines.append("@skipifgithubfork") + processed_lines.append("@skipifgithubfork") # skip on forks + processed_lines.append("@pytest.mark.forked") # skip on forks processed_lines.append(f"def test_{example}():") else: processed_lines.append(line) diff --git a/docs/website/blog/2023-09-05-mongo-etl.md b/docs/website/blog/2023-09-05-mongo-etl.md index cd102c8895..8dfd953be4 100644 --- a/docs/website/blog/2023-09-05-mongo-etl.md +++ b/docs/website/blog/2023-09-05-mongo-etl.md @@ -168,7 +168,7 @@ Here's a code explanation of how it works under the hood: pipeline_name='from_json', destination='duckdb', dataset_name='mydata', - full_refresh=True, + dev_mode=True, ) # dlt works with lists of dicts, so wrap data to the list load_info = pipeline.run([data], table_name="json_data") diff --git a/docs/website/blog/2023-10-23-arrow-loading.md b/docs/website/blog/2023-10-23-arrow-loading.md index 2cdf4d90e7..25962c932e 100644 --- a/docs/website/blog/2023-10-23-arrow-loading.md +++ b/docs/website/blog/2023-10-23-arrow-loading.md @@ -50,7 +50,7 @@ chat_messages = dlt.resource( In this demo I just extract and normalize data and skip the loading step. ```py -pipeline = dlt.pipeline(destination="duckdb", full_refresh=True) +pipeline = dlt.pipeline(destination="duckdb", dev_mode=True) # extract first pipeline.extract(chat_messages) info = pipeline.normalize() @@ -98,7 +98,7 @@ chat_messages = dlt.resource( write_disposition="append", )("postgresql://loader:loader@localhost:5432/dlt_data") -pipeline = dlt.pipeline(destination="duckdb", full_refresh=True) +pipeline = dlt.pipeline(destination="duckdb", dev_mode=True) # extract first pipeline.extract(chat_messages) info = pipeline.normalize(workers=3, loader_file_format="parquet") diff --git a/docs/website/blog/2023-12-01-dlt-kestra-demo.md b/docs/website/blog/2023-12-01-dlt-kestra-demo.md index 9f1d7acba2..1b1c79562d 100644 --- a/docs/website/blog/2023-12-01-dlt-kestra-demo.md +++ b/docs/website/blog/2023-12-01-dlt-kestra-demo.md @@ -45,7 +45,7 @@ Wanna jump to the [GitHub repo](https://github.com/dlt-hub/dlt-kestra-demo)? ## HOW IT WORKS -To lay it all out clearly: Everything's automated in **`Kestra`**, with hassle-free data loading thanks to **`dlt`**, and the analytical thinking handled by OpenAI. Here's a diagram to help you understand the general outline of the entire process. +To lay it all out clearly: Everything's automated in **`Kestra`**, with hassle-free data loading thanks to **`dlt`**, and the analytical thinking handled by OpenAI. Here's a diagram to help you understand the general outline of the entire process. ![overview](https://storage.googleapis.com/dlt-blog-images/dlt_kestra_workflow_overview.png) @@ -59,12 +59,12 @@ Once you’ve opened [http://localhost:8080/](http://localhost:8080/) in your br ![Kestra](https://storage.googleapis.com/dlt-blog-images/dlt_kestra_kestra_ui.png) -Now, all you need to do is [create your flows](https://github.com/dlt-hub/dlt-kestra-demo/blob/main/README.md) and execute them. +Now, all you need to do is [create your flows](https://github.com/dlt-hub/dlt-kestra-demo/blob/main/README.md) and execute them. The great thing about **`Kestra`** is its ease of use - it's UI-based, declarative, and language-agnostic. Unless you're using a task like a [Python script](https://kestra.io/plugins/plugin-script-python/tasks/io.kestra.plugin.scripts.python.script), you don't even need to know how to code. -:::tip +:::tip If you're already considering ways to use **`Kestra`** for your projects, consult their [documentation](https://kestra.io/docs) and the [plugin](https://kestra.io/plugins) pages for further insights. ::: @@ -84,7 +84,7 @@ pipeline = dlt.pipeline( pipeline_name="standard_inbox", destination='bigquery', dataset_name="messages_data", - full_refresh=False, + dev_mode=False, ) # Set table name diff --git a/docs/website/docs/dlt-ecosystem/destinations/athena.md b/docs/website/docs/dlt-ecosystem/destinations/athena.md index 93291bfe9a..a723e3554c 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/athena.md +++ b/docs/website/docs/dlt-ecosystem/destinations/athena.md @@ -141,7 +141,7 @@ For every table created as an iceberg table, the Athena destination will create The `merge` write disposition is supported for Athena when using iceberg tables. > Note that: -> 1. there is a risk of tables ending up in inconsistent state in case a pipeline run fails mid flight, because Athena doesn't support transactions, and `dlt` uses multiple DELETE/UPDATE/INSERT statements to implement `merge`, +> 1. there is a risk of tables ending up in inconsistent state in case a pipeline run fails mid flight, because Athena doesn't support transactions, and `dlt` uses multiple DELETE/UPDATE/INSERT statements to implement `merge`, > 2. `dlt` creates additional helper tables called `insert_` and `delete_
` in the staging schema to work around Athena's lack of temporary tables. ### dbt support @@ -183,7 +183,7 @@ Here is an example of how to use the adapter to partition a table: from datetime import date import dlt -from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter +from dlt.destinations.adapters import athena_partition, athena_adapter data_items = [ (1, "A", date(2021, 1, 1)), diff --git a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md index 4f99901e37..f97a4a96bb 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md +++ b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md @@ -232,7 +232,7 @@ Here is an example of how to use the `bigquery_adapter` method to apply hints to from datetime import date, timedelta import dlt -from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.adapters import bigquery_adapter @dlt.resource( diff --git a/docs/website/docs/dlt-ecosystem/destinations/dremio.md b/docs/website/docs/dlt-ecosystem/destinations/dremio.md index 546f470938..c087d5dc0a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/dremio.md +++ b/docs/website/docs/dlt-ecosystem/destinations/dremio.md @@ -86,7 +86,7 @@ Data loading happens by copying a staged parquet files from an object storage bu Dremio does not support `CREATE SCHEMA` DDL statements. -Therefore, "Metastore" data sources, such as Hive or Glue, require that the dataset schema exists prior to running the dlt pipeline. `full_refresh=True` is unsupported for these data sources. +Therefore, "Metastore" data sources, such as Hive or Glue, require that the dataset schema exists prior to running the dlt pipeline. `dev_mode=True` is unsupported for these data sources. "Object Storage" data sources do not have this limitation. diff --git a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md index d6ec36ae49..023f3e35bc 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md @@ -37,7 +37,7 @@ All write dispositions are supported. ### Names normalization `dlt` uses the standard **snake_case** naming convention to keep identical table and column identifiers across all destinations. If you want to use the **duckdb** wide range of characters (i.e., emojis) for table and column names, you can switch to the **duck_case** naming convention, which accepts almost any string as an identifier: -* `\n` `\r` and `" are translated to `_` +* `\n` `\r` and `"` are translated to `_` * multiple `_` are translated to a single `_` Switch the naming convention using `config.toml`: @@ -51,7 +51,7 @@ or via the env variable `SCHEMA__NAMING` or directly in the code: dlt.config["schema.naming"] = "duck_case" ``` :::caution -**duckdb** identifiers are **case insensitive** but display names preserve case. This may create name clashes if, for example, you load JSON with +**duckdb** identifiers are **case insensitive** but display names preserve case. This may create name collisions if, for example, you load JSON with `{"Column": 1, "column": 2}` as it will map data to a single column. ::: diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md new file mode 100644 index 0000000000..dbf90da4b9 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -0,0 +1,211 @@ +--- +title: LanceDB +description: LanceDB is an open source vector database that can be used as a destination in dlt. +keywords: [ lancedb, vector database, destination, dlt ] +--- + +# LanceDB + +[LanceDB](https://lancedb.com/) is an open-source, high-performance vector database. It allows you to store data objects and perform similarity searches over them. +This destination helps you load data into LanceDB from [dlt resources](../../general-usage/resource.md). + +## Setup Guide + +### Choosing a Model Provider + +First, you need to decide which embedding model provider to use. You can find all supported providers by visiting the official [LanceDB docs](https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/). + + +### Install dlt with LanceDB + +To use LanceDB as a destination, make sure `dlt` is installed with the `lancedb` extra: + +```sh +pip install "dlt[lancedb]" +``` + +the lancedb extra only installs `dlt` and `lancedb`. You will need to install your model provider's SDK. + +You can find which libraries you need to also referring to the [LanceDB docs](https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/). + +### Configure the destination + +Configure the destination in the dlt secrets file located at `~/.dlt/secrets.toml` by default. Add the following section: + +```toml +[destination.lancedb] +embedding_model_provider = "cohere" +embedding_model = "embed-english-v3.0" +[destination.lancedb.credentials] +uri = ".lancedb" +api_key = "api_key" # API key to connect to LanceDB Cloud. Leave out if you are using LanceDB OSS. +embedding_model_provider_api_key = "embedding_model_provider_api_key" # Not needed for providers that don't need authentication (ollama, sentence-transformers). +``` + +- The `uri` specifies the location of your LanceDB instance. It defaults to a local, on-disk instance if not provided. +- The `api_key` is your api key for LanceDB Cloud connections. If you're using LanceDB OSS, you don't need to supply this key. +- The `embedding_model_provider` specifies the embedding provider used for generating embeddings. The default is `cohere`. +- The `embedding_model` specifies the model used by the embedding provider for generating embeddings. + Check with the embedding provider which options are available. + Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/. +- The `embedding_model_provider_api_key` is the API key for the embedding model provider used to generate embeddings. If you're using a provider that doesn't need authentication, say ollama, you don't need to supply this key. + +:::info Available Model Providers +- "gemini-text" +- "bedrock-text" +- "cohere" +- "gte-text" +- "imagebind" +- "instructor" +- "open-clip" +- "openai" +- "sentence-transformers" +- "huggingface" +- "colbert" +::: + +### Define your data source + +For example: + +```py +import dlt +from dlt.destinations.adapters import lancedb_adapter + + +movies = [ + { + "id": 1, + "title": "Blade Runner", + "year": 1982, + }, + { + "id": 2, + "title": "Ghost in the Shell", + "year": 1995, + }, + { + "id": 3, + "title": "The Matrix", + "year": 1999, + }, +] +``` + +### Create a pipeline: + +```py +pipeline = dlt.pipeline( + pipeline_name="movies", + destination="lancedb", + dataset_name="MoviesDataset", +) +``` + +### Run the pipeline: + +```py +info = pipeline.run( + lancedb_adapter( + movies, + embed="title", + ) +) +``` + +The data is now loaded into LanceDB. + +To use **vector search** after loading, you **must specify which fields LanceDB should generate embeddings for**. Do this by wrapping the data (or dlt resource) with the **`lancedb_adapter`** +function. + +## Using an Adapter to Specify Columns to Vectorise + +Out of the box, LanceDB will act as a normal database. To use LanceDB's embedding facilities, you'll need to specify which fields you'd like to embed in your dlt resource. + +The `lancedb_adapter` is a helper function that configures the resource for the LanceDB destination: + +```py +lancedb_adapter(data, embed) +``` + +It accepts the following arguments: + +- `data`: a dlt resource object, or a Python data structure (e.g. a list of dictionaries). +- `embed`: a name of the field or a list of names to generate embeddings for. + +Returns: [dlt resource](../../general-usage/resource.md) object that you can pass to the `pipeline.run()`. + +Example: + +```py +lancedb_adapter( + resource, + embed=["title", "description"], +) +``` + +Bear in mind that you can't use an adapter on a [dlt source](../../general-usage/source.md), only a [dlt resource](../../general-usage/resource.md). + +## Write disposition + +All [write dispositions](../../general-usage/incremental-loading.md#choosing-a-write-disposition) are supported by the LanceDB destination. + +### Replace + +The [replace](../../general-usage/full-loading.md) disposition replaces the data in the destination with the data from the resource. + +```py +info = pipeline.run( + lancedb_adapter( + movies, + embed="title", + ), + write_disposition="replace", +) +``` + +### Merge + +The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. + +```py +pipeline.run( + lancedb_adapter( + movies, + embed="title", + ), + write_disposition="merge", + primary_key="id", +) +``` + +### Append + +This is the default disposition. It will append the data to the existing data in the destination. + +## Additional Destination Options + +- `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". +- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector__". +- `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". +- `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. + + +## dbt support + +The LanceDB destination doesn't support dbt integration. + +## Syncing of `dlt` state + +The LanceDB destination supports syncing of the `dlt` state. + +## Current Limitations + +Adding new fields to an existing LanceDB table requires loading the entire table data into memory as a PyArrow table. +This is because PyArrow tables are immutable, so adding fields requires creating a new table with the updated schema. + +For huge tables, this may impact performance and memory usage since the full table must be loaded into memory to add the new fields. +Keep these considerations in mind when working with large datasets and monitor memory usage if adding fields to sizable existing tables. + + + diff --git a/docs/website/docs/dlt-ecosystem/destinations/postgres.md b/docs/website/docs/dlt-ecosystem/destinations/postgres.md index ae504728c3..49b3c06208 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/postgres.md +++ b/docs/website/docs/dlt-ecosystem/destinations/postgres.md @@ -105,6 +105,28 @@ The Postgres destination creates UNIQUE indexes by default on columns with the ` create_indexes=false ``` +### Setting up `csv` format +You can provide [non-default](../file-formats/csv.md#default-settings) csv settings via configuration file or explicitly. +```toml +[destination.postgres.csv_format] +delimiter="|" +include_header=false +``` +or +```py +from dlt.destinations import postgres +from dlt.common.data_writers.configuration import CsvFormatConfiguration + +csv_format = CsvFormatConfiguration(delimiter="|", include_header=False) + +dest_ = postgres(csv_format=csv_format) +``` +Above we set `csv` file without header, with **|** as a separator. + +:::tip +You'll need those setting when [importing external files](../../general-usage/resource.md#import-external-files) +::: + ### dbt support This destination [integrates with dbt](../transformations/dbt/dbt.md) via dbt-postgres. diff --git a/docs/website/docs/dlt-ecosystem/destinations/redshift.md b/docs/website/docs/dlt-ecosystem/destinations/redshift.md index 7e0679ec6b..ab193c755d 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/redshift.md +++ b/docs/website/docs/dlt-ecosystem/destinations/redshift.md @@ -97,6 +97,12 @@ Amazon Redshift supports the following column hints: Redshift supports s3 as a file staging destination. dlt will upload files in the parquet format to s3 and ask Redshift to copy their data directly into the db. Please refer to the [S3 documentation](./filesystem.md#aws-s3) to learn how to set up your s3 bucket with the bucket_url and credentials. The `dlt` Redshift loader will use the AWS credentials provided for s3 to access the s3 bucket if not specified otherwise (see config options below). Alternatively to parquet files, you can also specify jsonl as the staging file format. For this, set the `loader_file_format` argument of the `run` command of the pipeline to `jsonl`. +## Identifier names and case sensitivity +* Up to 127 characters +* Case insensitive +* Stores identifiers in lower case +* Has case sensitive mode, if enabled you must [enable case sensitivity in destination factory](../../general-usage/destination.md#control-how-dlt-creates-table-column-and-other-identifiers) + ### Authentication IAM Role If you would like to load from s3 without forwarding the AWS staging credentials but authorize with an IAM role connected to Redshift, follow the [Redshift documentation](https://docs.aws.amazon.com/redshift/latest/mgmt/authorizing-redshift-service.html) to create a role with access to s3 linked to your Redshift cluster and change your destination settings to use the IAM role: diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 513c951f78..b92d242c8a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -141,12 +141,30 @@ The data is loaded using an internal Snowflake stage. We use the `PUT` command a * [insert-values](../file-formats/insert-format.md) is used by default * [parquet](../file-formats/parquet.md) is supported * [jsonl](../file-formats/jsonl.md) is supported +* [csv](../file-formats/csv.md) is supported When staging is enabled: * [jsonl](../file-formats/jsonl.md) is used by default * [parquet](../file-formats/parquet.md) is supported +* [csv](../file-formats/csv.md) is supported -> ❗ When loading from `parquet`, Snowflake will store `complex` types (JSON) in `VARIANT` as a string. Use the `jsonl` format instead or use `PARSE_JSON` to update the `VARIANT` field after loading. +:::caution +When loading from `parquet`, Snowflake will store `complex` types (JSON) in `VARIANT` as a string. Use the `jsonl` format instead or use `PARSE_JSON` to update the `VARIANT` field after loading. +::: + +### Custom csv formats +By default we support csv format [produced by our writers](../file-formats/csv.md#default-settings) which is comma delimited, with header and optionally quoted. + +You can configure your own formatting ie. when [importing](../../general-usage/resource.md#import-external-files) external `csv` files. +```toml +[destination.snowflake.csv_format] +delimiter="|" +include_header=false +on_error_continue=true +``` +Which will read, `|` delimited file, without header and will continue on errors. + +Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and we will insert NULL into them. ## Supported column hints Snowflake supports the following [column hints](https://dlthub.com/docs/general-usage/schema#tables-and-columns): @@ -265,6 +283,29 @@ stage_name="DLT_STAGE" keep_staged_files=true ``` +### Setting up `csv` format +You can provide [non-default](../file-formats/csv.md#default-settings) csv settings via configuration file or explicitly. +```toml +[destination.snowflake.csv_format] +delimiter="|" +include_header=false +on_error_continue=true +``` +or +```py +from dlt.destinations import snowflake +from dlt.common.data_writers.configuration import CsvFormatConfiguration + +csv_format = CsvFormatConfiguration(delimiter="|", include_header=False, on_error_continue=True) + +dest_ = snowflake(csv_format=csv_format) +``` +Above we set `csv` file without header, with **|** as a separator and we request to ignore lines with errors. + +:::tip +You'll need those setting when [importing external files](../../general-usage/resource.md#import-external-files) +::: + ### dbt support This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-snowflake](https://github.com/dbt-labs/dbt-snowflake). Both password and key pair authentication are supported and shared with dbt runners. diff --git a/docs/website/docs/dlt-ecosystem/destinations/synapse.md b/docs/website/docs/dlt-ecosystem/destinations/synapse.md index 2e936f193e..6cfcb1ef8f 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/synapse.md +++ b/docs/website/docs/dlt-ecosystem/destinations/synapse.md @@ -148,6 +148,8 @@ Data is loaded via `INSERT` statements by default. The [table index type](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index) of the created tables can be configured at the resource level with the `synapse_adapter`: ```py +from dlt.destinations.adapters import synapse_adapter + info = pipeline.run( synapse_adapter( data=your_resource, diff --git a/docs/website/docs/dlt-ecosystem/destinations/weaviate.md b/docs/website/docs/dlt-ecosystem/destinations/weaviate.md index 11d1276ceb..c6597fadce 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/weaviate.md +++ b/docs/website/docs/dlt-ecosystem/destinations/weaviate.md @@ -252,7 +252,7 @@ it will be normalized to: so your best course of action is to clean up the data yourself before loading and use the default naming convention. Nevertheless, you can configure the alternative in `config.toml`: ```toml [schema] -naming="dlt.destinations.weaviate.impl.ci_naming" +naming="dlt.destinations.impl.weaviate.ci_naming" ``` ## Additional destination options diff --git a/docs/website/docs/dlt-ecosystem/file-formats/csv.md b/docs/website/docs/dlt-ecosystem/file-formats/csv.md index 4a57a0e2d6..02a7e81def 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/csv.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/csv.md @@ -16,7 +16,7 @@ Internally we use two implementations: ## Supported Destinations -Supported by: **Postgres**, **Filesystem** +Supported by: **Postgres**, **Filesystem**, **snowflake** By setting the `loader_file_format` argument to `csv` in the run command, the pipeline will store your data in the csv format at the destination: @@ -28,11 +28,23 @@ info = pipeline.run(some_source(), loader_file_format="csv") `dlt` attempts to make both writers to generate similarly looking files * separators are commas * quotes are **"** and are escaped as **""** -* `NULL` values are empty strings +* `NULL` values both are empty strings and empty tokens as in the example below * UNIX new lines are used * dates are represented as ISO 8601 * quoting style is "when needed" +Example of NULLs: +```sh +text1,text2,text3 +A,B,C +A,,"" +``` + +In the last row both `text2` and `text3` values are NULL. Python `csv` writer +is not able to write unquoted `None` values so we had to settle for `""` + +Note: all destinations capable of writing csvs must support it. + ### Change settings You can change basic **csv** settings, this may be handy when working with **filesystem** destination. Other destinations are tested with standard settings: @@ -59,6 +71,15 @@ NORMALIZE__DATA_WRITER__INCLUDE_HEADER=False NORMALIZE__DATA_WRITER__QUOTING=quote_all ``` +### Destination settings +A few additional settings are available when copying `csv` to destination tables: +* **on_error_continue** - skip lines with errors (only Snowflake) +* **encoding** - encoding of the `csv` file + +:::tip +You'll need those setting when [importing external files](../../general-usage/resource.md#import-external-files) +::: + ## Limitations **arrow writer** diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md b/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md index 7b957e98ea..9cd6ad8079 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md @@ -355,7 +355,7 @@ To read more about tables, columns, and datatypes, please refer to [our document `dlt` will **not modify** tables after they are created. So if you changed data types with hints, then you need to **delete the dataset** -or set `full_refresh=True`. +or set `dev_mode=True`. ::: ## Sources and resources diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md b/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md index 6fda0f8fe9..f6d57a5ba2 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md @@ -317,16 +317,28 @@ verified source. 1. To load a selected collection and rename it in the destination: ```py - # Create the MongoDB source and select the "collection_1" collection - source = mongodb().with_resources("collection_1") + # Create the MongoDB source and select the "collection_1" collection + source = mongodb().with_resources("collection_1") - # Apply the hint to rename the table in the destination - source.resources["collection_1"].apply_hints(table_name="loaded_data_1") + # Apply the hint to rename the table in the destination + source.resources["collection_1"].apply_hints(table_name="loaded_data_1") - # Run the pipeline - info = pipeline.run(source, write_disposition="replace") - print(info) + # Run the pipeline + info = pipeline.run(source, write_disposition="replace") + print(info) ``` +1. To load a selected collection, using Apache Arrow for data conversion: + ```py + # Load collection "movies", using Apache Arrow for converion + movies = mongodb_collection( + collection="movies", + data_item_format="arrow", + ) + + # Run the pipeline + info = pipeline.run(source) + print(info) + ``` diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md index 11d09c89f7..96cbe3b87d 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md @@ -9,6 +9,61 @@ import Header from './_source-info-header.md'; This is a generic dlt source you can use to extract data from any REST API. It uses [declarative configuration](#source-configuration) to define the API endpoints, their [relationships](#define-resource-relationships), how to handle [pagination](#pagination), and [authentication](#authentication). +### Quick example + +Here's an example of how to configure the REST API source to load posts and related comments from a hypothetical blog API: + +```py +import dlt +from rest_api import rest_api_source + +source = rest_api_source({ + "client": { + "base_url": "https://api.example.com/", + "auth": { + "token": dlt.secrets["your_api_token"], + }, + "paginator": { + "type": "json_response", + "next_url_path": "paging.next", + }, + }, + "resources": [ + # "posts" will be used as the endpoint path, the resource name, + # and the table name in the destination. The HTTP client will send + # a request to "https://api.example.com/posts". + "posts", + + # The explicit configuration allows you to link resources + # and define parameters. + { + "name": "comments", + "endpoint": { + "path": "posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + }, + "sort": "created_at", + }, + }, + }, + ], +}) + +pipeline = dlt.pipeline( + pipeline_name="rest_api_example", + destination="duckdb", + dataset_name="rest_api_data", +) + +load_info = pipeline.run(source) +``` + +Running this pipeline will create two tables in the DuckDB: `posts` and `comments` with the data from the respective API endpoints. The `comments` resource will fetch comments for each post by using the `id` field from the `posts` resource. + ## Setup guide ### Initialize the verified source diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md index fde7a64144..36a8569a4a 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md @@ -271,7 +271,7 @@ pipeline = dlt.pipeline( pipeline_name="unsw_download", destination=filesystem(os.path.abspath("../_storage/unsw")), progress="log", - full_refresh=True, + dev_mode=True, ) info = pipeline.run( diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md b/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md index 5844844cca..8c39a5090e 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md @@ -175,24 +175,7 @@ def incremental_stripe_source( After each run, 'initial_start_date' updates to the last loaded date. Subsequent runs then retrieve only new data using append mode, streamlining the process and preventing redundant data downloads. -For more information, read the [General Usage: Incremental loading](../../general-usage/incremental-loading). - -### Resource `metrics_resource` - -This function loads a dictionary with calculated metrics, including MRR and Churn rate, along with the current timestamp. - -```py -@dlt.resource(name="Metrics", write_disposition="append", primary_key="created") -def metrics_resource() -> Iterable[TDataItem]: - ... -``` - -Abrevations MRR and Churn rate are as follows: -- Monthly Recurring Revenue (MRR): - - Measures the predictable monthly revenue from all active subscriptions. It's the sum of the monthly-normalized subscription amounts. -- Churn rate: - - Indicates the rate subscribers leave a service over a specific period. Calculated by dividing the number of recent cancellations by the total subscribers from 30 days ago, adjusted for new subscribers. - +For more information, read the [Incremental loading](../../general-usage/incremental-loading). ## Customization ### Create your own pipeline @@ -236,7 +219,7 @@ verified source. ``` > For subsequent runs, the dlt module sets the previous "end_date" as "initial_start_date", ensuring incremental data retrieval. -1. To load data created after December 31, 2022, adjust the data range for stripe_source to prevent redundant loading. For incremental_stripe_source, the initial_start_date will auto-update to the last loaded date from the previous run. +1. To load data created after December 31, 2022, adjust the data range for stripe_source to prevent redundant loading. For `incremental_stripe_source`, the `initial_start_date` will auto-update to the last loaded date from the previous run. ```py source_single = stripe_source( @@ -249,21 +232,6 @@ verified source. load_info = pipeline.run(data=[source_single, source_incremental]) print(load_info) ``` - > To load data, maintain the pipeline name and destination dataset name. The pipeline name is vital for accessing the last run's [state](https://dlthub.com/docs/general-usage/state), which determines the incremental data load's end date. Altering these names can trigger a [“full_refresh”](https://dlthub.com/docs/general-usage/pipeline#do-experiments-with-full-refresh), disrupting the metadata (state) tracking for [incremental data loading](https://dlthub.com/docs/general-usage/incremental-loading). - -1. To load important metrics and store them in database: - - ```py - # Event is an endpoint with uneditable data, so we can use 'incremental_stripe_source'. - source_event = incremental_stripe_source(endpoints=("Event",)) - # Subscription is an endpoint with editable data, use stripe_source. - source_subs = stripe_source(endpoints=("Subscription",)) - load_info = pipeline.run(data=[source_subs, source_event]) - print(load_info) - resource = metrics_resource() - print(list(resource)) - load_info = pipeline.run(resource) - print(load_info) - ``` + > To load data, maintain the pipeline name and destination dataset name. The pipeline name is vital for accessing the last run's [state](../../general-usage/state), which determines the incremental data load's end date. Altering these names can trigger a [“full_refresh”](../../general-usage/pipeline#do-experiments-with-full-refresh), disrupting the metadata (state) tracking for [incremental data loading](../../general-usage/incremental-loading). diff --git a/docs/website/docs/general-usage/destination.md b/docs/website/docs/general-usage/destination.md index 760daa2fee..b30403d349 100644 --- a/docs/website/docs/general-usage/destination.md +++ b/docs/website/docs/general-usage/destination.md @@ -18,26 +18,27 @@ We recommend that you declare the destination type when creating a pipeline inst Above we want to use **filesystem** built-in destination. You can use shorthand types only for built-ins. -* Use full **destination class type** +* Use full **destination factory type** -Above we use built in **filesystem** destination by providing a class type `filesystem` from module `dlt.destinations`. You can pass [destinations from external modules](#declare-external-destination) as well. +Above we use built in **filesystem** destination by providing a factory type `filesystem` from module `dlt.destinations`. You can pass [destinations from external modules](#declare-external-destination) as well. -* Import **destination class** +* Import **destination factory** -Above we import destination class for **filesystem** and pass it to the pipeline. +Above we import destination factory for **filesystem** and pass it to the pipeline. -All examples above will create the same destination class with default parameters and pull required config and secret values from [configuration](credentials/configuration.md) - they are equivalent. +All examples above will create the same destination factory with default parameters and pull required config and secret values from [configuration](credentials/configuration.md) - they are equivalent. ### Pass explicit parameters and a name to a destination -You can instantiate **destination class** yourself to configure it explicitly. When doing this you work with destinations the same way you work with [sources](source.md) +You can instantiate **destination factory** yourself to configure it explicitly. When doing this you work with destinations the same way you work with [sources](source.md) -Above we import and instantiate the `filesystem` destination class. We pass explicit url of the bucket and name the destination to `production_az_bucket`. +Above we import and instantiate the `filesystem` destination factory. We pass explicit url of the bucket and name the destination to `production_az_bucket`. + +If destination is not named, its shorthand type (the Python factory name) serves as a destination name. Name your destination explicitly if you need several separate configurations of destinations of the same type (i.e. you wish to maintain credentials for development, staging and production storage buckets in the same config file). Destination name is also stored in the [load info](../running-in-production/running.md#inspect-and-save-the-load-info-and-trace) and pipeline traces so use them also when you need more descriptive names (other than, for example, `filesystem`). -If destination is not named, its shorthand type (the Python class name) serves as a destination name. Name your destination explicitly if you need several separate configurations of destinations of the same type (i.e. you wish to maintain credentials for development, staging and production storage buckets in the same config file). Destination name is also stored in the [load info](../running-in-production/running.md#inspect-and-save-the-load-info-and-trace) and pipeline traces so use them also when you need more descriptive names (other than, for example, `filesystem`). ## Configure a destination We recommend to pass the credentials and other required parameters to configuration via TOML files, environment variables or other [config providers](credentials/config_providers.md). This allows you, for example, to easily switch to production destinations after deployment. @@ -59,7 +60,7 @@ For named destinations you use their names in the config section Note that when you use [`dlt init` command](../walkthroughs/add-a-verified-source.md) to create or add a data source, `dlt` creates a sample configuration for selected destination. ### Pass explicit credentials -You can pass credentials explicitly when creating destination class instance. This replaces the `credentials` argument in `dlt.pipeline` and `pipeline.load` methods - which is now deprecated. You can pass the required credentials object, its dictionary representation or the supported native form like below: +You can pass credentials explicitly when creating destination factory instance. This replaces the `credentials` argument in `dlt.pipeline` and `pipeline.load` methods - which is now deprecated. You can pass the required credentials object, its dictionary representation or the supported native form like below: @@ -74,6 +75,23 @@ You can create and pass partial credentials and `dlt` will fill the missing data Please read how to use [various built in credentials types](credentials/config_specs.md). ::: +### Inspect destination capabilities +[Destination capabilities](../walkthroughs/create-new-destination.md#3-set-the-destination-capabilities) tell `dlt` what given destination can and cannot do. For example it tells which file formats it can load, what is maximum query or identifier length. Inspect destination capabilities as follows: +```py +import dlt +pipeline = dlt.pipeline("snowflake_test", destination="snowflake") +print(dict(pipeline.destination.capabilities())) +``` + +### Pass additional parameters and change destination capabilities +Destination factory accepts additional parameters that will be used to pre-configure it and change destination capabilities. +```py +import dlt +duck_ = dlt.destinations.duckdb(naming_convention="duck_case", recommended_file_size=120000) +print(dict(duck_.capabilities())) +``` +Example above is overriding `naming_convention` and `recommended_file_size` in the destination capabilities. + ### Configure multiple destinations in a pipeline To configure multiple destinations within a pipeline, you need to provide the credentials for each destination in the "secrets.toml" file. This example demonstrates how to configure a BigQuery destination named `destination_one`: @@ -86,7 +104,7 @@ private_key = "please set me up!" client_email = "please set me up!" ``` -You can then use this destination in your pipeline as follows: +You can then use this destination in your pipeline as follows: ```py import dlt from dlt.common.destination import Destination @@ -117,6 +135,56 @@ Obviously, dlt will access the destination when you instantiate [sql_client](../ ::: +## Control how `dlt` creates table, column and other identifiers +`dlt` maps identifiers found in the source data into destination identifiers (ie. table and columns names) using [naming conventions](naming-convention.md) which ensure that +character set, identifier length and other properties fit into what given destination can handle. For example our [default naming convention (**snake case**)](naming-convention.md#default-naming-convention-snake_case) converts all names in the source (ie. JSON document fields) into snake case, case insensitive identifiers. + +Each destination declares its preferred naming convention, support for case sensitive identifiers and case folding function that case insensitive identifiers follow. For example: +1. Redshift - by default does not support case sensitive identifiers and converts all of them to lower case. +2. Snowflake - supports case sensitive identifiers and considers upper cased identifiers as case insensitive (which is the default case folding) +3. DuckDb - does not support case sensitive identifiers but does not case fold them so it preserves the original casing in the information schema. +4. Athena - does not support case sensitive identifiers and converts all of them to lower case. +5. BigQuery - all identifiers are case sensitive, there's no case insensitive mode available via case folding (but it can be enabled in dataset level). + +You can change the naming convention used in [many different ways](naming-convention.md#configure-naming-convention), below we set the preferred naming convention on the Snowflake destination to `sql_cs` to switch Snowflake to case sensitive mode: +```py +import dlt +snow_ = dlt.destinations.snowflake(naming_convention="sql_cs_v1") +``` +Setting naming convention will impact all new schemas being created (ie. on first pipeline run) and will re-normalize all existing identifiers. + +:::caution +`dlt` prevents re-normalization of identifiers in tables that were already created at the destination. Use [refresh](pipeline.md#refresh-pipeline-data-and-state) mode to drop the data. You can also disable this behavior via [configuration](naming-convention.md#avoid-identifier-collisions) +::: + +:::note +Destinations that support case sensitive identifiers but use case folding convention to enable case insensitive identifiers are configured in case insensitive mode by default. Examples: Postgres, Snowflake, Oracle. +::: + +:::caution +If you use case sensitive naming convention with case insensitive destination, `dlt` will: +1. Fail the load if it detects identifier collision due to case folding +2. Warn if any case folding is applied by the destination. +::: + +### Enable case sensitive identifiers support +Selected destinations may be configured so they start accepting case sensitive identifiers. For example, it is possible to set case sensitive collation on **mssql** database and then tell `dlt` about it. +```py +from dlt.destinations import mssql +dest_ = mssql(has_case_sensitive_identifiers=True, naming_convention="sql_cs_v1") +``` +Above we can safely use case sensitive naming convention without worrying of name collisions. + +You can configure the case sensitivity, **but configuring destination capabilities is not currently supported**. +```toml +[destination.mssql] +has_case_sensitive_identifiers=true +``` + +:::note +In most cases setting the flag above just indicates to `dlt` that you switched the case sensitive option on a destination. `dlt` will not do that for you. Refer to destination documentation for details. +::: + ## Create new destination You have two ways to implement a new destination: 1. You can use `@dlt.destination` decorator and [implement a sink function](../dlt-ecosystem/destinations/destination.md). This is perfect way to implement reverse ETL destinations that push data back to REST APIs. diff --git a/docs/website/docs/general-usage/naming-convention.md b/docs/website/docs/general-usage/naming-convention.md new file mode 100644 index 0000000000..72db7bf5f3 --- /dev/null +++ b/docs/website/docs/general-usage/naming-convention.md @@ -0,0 +1,128 @@ +--- +title: Naming Convention +description: Control how dlt creates table, column and other identifiers +keywords: [identifiers, snake case, case sensitive, case insensitive, naming] +--- + +# Naming Convention +`dlt` creates table and column identifiers from the data. The data source that ie. a stream of JSON documents may have identifiers (i.e. key names in a dictionary) with any Unicode characters, of any length and naming style. On the other hand, destinations require that you follow strict rules when you name tables, columns or collections. +A good example is [Redshift](../dlt-ecosystem/destinations/redshift.md#naming-convention) that accepts case-insensitive alphanumeric identifiers with maximum 127 characters. + +`dlt` groups tables from a single [source](source.md) in a [schema](schema.md). + +Each schema defines **naming convention** that tells `dlt` how to translate identifiers to the +namespace that the destination understands. Naming conventions are in essence functions that map strings from the source identifier format into destination identifier format. For example our **snake_case** (default) naming convention will translate `DealFlow` into `deal_flow` identifier. + +You can pick which naming convention to use. `dlt` provides a few to [choose from](#available-naming-conventions) or you can [easily add your own](#write-your-own-naming-convention). + +:::tip +* Standard behavior of `dlt` is to **use the same naming convention for all destinations** so users see always the same tables and column names in their databases. +* Use simple, short small caps identifiers for everything so no normalization is needed +::: + +### Use default naming convention (snake_case) +Case insensitive naming convention, converting source identifiers into lower case snake case with reduced alphabet. + +- Spaces around identifier are trimmed +- Keeps ascii alphanumerics and underscores, replaces all other characters with underscores (with the exceptions below) +- Replaces `+` and `*` with `x`, `-` with `_`, `@` with `a` and `|` with `l` +- Prepends `_` if name starts with number. +- Multiples of `_` are converted into single `_`. +- Replaces all trailing `_` with `x` + +Uses __ as patent-child separator for tables and flattened column names. + +:::tip +If you do not like **snake_case** your next safe option is **sql_ci** which generates SQL-safe, lower-case, case-insensitive identifiers without any +other transformations. To permanently change the default naming convention on a given machine: +1. set an environment variable `SCHEMA__NAMING` to `sql_ci_v1` OR +2. add the following line to your global `config.toml` (the one in your home dir ie. `~/.dlt/config.toml`) +```toml +[schema] +naming="sql_ci_v1" +``` +::: + +## Source identifiers vs destination identifiers +### Pick the right identifier form when defining resources +`dlt` keeps source (not normalized) identifiers during data [extraction](../reference/explainers/how-dlt-works.md#extract) and translates them during [normalization](../reference/explainers/how-dlt-works.md#normalize). For you it means: +1. If you write a [transformer](resource.md#process-resources-with-dlttransformer) or a [mapping/filtering function](resource.md#filter-transform-and-pivot-data), you will see the original data, without any normalization. Use the source key names to access the dicts! +2. If you define a `primary_key` or `cursor` that participate in [cursor field incremental loading](incremental-loading.md#incremental-loading-with-a-cursor-field) use the source identifiers (`dlt` uses them to inspect source data, `Incremental` class is a filtering function). +3. When defining any other hints ie. `columns` or `merge_key` you can pick source or destination identifiers. `dlt` normalizes all hints together with your data. +4. `Schema` object (ie. obtained from the pipeline or from `dlt` source via `discover_schema`) **always contains destination (normalized) identifiers**. + +In the snippet below, we define a resource with various "illegal" unicode characters in table name and other hint and demonstrate how they get normalized in the schema object. +```py +``` + +### Understand the identifier normalization +Identifiers are translated from source to destination form in **normalize** step. Here's how `dlt` picks the right naming convention: + +* Each destination may define a preferred naming convention (ie. Weaviate), otherwise **snake case** will be used +* This naming convention is used when new schemas are created. This happens when pipeline is run for a first time. +* Schemas preserve naming convention when saved. Your running pipelines will maintain existing naming conventions if not requested otherwise +* `dlt` applies final naming convention in `normalize` step. Naming convention comes from (1) explicit configuration (2) from destination capabilities. +* Naming convention will be used to put destination is case sensitive/insensitive mode and apply the right case folding function. + +:::caution +If you change naming convention and `dlt` detects that it changes the destination identifiers for tables/collection/files that already exist and store data, +the normalize process will fail. +::: + +### Case sensitive and insensitive destinations +Naming conventions come in two types. +* **case sensitive** +* **case insensitive** + +Case sensitive naming convention will put a destination in [case sensitive mode](destination.md#control-how-dlt-creates-table-column-and-other-identifiers). Identifiers that +differ only in casing will not [collide](#avoid-identifier-collisions). Note that many destinations are exclusively case insensitive, of which some preserve casing of identifiers (ie. **duckdb**) and some will case-fold identifiers when creating tables (ie. **Redshift**, **Athena** do lower case on the names). + +## Identifier shortening +Identifier shortening happens during normalization. `dlt` takes the maximum length of the identifier from the destination capabilities and will trim the identifiers that are +too long. The default shortening behavior generates short deterministic hashes of the source identifiers and places them in the middle of the destination identifier. This +(with a high probability) avoids shortened identifier collisions. + + +## Pick your own naming convention + +### Configure naming convention +tbd. + + +### Available naming conventions + +* snake_case +* duck_case - case sensitive, allows all unicode characters like emoji 💥 +* direct - case sensitive, allows all unicode characters, does not contract underscores +* `sql_cs_v1` - case sensitive, generates sql-safe identifiers +* `sql_ci_v1` - case insensitive, generates sql-safe lower case identifiers + +### Set and adjust naming convention explicitly +tbd. + +## Avoid identifier collisions +`dlt` detects various types of collisions and ignores the others. +1. `dlt` detects collisions if case sensitive naming convention is used on case insensitive destination +2. `dlt` detects collisions if change of naming convention changes the identifiers of tables already created in the destination +3. `dlt` detects collisions when naming convention is applied to column names of arrow tables + +`dlt` will not detect collision when normalizing source data. If you have a dictionary, keys will be merged if they collide after being normalized. +You can use a naming convention that does not generate collisions, see examples below. + + +## Write your own naming convention +Custom naming conventions are classes that derive from `NamingConvention` that you can import from `dlt.common.normalizers.naming`. We recommend the following module layout: +1. Each naming convention resides in a separate Python module (file) +2. The class is always named `NamingConvention` + +In that case you can use a fully qualified module name in [schema configuration](#configure-naming-convention) or pass module [explicitly](#set-and-adjust-naming-convention-explicitly). + +We include [two examples](../examples/custom_naming) of naming conventions that you may find useful: + +1. A variant of `sql_ci` that generates identifier collisions with a low (user defined) probability by appending a deterministic tag to each name. +2. A variant of `sql_cs` that allows for LATIN (ie. umlaut) characters + +:::note +Note that a fully qualified name of your custom naming convention will be stored in the `Schema` and `dlt` will attempt to import it when schema is loaded from storage. +You should distribute your custom naming conventions with your pipeline code via an installable package with a defined namespace. +::: diff --git a/docs/website/docs/general-usage/resource.md b/docs/website/docs/general-usage/resource.md index ac7f7e6b38..14f8d73b58 100644 --- a/docs/website/docs/general-usage/resource.md +++ b/docs/website/docs/general-usage/resource.md @@ -488,6 +488,59 @@ be adjusted after the `batch` is processed in the extract pipeline but before an You can emit columns as Pydantic model and use dynamic hints (ie. lambda for table name) as well. You should avoid redefining `Incremental` this way. ::: +### Import external files +You can import external files ie. `csv`, `parquet` and `jsonl` by yielding items marked with `with_file_import`, optionally passing table schema corresponding +the the imported file. `dlt` will not read, parse and normalize any names (ie. `csv` or `arrow` headers) and will attempt to copy the file into the destination as is. +```py +import os +import dlt + +from filesystem import filesystem + +columns: List[TColumnSchema] = [ + {"name": "id", "data_type": "bigint"}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, +] + +import_folder = "/tmp/import" + +@dlt.transformer(columns=columns) +def orders(items: Iterator[FileItemDict]): + for item in items: + # copy file locally + dest_file = os.path.join(import_folder, item["file_name"]) + # download file + item.fsspec.download(item["file_url"], dest_file) + # tell dlt to import the dest_file as `csv` + yield dlt.mark.with_file_import(dest_file, "csv") + + +# use filesystem verified source to glob a bucket +downloader = filesystem( + bucket_url="s3://my_bucket/csv", + file_glob="today/*.csv.gz") | orders + +info = pipeline.run(orders, destination="snowflake") +``` +In the example above, we glob all zipped csv files present on **my_bucket/csv/today** (using `filesystem` verified source) and send file descriptors to `orders` transformer. Transformer downloads and imports the files into extract package. At the end, `dlt` sends them to snowflake (the table will be created because we use `column` hints to define the schema). + +If imported `csv` files are not in `dlt` [default format](../dlt-ecosystem/file-formats/csv.md#default-settings), you may need to pass additional configuration. +```toml +[destination.snowflake.csv_format] +delimiter="|" +include_header=false +on_error_continue=true +``` + +You can sniff the schema from the data ie. using `duckdb` to infer the table schema from `csv` file. `dlt.mark.with_file_import` accepts additional arguments that you can use to pass hints at run time. + +:::note +* If you do not define any columns, the table will not be created in the destination. `dlt` will still attempt to load data into it, so you create a fitting table upfront, the load process will succeed. +* Files are imported using hard links if possible to avoid copying and duplicating storage space needed. +::: ### Duplicate and rename resources There are cases when you your resources are generic (ie. bucket filesystem) and you want to load several instances of it (ie. files from different folders) to separate tables. In example below we use `filesystem` source to load csvs from two different folders into separate tables: @@ -538,12 +591,30 @@ pipeline.run(generate_rows(10)) # load a list of resources pipeline.run([generate_rows(10), generate_rows(20)]) ``` + +### Pick loader file format for a particular resource +You can request a particular loader file format to be used for a resource. +```py +@dlt.resource(file_format="parquet") +def generate_rows(nr): + for i in range(nr): + yield {'id':i, 'example_string':'abc'} +``` +Resource above will be saved and loaded from a `parquet` file (if destination supports it). + +:::note +A special `file_format`: **preferred** will load resource using a format that is preferred by a destination. This settings supersedes the `loader_file_format` passed to `run` method. +::: + ### Do a full refresh -To do a full refresh of an `append` or `merge` resources you temporarily change the write -disposition to replace. You can use `apply_hints` method of a resource or just provide alternative -write disposition when loading: +To do a full refresh of an `append` or `merge` resources you set the `refresh` argument on `run` method to `drop_data`. This will truncate the tables without dropping them. + +```py +p.run(merge_source(), refresh="drop_data") +``` +You can also [fully drop the tables](pipeline.md#refresh-pipeline-data-and-state) in the `merge_source`: ```py -p.run(merge_source(), write_disposition="replace") +p.run(merge_source(), refresh="drop_sources") ``` diff --git a/docs/website/docs/general-usage/schema.md b/docs/website/docs/general-usage/schema.md index 989b023b01..0e3e3bba1f 100644 --- a/docs/website/docs/general-usage/schema.md +++ b/docs/website/docs/general-usage/schema.md @@ -42,8 +42,9 @@ characters, any lengths and naming styles. On the other hand the destinations ac namespaces for their identifiers. Like Redshift that accepts case-insensitive alphanumeric identifiers with maximum 127 characters. -Each schema contains `naming convention` that tells `dlt` how to translate identifiers to the -namespace that the destination understands. +Each schema contains [naming convention](naming-convention.md) that tells `dlt` how to translate identifiers to the +namespace that the destination understands. This convention can be configured, changed in code or enforced via +destination. The default naming convention: @@ -214,7 +215,7 @@ The precision for **bigint** is mapped to available integer types ie. TINYINT, I ## Schema settings The `settings` section of schema file lets you define various global rules that impact how tables -and columns are inferred from data. +and columns are inferred from data. For example you can assign **primary_key** hint to all columns with name `id` or force **timestamp** data type on all columns containing `timestamp` with an use of regex pattern. > 💡 It is the best practice to use those instead of providing the exact column schemas via `columns` > argument or by pasting them in `yaml`. @@ -222,8 +223,9 @@ and columns are inferred from data. ### Data type autodetectors You can define a set of functions that will be used to infer the data type of the column from a -value. The functions are run from top to bottom on the lists. Look in [`detections.py`](https://github.com/dlt-hub/dlt/blob/devel/dlt/common/schema/detections.py) to see what is -available. +value. The functions are run from top to bottom on the lists. Look in `detections.py` to see what is +available. **iso_timestamp** detector that looks for ISO 8601 strings and converts them to **timestamp** +is enabled by default. ```yaml settings: @@ -236,12 +238,24 @@ settings: - wei_to_double ``` +Alternatively you can add and remove detections from code: +```py + source = data_source() + # remove iso time detector + source.schema.remove_type_detection("iso_timestamp") + # convert UNIX timestamp (float, withing a year from NOW) into timestamp + source.schema.add_type_detection("timestamp") +``` +Above we modify a schema that comes with a source to detect UNIX timestamps with **timestamp** detector. + ### Column hint rules You can define a global rules that will apply hints of a newly inferred columns. Those rules apply -to normalized column names. You can use column names directly or with regular expressions. +to normalized column names. You can use column names directly or with regular expressions. `dlt` is matching +the column names **after they got normalized with naming convention**. -Example from ethereum schema: +By default, schema adopts hints rules from json(relational) normalizer to support correct hinting +of columns added by normalizer: ```yaml settings: @@ -249,36 +263,59 @@ settings: foreign_key: - _dlt_parent_id not_null: - - re:^_dlt_id$ + - _dlt_id - _dlt_root_id - _dlt_parent_id - _dlt_list_idx + - _dlt_load_id unique: - _dlt_id - cluster: - - block_hash + root_key: + - _dlt_root_id +``` +Above we require exact column name match for a hint to apply. You can also use regular expression (which we call `SimpleRegex`) as follows: +```yaml +settings: partition: - - block_timestamp + - re:_timestamp$ +``` +Above we add `partition` hint to all columns ending with `_timestamp`. You can do same thing in the code +```py + source = data_source() + # this will update existing hints with the hints passed + source.schema.merge_hints({"partition": ["re:_timestamp$"]}) ``` ### Preferred data types You can define rules that will set the data type for newly created columns. Put the rules under `preferred_types` key of `settings`. On the left side there's a rule on a column name, on the right -side is the data type. - -> ❗See the column hint rules for naming convention! +side is the data type. You can use column names directly or with regular expressions. +`dlt` is matching the column names **after they got normalized with naming convention**. Example: ```yaml settings: preferred_types: - timestamp: timestamp - re:^inserted_at$: timestamp - re:^created_at$: timestamp - re:^updated_at$: timestamp - re:^_dlt_list_idx$: bigint + re:timestamp: timestamp + inserted_at: timestamp + created_at: timestamp + updated_at: timestamp +``` + +Above we prefer `timestamp` data type for all columns containing **timestamp** substring and define a few exact matches ie. **created_at**. +Here's same thing in code +```py + source = data_source() + source.schema.update_preferred_types( + { + "re:timestamp": "timestamp", + "inserted_at": "timestamp", + "created_at": "timestamp", + "updated_at": "timestamp", + } + ) ``` ### Applying data types directly with `@dlt.resource` and `apply_hints` `dlt` offers the flexibility to directly apply data types and hints in your code, bypassing the need for importing and adjusting schemas. This approach is ideal for rapid prototyping and handling data sources with dynamic schema requirements. @@ -364,7 +401,6 @@ def textual(nesting_level: int): schema.remove_type_detection("iso_timestamp") # convert UNIX timestamp (float, withing a year from NOW) into timestamp schema.add_type_detection("timestamp") - schema._compile_settings() return dlt.resource([]) ``` diff --git a/docs/website/docs/reference/performance_snippets/toml-snippets.toml b/docs/website/docs/reference/performance_snippets/toml-snippets.toml index 5e700c4e31..e1a640e7cf 100644 --- a/docs/website/docs/reference/performance_snippets/toml-snippets.toml +++ b/docs/website/docs/reference/performance_snippets/toml-snippets.toml @@ -71,7 +71,7 @@ max_parallel_items=10 # @@@DLT_SNIPPET_START normalize_workers_toml - [extract.data_writer] +[extract.data_writer] # force extract file rotation if size exceeds 1MiB file_max_bytes=1000000 diff --git a/docs/website/docs/walkthroughs/create-new-destination.md b/docs/website/docs/walkthroughs/create-new-destination.md index 1b72b81e3e..69e7b2fcc1 100644 --- a/docs/website/docs/walkthroughs/create-new-destination.md +++ b/docs/website/docs/walkthroughs/create-new-destination.md @@ -88,6 +88,10 @@ The default `escape_identifier` function identifier escapes `"` and '\' and quot You should avoid providing a custom `escape_literal` function by not enabling `insert-values` for your destination. +### Enable / disable case sensitive identifiers +Specify if destination supports case sensitive identifiers by setting `has_case_sensitive_identifiers` to `True` (or `False` if otherwise). Some case sensitive destinations (ie. **Snowflake** or **Postgres**) support case insensitive identifiers via. case folding ie. **Snowflake** considers all upper case identifiers as case insensitive (set `casefold_identifier` to `str.upper`), **Postgres** does the same with lower case identifiers (`str.lower`). +Some case insensitive destinations (ie. **Athena** or **Redshift**) case-fold (ie. lower case) all identifiers and store them as such. In that case set `casefold_identifier` to `str.lower` as well. + ## 4. Adjust the SQL client **sql client** is a wrapper over `dbapi` and its main role is to provide consistent interface for executing SQL statements, managing transactions and (probably the most important) to help handling errors via classifying exceptions. Here's a few things you should pay attention to: diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index d3d7def8fc..4fa1c58eae 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -116,6 +116,7 @@ const sidebars = { 'dlt-ecosystem/destinations/snowflake', 'dlt-ecosystem/destinations/athena', 'dlt-ecosystem/destinations/weaviate', + 'dlt-ecosystem/destinations/lancedb', 'dlt-ecosystem/destinations/qdrant', 'dlt-ecosystem/destinations/dremio', 'dlt-ecosystem/destinations/destination', @@ -157,6 +158,7 @@ const sidebars = { 'general-usage/incremental-loading', 'general-usage/full-loading', 'general-usage/schema', + 'general-usage/naming-convention', 'general-usage/schema-contracts', 'general-usage/schema-evolution', { diff --git a/poetry.lock b/poetry.lock index 5a94993c80..2cef57424d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "about-time" @@ -2683,21 +2683,27 @@ test = ["pytest (>=6)"] [[package]] name = "fastembed" -version = "0.1.1" +version = "0.2.6" description = "Fast, light, accurate library built for retrieval embedding generation" optional = true -python-versions = ">=3.8.0,<3.12" +python-versions = "<3.13,>=3.8.0" files = [ - {file = "fastembed-0.1.1-py3-none-any.whl", hash = "sha256:131413ae52cd72f4c8cced7a675f8269dbfd1a852abade3c815e265114bcc05a"}, - {file = "fastembed-0.1.1.tar.gz", hash = "sha256:f7e524ee4f74bb8aad16be5b687d1f77f608d40e96e292c87881dc36baf8f4c7"}, + {file = "fastembed-0.2.6-py3-none-any.whl", hash = "sha256:3e18633291722087abebccccd7fcdffafef643cb22d203370d7fad4fa83c10fb"}, + {file = "fastembed-0.2.6.tar.gz", hash = "sha256:adaed5b46e19cc1bbe5f98f2b3ffecfc4d2a48d27512e28ff5bfe92a42649a66"}, ] [package.dependencies] -onnx = ">=1.11,<2.0" -onnxruntime = ">=1.15,<2.0" +huggingface-hub = ">=0.20,<0.21" +loguru = ">=0.7.2,<0.8.0" +numpy = [ + {version = ">=1.21", markers = "python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, +] +onnx = ">=1.15.0,<2.0.0" +onnxruntime = ">=1.17.0,<2.0.0" requests = ">=2.31,<3.0" -tokenizers = ">=0.13,<0.14" -tqdm = ">=4.65,<5.0" +tokenizers = ">=0.15.1,<0.16.0" +tqdm = ">=4.66,<5.0" [[package]] name = "filelock" @@ -3546,6 +3552,164 @@ files = [ {file = "google_re2-1.1-1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c6c9f64b9724ec38da8e514f404ac64e9a6a5e8b1d7031c2dadd05c1f4c16fd"}, {file = "google_re2-1.1-1-cp39-cp39-win32.whl", hash = "sha256:d1b751b9ab9f8e2ab2a36d72b909281ce65f328c9115a1685acae1a2d1afd7a4"}, {file = "google_re2-1.1-1-cp39-cp39-win_amd64.whl", hash = "sha256:ac775c75cec7069351d201da4e0fb0cae4c1c5ebecd08fa34e1be89740c1d80b"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5eaefe4705b75ca5f78178a50104b689e9282f868e12f119b26b4cffc0c7ee6e"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:e35f2c8aabfaaa4ce6420b3cae86c0c29042b1b4f9937254347e9b985694a171"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:35fd189cbaaaa39c9a6a8a00164c8d9c709bacd0c231c694936879609beff516"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:60475d222cebd066c80414831c8a42aa2449aab252084102ee05440896586e6a"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:871cb85b9b0e1784c983b5c148156b3c5314cb29ca70432dff0d163c5c08d7e5"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:94f4e66e34bdb8de91ec6cdf20ba4fa9fea1dfdcfb77ff1f59700d01a0243664"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1563577e2b720d267c4cffacc0f6a2b5c8480ea966ebdb1844fbea6602c7496f"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:49b7964532a801b96062d78c0222d155873968f823a546a3dbe63d73f25bb56f"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2362fd70eb639a75fd0187d28b4ba7b20b3088833d8ad7ffd8693d0ba159e1c2"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:86b80719636a4e21391e20a9adf18173ee6ae2ec956726fe2ff587417b5e8ba6"}, + {file = "google_re2-1.1-2-cp310-cp310-win32.whl", hash = "sha256:5456fba09df951fe8d1714474ed1ecda102a68ddffab0113e6c117d2e64e6f2b"}, + {file = "google_re2-1.1-2-cp310-cp310-win_amd64.whl", hash = "sha256:2ac6936a3a60d8d9de9563e90227b3aea27068f597274ca192c999a12d8baa8f"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d5a87b436028ec9b0f02fe19d4cbc19ef30441085cdfcdf1cce8fbe5c4bd5e9a"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:fc0d4163de9ed2155a77e7a2d59d94c348a6bbab3cff88922fab9e0d3d24faec"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:48b12d953bc796736e7831d67b36892fb6419a4cc44cb16521fe291e594bfe23"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:62c780c927cff98c1538439f0ff616f48a9b2e8837c676f53170d8ae5b9e83cb"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:04b2aefd768aa4edeef8b273327806c9cb0b82e90ff52eacf5d11003ac7a0db2"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:9c90175992346519ee7546d9af9a64541c05b6b70346b0ddc54a48aa0d3b6554"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22ad9ad9d125249d6386a2e80efb9de7af8260b703b6be7fa0ab069c1cf56ced"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f70971f6ffe5254e476e71d449089917f50ebf9cf60f9cec80975ab1693777e2"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f267499529e64a4abed24c588f355ebe4700189d434d84a7367725f5a186e48d"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b632eff5e4cd44545a9c0e52f2e1becd55831e25f4dd4e0d7ec8ee6ca50858c1"}, + {file = "google_re2-1.1-2-cp311-cp311-win32.whl", hash = "sha256:a42c733036e8f242ee4e5f0e27153ad4ca44ced9e4ce82f3972938ddee528db0"}, + {file = "google_re2-1.1-2-cp311-cp311-win_amd64.whl", hash = "sha256:64f8eed4ca96905d99b5286b3d14b5ca4f6a025ff3c1351626a7df2f93ad1ddd"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5541efcca5b5faf7e0d882334a04fa479bad4e7433f94870f46272eec0672c4a"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:92309af35b6eb2d3b3dc57045cdd83a76370958ab3e0edd2cc4638f6d23f5b32"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:197cd9bcaba96d18c5bf84d0c32fca7a26c234ea83b1d3083366f4392cb99f78"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:1b896f171d29b541256cf26e10dccc9103ac1894683914ed88828ca6facf8dca"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:e022d3239b945014e916ca7120fee659b246ec26c301f9e0542f1a19b38a8744"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:2c73f8a9440873b68bee1198094377501065e85aaf6fcc0d2512c7589ffa06ca"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:901d86555bd7725506d651afaba7d71cd4abd13260aed6cfd7c641a45f76d4f6"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ce4710ff636701cfb56eb91c19b775d53b03749a23b7d2a5071bbbf4342a9067"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76a20e5ebdf5bc5d430530197e42a2eeb562f729d3a3fb51f39168283d676e66"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:77c9f4d4bb1c8de9d2642d3c4b8b615858ba764df025b3b4f1310266f8def269"}, + {file = "google_re2-1.1-2-cp38-cp38-win32.whl", hash = "sha256:94bd60785bf37ef130a1613738e3c39465a67eae3f3be44bb918540d39b68da3"}, + {file = "google_re2-1.1-2-cp38-cp38-win_amd64.whl", hash = "sha256:59efeb77c0dcdbe37794c61f29c5b1f34bc06e8ec309a111ccdd29d380644d70"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:221e38c27e1dd9ccb8e911e9c7aed6439f68ce81e7bb74001076830b0d6e931d"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d9145879e6c2e1b814445300b31f88a675e1f06c57564670d95a1442e8370c27"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:c8a12f0740e2a52826bdbf95569a4b0abdf413b4012fa71e94ad25dd4715c6e5"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:9c9998f71466f4db7bda752aa7c348b2881ff688e361108fe500caad1d8b9cb2"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:0c39f69b702005963a3d3bf78743e1733ad73efd7e6e8465d76e3009e4694ceb"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:6d0ce762dee8d6617d0b1788a9653e805e83a23046c441d0ea65f1e27bf84114"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ecf3619d98c9b4a7844ab52552ad32597cdbc9a5bdbc7e3435391c653600d1e2"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a1426a8cbd1fa004974574708d496005bd379310c4b1c7012be4bc75efde7a8"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1a30626ba48b4070f3eab272d860ef1952e710b088792c4d68dddb155be6bfc"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1b9c1ffcfbc3095b6ff601ec2d2bf662988f6ea6763bc1c9d52bec55881f8fde"}, + {file = "google_re2-1.1-2-cp39-cp39-win32.whl", hash = "sha256:32ecf995a252c0548404c1065ba4b36f1e524f1f4a86b6367a1a6c3da3801e30"}, + {file = "google_re2-1.1-2-cp39-cp39-win_amd64.whl", hash = "sha256:e7865410f3b112a3609739283ec3f4f6f25aae827ff59c6bfdf806fd394d753e"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3b21f83f0a201009c56f06fcc7294a33555ede97130e8a91b3f4cae01aed1d73"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b38194b91354a38db1f86f25d09cdc6ac85d63aee4c67b43da3048ce637adf45"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:e7da3da8d6b5a18d6c3b61b11cc5b66b8564eaedce99d2312b15b6487730fc76"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:aeca656fb10d8638f245331aabab59c9e7e051ca974b366dd79e6a9efb12e401"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:2069d6dc94f5fa14a159bf99cad2f11e9c0f8ec3b7f44a4dde9e59afe5d1c786"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:2319a39305a4931cb5251451f2582713418a19bef2af7adf9e2a7a0edd939b99"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:eb98fc131699756c6d86246f670a5e1c1cc1ba85413c425ad344cb30479b246c"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a6e038986d8ffe4e269f8532f03009f229d1f6018d4ac0dabc8aff876338f6e0"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8618343ee658310e0f53bf586fab7409de43ce82bf8d9f7eb119536adc9783fd"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d8140ca861cfe00602319cefe2c7b8737b379eb07fb328b51dc44584f47a2718"}, + {file = "google_re2-1.1-3-cp310-cp310-win32.whl", hash = "sha256:41f439c5c54e8a3a0a1fa2dbd1e809d3f643f862df7b16dd790f36a1238a272e"}, + {file = "google_re2-1.1-3-cp310-cp310-win_amd64.whl", hash = "sha256:fe20e97a33176d96d3e4b5b401de35182b9505823abea51425ec011f53ef5e56"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c39ff52b1765db039f690ee5b7b23919d8535aae94db7996079fbde0098c4d7"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:5420be674fd164041639ba4c825450f3d4bd635572acdde16b3dcd697f8aa3ef"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:ff53881cf1ce040f102a42d39db93c3f835f522337ae9c79839a842f26d97733"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:8d04600b0b53523118df2e413a71417c408f20dee640bf07dfab601c96a18a77"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:c4835d4849faa34a7fa1074098d81c420ed6c0707a3772482b02ce14f2a7c007"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:3309a9b81251d35fee15974d0ae0581a9a375266deeafdc3a3ac0d172a742357"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2b51cafee7e0bc72d0a4a454547bd8f257cde412ac9f1a2dc46a203b5e42cf4"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:83f5f1cb52f832c2297d271ee8c56cf5e9053448162e5d2223d513f729bad908"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:55865a1ace92be3f7953b2e2b38b901d8074a367aa491daee43260a53a7fc6f0"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cec2167dd142e583e98c783bd0d28b8cf5a9cdbe1f7407ba4163fe3ccb613cb9"}, + {file = "google_re2-1.1-3-cp311-cp311-win32.whl", hash = "sha256:a0bc1fe96849e4eb8b726d0bba493f5b989372243b32fe20729cace02e5a214d"}, + {file = "google_re2-1.1-3-cp311-cp311-win_amd64.whl", hash = "sha256:e6310a156db96fc5957cb007dd2feb18476898654530683897469447df73a7cd"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8e63cd10ea006088b320e8c5d308da1f6c87aa95138a71c60dd7ca1c8e91927e"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:12b566830a334178733a85e416b1e0507dbc0ceb322827616fe51ef56c5154f1"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:442e18c9d46b225c1496919c16eafe8f8d9bb4091b00b4d3440da03c55bbf4ed"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:c54c00263a9c39b2dacd93e9636319af51e3cf885c080b9680a9631708326460"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:15a3caeeb327bc22e0c9f95eb76890fec8874cacccd2b01ff5c080ab4819bbec"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:59ec0d2cced77f715d41f6eafd901f6b15c11e28ba25fe0effdc1de554d78e75"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:185bf0e3441aed3840590f8e42f916e2920d235eb14df2cbc2049526803d3e71"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:586d3f2014eea5be14d8de53374d9b79fa99689160e00efa64b5fe93af326087"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc2575082de4ffd234d9607f3ae67ca22b15a1a88793240e2045f3b3a36a5795"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:59c5ad438eddb3630def394456091284d7bbc5b89351987f94f3792d296d1f96"}, + {file = "google_re2-1.1-3-cp312-cp312-win32.whl", hash = "sha256:5b9878c53f2bf16f75bf71d4ddd57f6611351408d5821040e91c53ebdf82c373"}, + {file = "google_re2-1.1-3-cp312-cp312-win_amd64.whl", hash = "sha256:4fdecfeb213110d0a85bad335a8e7cdb59fea7de81a4fe659233f487171980f9"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2dd87bacab32b709c28d0145fe75a956b6a39e28f0726d867375dba5721c76c1"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:55d24c61fe35dddc1bb484593a57c9f60f9e66d7f31f091ef9608ed0b6dde79f"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a0cf1180d908622df648c26b0cd09281f92129805ccc56a39227fdbfeab95cb4"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:09586f07f3f88d432265c75976da1c619ab7192cd7ebdf53f4ae0776c19e4b56"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:539f1b053402203576e919a06749198da4ae415931ee28948a1898131ae932ce"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:abf0bcb5365b0e27a5a23f3da403dffdbbac2c0e3a3f1535a8b10cc121b5d5fb"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:19c83e5bbed7958213eeac3aa71c506525ce54faf03e07d0b96cd0a764890511"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3348e77330ff672dc44ec01894fa5d93c409a532b6d688feac55e714e9059920"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:06b63edb57c5ce5a13eabfd71155e346b9477dc8906dec7c580d4f70c16a7e0d"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:12fe57ba2914092b83338d61d8def9ebd5a2bd0fd8679eceb5d4c2748105d5c0"}, + {file = "google_re2-1.1-3-cp38-cp38-win32.whl", hash = "sha256:80796e08d24e606e675019fe8de4eb5c94bb765be13c384f2695247d54a6df75"}, + {file = "google_re2-1.1-3-cp38-cp38-win_amd64.whl", hash = "sha256:3c2257dedfe7cc5deb6791e563af9e071a9d414dad89e37ac7ad22f91be171a9"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43a0cd77c87c894f28969ac622f94b2e6d1571261dfdd785026848a25cfdc9b9"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:1038990b77fd66f279bd66a0832b67435ea925e15bb59eafc7b60fdec812b616"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fb5dda6875d18dd45f0f24ebced6d1f7388867c8fb04a235d1deab7ea479ce38"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:bb1d164965c6d57a351b421d2f77c051403766a8b75aaa602324ee2451fff77f"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:a072ebfa495051d07ffecbf6ce21eb84793568d5c3c678c00ed8ff6b8066ab31"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:4eb66c8398c8a510adc97978d944b3b29c91181237218841ea1a91dc39ec0e54"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f7c8b57b1f559553248d1757b7fa5b2e0cc845666738d155dff1987c2618264e"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9162f6aa4f25453c682eb176f21b8e2f40205be9f667e98a54b3e1ff10d6ee75"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2d65ddf67fd7bf94705626871d463057d3d9a3538d41022f95b9d8f01df36e1"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d140c7b9395b4d1e654127aa1c99bcc603ed01000b7bc7e28c52562f1894ec12"}, + {file = "google_re2-1.1-3-cp39-cp39-win32.whl", hash = "sha256:80c5fc200f64b2d903eeb07b8d6cefc620a872a0240c7caaa9aca05b20f5568f"}, + {file = "google_re2-1.1-3-cp39-cp39-win_amd64.whl", hash = "sha256:9eb6dbcee9b5dc4069bbc0634f2eb039ca524a14bed5868fdf6560aaafcbca06"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0db114d7e1aa96dbcea452a40136d7d747d60cbb61394965774688ef59cccd4e"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:82133958e003a1344e5b7a791b9a9dd7560b5c8f96936dbe16f294604524a633"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:9e74fd441d1f3d917d3303e319f61b82cdbd96b9a5ba919377a6eef1504a1e2b"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:734a2e7a4541c57253b5ebee24f3f3366ba3658bcad01da25fb623c78723471a"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:d88d5eecbc908abe16132456fae13690d0508f3ac5777f320ef95cb6cab9a961"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:b91db80b171ecec435a07977a227757dd487356701a32f556fa6fca5d0a40522"}, + {file = "google_re2-1.1-4-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b23129887a64bb9948af14c84705273ed1a40054e99433b4acccab4dcf6a226"}, + {file = "google_re2-1.1-4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5dc1a0cc7cd19261dcaf76763e2499305dbb7e51dc69555167cdb8af98782698"}, + {file = "google_re2-1.1-4-cp310-cp310-win32.whl", hash = "sha256:3b2ab1e2420b5dd9743a2d6bc61b64e5f708563702a75b6db86637837eaeaf2f"}, + {file = "google_re2-1.1-4-cp310-cp310-win_amd64.whl", hash = "sha256:92efca1a7ef83b6df012d432a1cbc71d10ff42200640c0f9a5ff5b343a48e633"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:854818fd4ce79787aca5ba459d6e5abe4ca9be2c684a5b06a7f1757452ca3708"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:4ceef51174b6f653b6659a8fdaa9c38960c5228b44b25be2a3bcd8566827554f"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:ee49087c3db7e6f5238105ab5299c09e9b77516fe8cfb0a37e5f1e813d76ecb8"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:dc2312854bdc01410acc5d935f1906a49cb1f28980341c20a68797ad89d8e178"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:0dc0d2e42296fa84a3cb3e1bd667c6969389cd5cdf0786e6b1f911ae2d75375b"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6bf04ced98453b035f84320f348f67578024f44d2997498def149054eb860ae8"}, + {file = "google_re2-1.1-4-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1d6b6ef11dc4ab322fa66c2f3561925f2b5372a879c3ed764d20e939e2fd3e5f"}, + {file = "google_re2-1.1-4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0dcde6646fa9a97fd3692b3f6ae7daf7f3277d7500b6c253badeefa11db8956a"}, + {file = "google_re2-1.1-4-cp311-cp311-win32.whl", hash = "sha256:5f4f0229deb057348893574d5b0a96d055abebac6debf29d95b0c0e26524c9f6"}, + {file = "google_re2-1.1-4-cp311-cp311-win_amd64.whl", hash = "sha256:4713ddbe48a18875270b36a462b0eada5e84d6826f8df7edd328d8706b6f9d07"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:40a698300b8faddbb325662973f839489c89b960087060bd389c376828978a04"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:103d2d7ac92ba23911a151fd1fc7035cbf6dc92a7f6aea92270ebceb5cd5acd3"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:51fb7182bccab05e8258a2b6a63dda1a6b4a9e8dfb9b03ec50e50c49c2827dd4"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:65383022abd63d7b620221eba7935132b53244b8b463d8fdce498c93cf58b7b7"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:396281fc68a9337157b3ffcd9392c6b7fcb8aab43e5bdab496262a81d56a4ecc"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:8198adcfcff1c680e052044124621730fc48d08005f90a75487f5651f1ebfce2"}, + {file = "google_re2-1.1-4-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:81f7bff07c448aec4db9ca453d2126ece8710dbd9278b8bb09642045d3402a96"}, + {file = "google_re2-1.1-4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7dacf730fd7d6ec71b11d6404b0b26e230814bfc8e9bb0d3f13bec9b5531f8d"}, + {file = "google_re2-1.1-4-cp312-cp312-win32.whl", hash = "sha256:8c764f62f4b1d89d1ef264853b6dd9fee14a89e9b86a81bc2157fe3531425eb4"}, + {file = "google_re2-1.1-4-cp312-cp312-win_amd64.whl", hash = "sha256:0be2666df4bc5381a5d693585f9bbfefb0bfd3c07530d7e403f181f5de47254a"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:5cb1b63a0bfd8dd65d39d2f3b2e5ae0a06ce4b2ce5818a1d1fc78a786a252673"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:e41751ce6b67a95230edd0772226dc94c2952a2909674cd69df9804ed0125307"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:b998cfa2d50bf4c063e777c999a7e8645ec7e5d7baf43ad71b1e2e10bb0300c3"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:226ca3b0c2e970f3fc82001ac89e845ecc7a4bb7c68583e7a76cda70b61251a7"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:9adec1f734ebad7c72e56c85f205a281d8fe9bf6583bc21020157d3f2812ce89"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:9c34f3c64ba566af967d29e11299560e6fdfacd8ca695120a7062b6ed993b179"}, + {file = "google_re2-1.1-4-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e1b85385fe293838e0d0b6e19e6c48ba8c6f739ea92ce2e23b718afe7b343363"}, + {file = "google_re2-1.1-4-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4694daa8a8987cfb568847aa872f9990e930c91a68c892ead876411d4b9012c3"}, + {file = "google_re2-1.1-4-cp38-cp38-win32.whl", hash = "sha256:5e671e9be1668187e2995aac378de574fa40df70bb6f04657af4d30a79274ce0"}, + {file = "google_re2-1.1-4-cp38-cp38-win_amd64.whl", hash = "sha256:f66c164d6049a8299f6dfcfa52d1580576b4b9724d6fcdad2f36f8f5da9304b6"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:25cb17ae0993a48c70596f3a3ef5d659638106401cc8193f51c0d7961b3b3eb7"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:5f101f86d14ca94ca4dcf63cceaa73d351f2be2481fcaa29d9e68eeab0dc2a88"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:4e82591e85bf262a6d74cff152867e05fc97867c68ba81d6836ff8b0e7e62365"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:1f61c09b93ffd34b1e2557e5a9565039f935407a5786dbad46f64f1a484166e6"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:12b390ad8c7e74bab068732f774e75e0680dade6469b249a721f3432f90edfc3"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:1284343eb31c2e82ed2d8159f33ba6842238a56782c881b07845a6d85613b055"}, + {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c7b38e0daf2c06e4d3163f4c732ab3ad2521aecfed6605b69e4482c612da303"}, + {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f4d4f0823e8b2f6952a145295b1ff25245ce9bb136aff6fe86452e507d4c1dd"}, + {file = "google_re2-1.1-4-cp39-cp39-win32.whl", hash = "sha256:1afae56b2a07bb48cfcfefaa15ed85bae26a68f5dc7f9e128e6e6ea36914e847"}, + {file = "google_re2-1.1-4-cp39-cp39-win_amd64.whl", hash = "sha256:aa7d6d05911ab9c8adbf3c225a7a120ab50fd2784ac48f2f0d140c0b7afc2b55"}, ] [[package]] @@ -3949,6 +4113,38 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "huggingface-hub" +version = "0.20.3" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "huggingface_hub-0.20.3-py3-none-any.whl", hash = "sha256:d988ae4f00d3e307b0c80c6a05ca6dbb7edba8bba3079f74cda7d9c2e562a7b6"}, + {file = "huggingface_hub-0.20.3.tar.gz", hash = "sha256:94e7f8e074475fbc67d6a71957b678e1b4a74ff1b64a644fd6cbb83da962d05d"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + [[package]] name = "humanfriendly" version = "10.0" @@ -4290,6 +4486,42 @@ docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] +[[package]] +name = "lancedb" +version = "0.9.0" +description = "lancedb" +optional = false +python-versions = ">=3.8" +files = [ + {file = "lancedb-0.9.0-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:b1ca08797c72c93ae512aa1078f1891756da157d910fbae8e194fac3528fc1ac"}, + {file = "lancedb-0.9.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:15129791f03c2c04b95f914ced2c1556b43d73a24710207b9af77b6e4008bdeb"}, + {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f093d89447a2039b820d2540a0b64df3024e4549b6808ebd26b44fbe0345cc6"}, + {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:a8c1f6777e217d2277451038866d280fa5fb38bd161795e51703b043c26dd345"}, + {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:78dd5800a1148f89d33b7e98d1c8b1c42dee146f03580abc1ca83cb05273ff7f"}, + {file = "lancedb-0.9.0-cp38-abi3-win_amd64.whl", hash = "sha256:ba5bdc727d3bc131f17414f42372acde5817073feeb553793a3d20003caa1658"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +cachetools = "*" +deprecation = "*" +overrides = ">=0.7" +packaging = "*" +pydantic = ">=1.10" +pylance = "0.13.0" +ratelimiter = ">=1.0,<2.0" +requests = ">=2.31.0" +retry = ">=0.9.2" +tqdm = ">=4.27.0" + +[package.extras] +azure = ["adlfs (>=2024.2.0)"] +clip = ["open-clip", "pillow", "torch"] +dev = ["pre-commit", "ruff"] +docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] +tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] + [[package]] name = "lazy-object-proxy" version = "1.9.0" @@ -4438,6 +4670,24 @@ sqlalchemy = ["sqlalchemy"] test = ["mock", "pytest", "pytest-cov (<2.6)"] zmq = ["pyzmq"] +[[package]] +name = "loguru" +version = "0.7.2" +description = "Python logging made (stupidly) simple" +optional = true +python-versions = ">=3.5" +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] + [[package]] name = "lxml" version = "4.9.3" @@ -4448,10 +4698,13 @@ files = [ {file = "lxml-4.9.3-cp27-cp27m-macosx_11_0_x86_64.whl", hash = "sha256:b0a545b46b526d418eb91754565ba5b63b1c0b12f9bd2f808c852d9b4b2f9b5c"}, {file = "lxml-4.9.3-cp27-cp27m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:075b731ddd9e7f68ad24c635374211376aa05a281673ede86cbe1d1b3455279d"}, {file = "lxml-4.9.3-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1e224d5755dba2f4a9498e150c43792392ac9b5380aa1b845f98a1618c94eeef"}, + {file = "lxml-4.9.3-cp27-cp27m-win32.whl", hash = "sha256:2c74524e179f2ad6d2a4f7caf70e2d96639c0954c943ad601a9e146c76408ed7"}, + {file = "lxml-4.9.3-cp27-cp27m-win_amd64.whl", hash = "sha256:4f1026bc732b6a7f96369f7bfe1a4f2290fb34dce00d8644bc3036fb351a4ca1"}, {file = "lxml-4.9.3-cp27-cp27mu-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c0781a98ff5e6586926293e59480b64ddd46282953203c76ae15dbbbf302e8bb"}, {file = "lxml-4.9.3-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cef2502e7e8a96fe5ad686d60b49e1ab03e438bd9123987994528febd569868e"}, {file = "lxml-4.9.3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b86164d2cff4d3aaa1f04a14685cbc072efd0b4f99ca5708b2ad1b9b5988a991"}, {file = "lxml-4.9.3-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:42871176e7896d5d45138f6d28751053c711ed4d48d8e30b498da155af39aebd"}, + {file = "lxml-4.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:ae8b9c6deb1e634ba4f1930eb67ef6e6bf6a44b6eb5ad605642b2d6d5ed9ce3c"}, {file = "lxml-4.9.3-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:411007c0d88188d9f621b11d252cce90c4a2d1a49db6c068e3c16422f306eab8"}, {file = "lxml-4.9.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:cd47b4a0d41d2afa3e58e5bf1f62069255aa2fd6ff5ee41604418ca925911d76"}, {file = "lxml-4.9.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e2cb47860da1f7e9a5256254b74ae331687b9672dfa780eed355c4c9c3dbd23"}, @@ -4460,6 +4713,7 @@ files = [ {file = "lxml-4.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:97047f0d25cd4bcae81f9ec9dc290ca3e15927c192df17331b53bebe0e3ff96d"}, {file = "lxml-4.9.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:1f447ea5429b54f9582d4b955f5f1985f278ce5cf169f72eea8afd9502973dd5"}, {file = "lxml-4.9.3-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:57d6ba0ca2b0c462f339640d22882acc711de224d769edf29962b09f77129cbf"}, + {file = "lxml-4.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:9767e79108424fb6c3edf8f81e6730666a50feb01a328f4a016464a5893f835a"}, {file = "lxml-4.9.3-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:71c52db65e4b56b8ddc5bb89fb2e66c558ed9d1a74a45ceb7dcb20c191c3df2f"}, {file = "lxml-4.9.3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d73d8ecf8ecf10a3bd007f2192725a34bd62898e8da27eb9d32a58084f93962b"}, {file = "lxml-4.9.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0a3d3487f07c1d7f150894c238299934a2a074ef590b583103a45002035be120"}, @@ -4479,6 +4733,7 @@ files = [ {file = "lxml-4.9.3-cp36-cp36m-macosx_11_0_x86_64.whl", hash = "sha256:64f479d719dc9f4c813ad9bb6b28f8390360660b73b2e4beb4cb0ae7104f1c12"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:dd708cf4ee4408cf46a48b108fb9427bfa00b9b85812a9262b5c668af2533ea5"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c31c7462abdf8f2ac0577d9f05279727e698f97ecbb02f17939ea99ae8daa98"}, + {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:e3cd95e10c2610c360154afdc2f1480aea394f4a4f1ea0a5eacce49640c9b190"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:4930be26af26ac545c3dffb662521d4e6268352866956672231887d18f0eaab2"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4aec80cde9197340bc353d2768e2a75f5f60bacda2bab72ab1dc499589b3878c"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:14e019fd83b831b2e61baed40cab76222139926b1fb5ed0e79225bc0cae14584"}, @@ -4488,6 +4743,7 @@ files = [ {file = "lxml-4.9.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bef4e656f7d98aaa3486d2627e7d2df1157d7e88e7efd43a65aa5dd4714916cf"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:46f409a2d60f634fe550f7133ed30ad5321ae2e6630f13657fb9479506b00601"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:4c28a9144688aef80d6ea666c809b4b0e50010a2aca784c97f5e6bf143d9f129"}, + {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:141f1d1a9b663c679dc524af3ea1773e618907e96075262726c7612c02b149a4"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:53ace1c1fd5a74ef662f844a0413446c0629d151055340e9893da958a374f70d"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:17a753023436a18e27dd7769e798ce302963c236bc4114ceee5b25c18c52c693"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7d298a1bd60c067ea75d9f684f5f3992c9d6766fadbc0bcedd39750bf344c2f4"}, @@ -4497,6 +4753,7 @@ files = [ {file = "lxml-4.9.3-cp37-cp37m-win_amd64.whl", hash = "sha256:120fa9349a24c7043854c53cae8cec227e1f79195a7493e09e0c12e29f918e52"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:4d2d1edbca80b510443f51afd8496be95529db04a509bc8faee49c7b0fb6d2cc"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:8d7e43bd40f65f7d97ad8ef5c9b1778943d02f04febef12def25f7583d19baac"}, + {file = "lxml-4.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:71d66ee82e7417828af6ecd7db817913cb0cf9d4e61aa0ac1fde0583d84358db"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:6fc3c450eaa0b56f815c7b62f2b7fba7266c4779adcf1cece9e6deb1de7305ce"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:65299ea57d82fb91c7f019300d24050c4ddeb7c5a190e076b5f48a2b43d19c42"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:eadfbbbfb41b44034a4c757fd5d70baccd43296fb894dba0295606a7cf3124aa"}, @@ -4506,6 +4763,7 @@ files = [ {file = "lxml-4.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:92af161ecbdb2883c4593d5ed4815ea71b31fafd7fd05789b23100d081ecac96"}, {file = "lxml-4.9.3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:9bb6ad405121241e99a86efff22d3ef469024ce22875a7ae045896ad23ba2340"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:8ed74706b26ad100433da4b9d807eae371efaa266ffc3e9191ea436087a9d6a7"}, + {file = "lxml-4.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fbf521479bcac1e25a663df882c46a641a9bff6b56dc8b0fafaebd2f66fb231b"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:303bf1edce6ced16bf67a18a1cf8339d0db79577eec5d9a6d4a80f0fb10aa2da"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:5515edd2a6d1a5a70bfcdee23b42ec33425e405c5b351478ab7dc9347228f96e"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:690dafd0b187ed38583a648076865d8c229661ed20e48f2335d68e2cf7dc829d"}, @@ -4516,13 +4774,16 @@ files = [ {file = "lxml-4.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:4dd9a263e845a72eacb60d12401e37c616438ea2e5442885f65082c276dfb2b2"}, {file = "lxml-4.9.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6689a3d7fd13dc687e9102a27e98ef33730ac4fe37795d5036d18b4d527abd35"}, {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:f6bdac493b949141b733c5345b6ba8f87a226029cbabc7e9e121a413e49441e0"}, + {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:05186a0f1346ae12553d66df1cfce6f251589fea3ad3da4f3ef4e34b2d58c6a3"}, {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c2006f5c8d28dee289f7020f721354362fa304acbaaf9745751ac4006650254b"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-macosx_11_0_x86_64.whl", hash = "sha256:5c245b783db29c4e4fbbbfc9c5a78be496c9fea25517f90606aa1f6b2b3d5f7b"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:4fb960a632a49f2f089d522f70496640fdf1218f1243889da3822e0a9f5f3ba7"}, + {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:50670615eaf97227d5dc60de2dc99fb134a7130d310d783314e7724bf163f75d"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:9719fe17307a9e814580af1f5c6e05ca593b12fb7e44fe62450a5384dbf61b4b"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:3331bece23c9ee066e0fb3f96c61322b9e0f54d775fccefff4c38ca488de283a"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-macosx_11_0_x86_64.whl", hash = "sha256:ed667f49b11360951e201453fc3967344d0d0263aa415e1619e85ae7fd17b4e0"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:8b77946fd508cbf0fccd8e400a7f71d4ac0e1595812e66025bac475a8e811694"}, + {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:e4da8ca0c0c0aea88fd46be8e44bd49716772358d648cce45fe387f7b92374a7"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fe4bda6bd4340caa6e5cf95e73f8fea5c4bfc55763dd42f1b50a94c1b4a2fbd4"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f3df3db1d336b9356dd3112eae5f5c2b8b377f3bc826848567f10bfddfee77e9"}, {file = "lxml-4.9.3.tar.gz", hash = "sha256:48628bd53a426c9eb9bc066a923acaa0878d1e86129fd5359aee99285f4eed9c"}, @@ -4683,6 +4944,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -5468,35 +5739,36 @@ reference = ["Pillow", "google-re2"] [[package]] name = "onnxruntime" -version = "1.16.1" +version = "1.18.0" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = true python-versions = "*" files = [ - {file = "onnxruntime-1.16.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:28b2c7f444b4119950b69370801cd66067f403d19cbaf2a444735d7c269cce4a"}, - {file = "onnxruntime-1.16.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c24e04f33e7899f6aebb03ed51e51d346c1f906b05c5569d58ac9a12d38a2f58"}, - {file = "onnxruntime-1.16.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fa93b166f2d97063dc9f33c5118c5729a4a5dd5617296b6dbef42f9047b3e81"}, - {file = "onnxruntime-1.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:042dd9201b3016ee18f8f8bc4609baf11ff34ca1ff489c0a46bcd30919bf883d"}, - {file = "onnxruntime-1.16.1-cp310-cp310-win32.whl", hash = "sha256:c20aa0591f305012f1b21aad607ed96917c86ae7aede4a4dd95824b3d124ceb7"}, - {file = "onnxruntime-1.16.1-cp310-cp310-win_amd64.whl", hash = "sha256:5581873e578917bea76d6434ee7337e28195d03488dcf72d161d08e9398c6249"}, - {file = "onnxruntime-1.16.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:ef8c0c8abf5f309aa1caf35941380839dc5f7a2fa53da533be4a3f254993f120"}, - {file = "onnxruntime-1.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e680380bea35a137cbc3efd67a17486e96972901192ad3026ee79c8d8fe264f7"}, - {file = "onnxruntime-1.16.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e62cc38ce1a669013d0a596d984762dc9c67c56f60ecfeee0d5ad36da5863f6"}, - {file = "onnxruntime-1.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:025c7a4d57bd2e63b8a0f84ad3df53e419e3df1cc72d63184f2aae807b17c13c"}, - {file = "onnxruntime-1.16.1-cp311-cp311-win32.whl", hash = "sha256:9ad074057fa8d028df248b5668514088cb0937b6ac5954073b7fb9b2891ffc8c"}, - {file = "onnxruntime-1.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:d5e43a3478bffc01f817ecf826de7b25a2ca1bca8547d70888594ab80a77ad24"}, - {file = "onnxruntime-1.16.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:3aef4d70b0930e29a8943eab248cd1565664458d3a62b2276bd11181f28fd0a3"}, - {file = "onnxruntime-1.16.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:55a7b843a57c8ca0c8ff169428137958146081d5d76f1a6dd444c4ffcd37c3c2"}, - {file = "onnxruntime-1.16.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62c631af1941bf3b5f7d063d24c04aacce8cff0794e157c497e315e89ac5ad7b"}, - {file = "onnxruntime-1.16.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5671f296c3d5c233f601e97a10ab5a1dd8e65ba35c7b7b0c253332aba9dff330"}, - {file = "onnxruntime-1.16.1-cp38-cp38-win32.whl", hash = "sha256:eb3802305023dd05e16848d4e22b41f8147247894309c0c27122aaa08793b3d2"}, - {file = "onnxruntime-1.16.1-cp38-cp38-win_amd64.whl", hash = "sha256:fecfb07443d09d271b1487f401fbdf1ba0c829af6fd4fe8f6af25f71190e7eb9"}, - {file = "onnxruntime-1.16.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:de3e12094234db6545c67adbf801874b4eb91e9f299bda34c62967ef0050960f"}, - {file = "onnxruntime-1.16.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ff723c2a5621b5e7103f3be84d5aae1e03a20621e72219dddceae81f65f240af"}, - {file = "onnxruntime-1.16.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14a7fb3073aaf6b462e3d7fb433320f7700558a8892e5021780522dc4574292a"}, - {file = "onnxruntime-1.16.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:963159f1f699b0454cd72fcef3276c8a1aab9389a7b301bcd8e320fb9d9e8597"}, - {file = "onnxruntime-1.16.1-cp39-cp39-win32.whl", hash = "sha256:85771adb75190db9364b25ddec353ebf07635b83eb94b64ed014f1f6d57a3857"}, - {file = "onnxruntime-1.16.1-cp39-cp39-win_amd64.whl", hash = "sha256:d32d2b30799c1f950123c60ae8390818381fd5f88bdf3627eeca10071c155dc5"}, + {file = "onnxruntime-1.18.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:5a3b7993a5ecf4a90f35542a4757e29b2d653da3efe06cdd3164b91167bbe10d"}, + {file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:15b944623b2cdfe7f7945690bfb71c10a4531b51997c8320b84e7b0bb59af902"}, + {file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e61ce5005118064b1a0ed73ebe936bc773a102f067db34108ea6c64dd62a179"}, + {file = "onnxruntime-1.18.0-cp310-cp310-win32.whl", hash = "sha256:a4fc8a2a526eb442317d280610936a9f73deece06c7d5a91e51570860802b93f"}, + {file = "onnxruntime-1.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:71ed219b768cab004e5cd83e702590734f968679bf93aa488c1a7ffbe6e220c3"}, + {file = "onnxruntime-1.18.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:3d24bd623872a72a7fe2f51c103e20fcca2acfa35d48f2accd6be1ec8633d960"}, + {file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f15e41ca9b307a12550bfd2ec93f88905d9fba12bab7e578f05138ad0ae10d7b"}, + {file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f45ca2887f62a7b847d526965686b2923efa72538c89b7703c7b3fe970afd59"}, + {file = "onnxruntime-1.18.0-cp311-cp311-win32.whl", hash = "sha256:9e24d9ecc8781323d9e2eeda019b4b24babc4d624e7d53f61b1fe1a929b0511a"}, + {file = "onnxruntime-1.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:f8608398976ed18aef450d83777ff6f77d0b64eced1ed07a985e1a7db8ea3771"}, + {file = "onnxruntime-1.18.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:f1d79941f15fc40b1ee67738b2ca26b23e0181bf0070b5fb2984f0988734698f"}, + {file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e8caf3a8565c853a22d323a3eebc2a81e3de7591981f085a4f74f7a60aab2d"}, + {file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:498d2b8380635f5e6ebc50ec1b45f181588927280f32390fb910301d234f97b8"}, + {file = "onnxruntime-1.18.0-cp312-cp312-win32.whl", hash = "sha256:ba7cc0ce2798a386c082aaa6289ff7e9bedc3dee622eef10e74830cff200a72e"}, + {file = "onnxruntime-1.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:1fa175bd43f610465d5787ae06050c81f7ce09da2bf3e914eb282cb8eab363ef"}, + {file = "onnxruntime-1.18.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:0284c579c20ec8b1b472dd190290a040cc68b6caec790edb960f065d15cf164a"}, + {file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d47353d036d8c380558a5643ea5f7964d9d259d31c86865bad9162c3e916d1f6"}, + {file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:885509d2b9ba4b01f08f7fa28d31ee54b6477953451c7ccf124a84625f07c803"}, + {file = "onnxruntime-1.18.0-cp38-cp38-win32.whl", hash = "sha256:8614733de3695656411d71fc2f39333170df5da6c7efd6072a59962c0bc7055c"}, + {file = "onnxruntime-1.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:47af3f803752fce23ea790fd8d130a47b2b940629f03193f780818622e856e7a"}, + {file = "onnxruntime-1.18.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:9153eb2b4d5bbab764d0aea17adadffcfc18d89b957ad191b1c3650b9930c59f"}, + {file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c7fd86eca727c989bb8d9c5104f3c45f7ee45f445cc75579ebe55d6b99dfd7c"}, + {file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac67a4de9c1326c4d87bcbfb652c923039b8a2446bb28516219236bec3b494f5"}, + {file = "onnxruntime-1.18.0-cp39-cp39-win32.whl", hash = "sha256:6ffb445816d06497df7a6dd424b20e0b2c39639e01e7fe210e247b82d15a23b9"}, + {file = "onnxruntime-1.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:46de6031cb6745f33f7eca9e51ab73e8c66037fb7a3b6b4560887c5b55ab5d5d"}, ] [package.dependencies] @@ -6680,6 +6952,32 @@ ray = ["ray[data]"] tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] torch = ["torch"] +[[package]] +name = "pylance" +version = "0.13.0" +description = "python wrapper for Lance columnar format" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pylance-0.13.0-cp39-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2f3d6f9eec1f59f45dccb01075ba79868b8d37c8371d6210bcf6418217a0dd8b"}, + {file = "pylance-0.13.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f4861ab466c94b0f9a4b4e6de6e1dfa02f40e7242d8db87447bc7bb7d89606ac"}, + {file = "pylance-0.13.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3cb92547e145f5bfb0ea7d6f483953913b9bdd44c45bea84fc95a18da9f5853"}, + {file = "pylance-0.13.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:d1ddd7700924bc6b6b0774ea63d2aa23f9210a86cd6d6af0cdfa987df776d50d"}, + {file = "pylance-0.13.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c51d4b6e59cf4dc97c11a35b299f11e80dbdf392e2d8dc498573c26474a3c19e"}, + {file = "pylance-0.13.0-cp39-abi3-win_amd64.whl", hash = "sha256:4018ba016f1445874960a4ba2ad5c80cb380f3116683282ee8beabd38fa8989d"}, +] + +[package.dependencies] +numpy = ">=1.22" +pyarrow = ">=12,<15.0.1" + +[package.extras] +benchmarks = ["pytest-benchmark"] +dev = ["ruff (==0.4.1)"] +ray = ["ray[data]"] +tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] +torch = ["torch"] + [[package]] name = "pymongo" version = "4.6.0" @@ -6717,6 +7015,7 @@ files = [ {file = "pymongo-4.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ab6bcc8e424e07c1d4ba6df96f7fb963bcb48f590b9456de9ebd03b88084fe8"}, {file = "pymongo-4.6.0-cp312-cp312-win32.whl", hash = "sha256:47aa128be2e66abd9d1a9b0437c62499d812d291f17b55185cb4aa33a5f710a4"}, {file = "pymongo-4.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:014e7049dd019a6663747ca7dae328943e14f7261f7c1381045dfc26a04fa330"}, + {file = "pymongo-4.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e24025625bad66895b1bc3ae1647f48f0a92dd014108fb1be404c77f0b69ca67"}, {file = "pymongo-4.6.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:288c21ab9531b037f7efa4e467b33176bc73a0c27223c141b822ab4a0e66ff2a"}, {file = "pymongo-4.6.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:747c84f4e690fbe6999c90ac97246c95d31460d890510e4a3fa61b7d2b87aa34"}, {file = "pymongo-4.6.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:055f5c266e2767a88bb585d01137d9c7f778b0195d3dbf4a487ef0638be9b651"}, @@ -7157,6 +7456,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -7164,8 +7464,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -7182,6 +7490,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -7189,6 +7498,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -7196,30 +7506,30 @@ files = [ [[package]] name = "qdrant-client" -version = "1.6.4" +version = "1.9.1" description = "Client library for the Qdrant vector search engine" optional = true -python-versions = ">=3.8,<3.13" +python-versions = ">=3.8" files = [ - {file = "qdrant_client-1.6.4-py3-none-any.whl", hash = "sha256:db4696978d6a62d78ff60f70b912383f1e467bda3053f732b01ddb5f93281b10"}, - {file = "qdrant_client-1.6.4.tar.gz", hash = "sha256:bbd65f383b6a55a9ccf4e301250fa925179340dd90cfde9b93ce4230fd68867b"}, + {file = "qdrant_client-1.9.1-py3-none-any.whl", hash = "sha256:b9b7e0e5c1a51410d8bb5106a869a51e12f92ab45a99030f27aba790553bd2c8"}, + {file = "qdrant_client-1.9.1.tar.gz", hash = "sha256:186b9c31d95aefe8f2db84b7746402d7365bd63b305550e530e31bde2002ce79"}, ] [package.dependencies] -fastembed = {version = "0.1.1", optional = true, markers = "python_version < \"3.12\" and extra == \"fastembed\""} +fastembed = {version = "0.2.6", optional = true, markers = "python_version < \"3.13\" and extra == \"fastembed\""} grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" -httpx = {version = ">=0.14.0", extras = ["http2"]} +httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, {version = ">=1.26", markers = "python_version >= \"3.12\""}, ] portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" -urllib3 = ">=1.26.14,<2.0.0" +urllib3 = ">=1.26.14,<3" [package.extras] -fastembed = ["fastembed (==0.1.1)"] +fastembed = ["fastembed (==0.2.6)"] [[package]] name = "ratelimiter" @@ -8112,6 +8422,7 @@ files = [ {file = "SQLAlchemy-1.4.49-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:03db81b89fe7ef3857b4a00b63dedd632d6183d4ea5a31c5d8a92e000a41fc71"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:95b9df9afd680b7a3b13b38adf6e3a38995da5e162cc7524ef08e3be4e5ed3e1"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a63e43bf3f668c11bb0444ce6e809c1227b8f067ca1068898f3008a273f52b09"}, + {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca46de16650d143a928d10842939dab208e8d8c3a9a8757600cae9b7c579c5cd"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f835c050ebaa4e48b18403bed2c0fda986525896efd76c245bdd4db995e51a4c"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c21b172dfb22e0db303ff6419451f0cac891d2e911bb9fbf8003d717f1bcf91"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-win32.whl", hash = "sha256:5fb1ebdfc8373b5a291485757bd6431de8d7ed42c27439f543c81f6c8febd729"}, @@ -8121,26 +8432,35 @@ files = [ {file = "SQLAlchemy-1.4.49-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5debe7d49b8acf1f3035317e63d9ec8d5e4d904c6e75a2a9246a119f5f2fdf3d"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win32.whl", hash = "sha256:82b08e82da3756765c2e75f327b9bf6b0f043c9c3925fb95fb51e1567fa4ee87"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win_amd64.whl", hash = "sha256:171e04eeb5d1c0d96a544caf982621a1711d078dbc5c96f11d6469169bd003f1"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f23755c384c2969ca2f7667a83f7c5648fcf8b62a3f2bbd883d805454964a800"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8396e896e08e37032e87e7fbf4a15f431aa878c286dc7f79e616c2feacdb366c"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66da9627cfcc43bbdebd47bfe0145bb662041472393c03b7802253993b6b7c90"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-win32.whl", hash = "sha256:9a06e046ffeb8a484279e54bda0a5abfd9675f594a2e38ef3133d7e4d75b6214"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-win_amd64.whl", hash = "sha256:7cf8b90ad84ad3a45098b1c9f56f2b161601e4670827d6b892ea0e884569bd1d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:36e58f8c4fe43984384e3fbe6341ac99b6b4e083de2fe838f0fdb91cebe9e9cb"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b31e67ff419013f99ad6f8fc73ee19ea31585e1e9fe773744c0f3ce58c039c30"}, + {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc22807a7e161c0d8f3da34018ab7c97ef6223578fcdd99b1d3e7ed1100a5db"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c14b29d9e1529f99efd550cd04dbb6db6ba5d690abb96d52de2bff4ed518bc95"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c40f3470e084d31247aea228aa1c39bbc0904c2b9ccbf5d3cfa2ea2dac06f26d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win32.whl", hash = "sha256:706bfa02157b97c136547c406f263e4c6274a7b061b3eb9742915dd774bbc264"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win_amd64.whl", hash = "sha256:a7f7b5c07ae5c0cfd24c2db86071fb2a3d947da7bd487e359cc91e67ac1c6d2e"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:4afbbf5ef41ac18e02c8dc1f86c04b22b7a2125f2a030e25bbb4aff31abb224b"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24e300c0c2147484a002b175f4e1361f102e82c345bf263242f0449672a4bccf"}, + {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:393cd06c3b00b57f5421e2133e088df9cabcececcea180327e43b937b5a7caa5"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:201de072b818f8ad55c80d18d1a788729cccf9be6d9dc3b9d8613b053cd4836d"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7653ed6817c710d0c95558232aba799307d14ae084cc9b1f4c389157ec50df5c"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win32.whl", hash = "sha256:647e0b309cb4512b1f1b78471fdaf72921b6fa6e750b9f891e09c6e2f0e5326f"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win_amd64.whl", hash = "sha256:ab73ed1a05ff539afc4a7f8cf371764cdf79768ecb7d2ec691e3ff89abbc541e"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:37ce517c011560d68f1ffb28af65d7e06f873f191eb3a73af5671e9c3fada08a"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1878ce508edea4a879015ab5215546c444233881301e97ca16fe251e89f1c55"}, + {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95ab792ca493891d7a45a077e35b418f68435efb3e1706cb8155e20e86a9013c"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0e8e608983e6f85d0852ca61f97e521b62e67969e6e640fe6c6b575d4db68557"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccf956da45290df6e809ea12c54c02ace7f8ff4d765d6d3dfb3655ee876ce58d"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win32.whl", hash = "sha256:f167c8175ab908ce48bd6550679cc6ea20ae169379e73c7720a28f89e53aa532"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win_amd64.whl", hash = "sha256:45806315aae81a0c202752558f0df52b42d11dd7ba0097bf71e253b4215f34f4"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:b6d0c4b15d65087738a6e22e0ff461b407533ff65a73b818089efc8eb2b3e1de"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a843e34abfd4c797018fd8d00ffffa99fd5184c421f190b6ca99def4087689bd"}, + {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:738d7321212941ab19ba2acf02a68b8ee64987b248ffa2101630e8fccb549e0d"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c890421651b45a681181301b3497e4d57c0d01dc001e10438a40e9a9c25ee77"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d26f280b8f0a8f497bc10573849ad6dc62e671d2468826e5c748d04ed9e670d5"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-win32.whl", hash = "sha256:ec2268de67f73b43320383947e74700e95c6770d0c68c4e615e9897e46296294"}, @@ -8370,56 +8690,129 @@ twisted = ["twisted"] [[package]] name = "tokenizers" -version = "0.13.3" -description = "Fast and Customizable Tokenizers" +version = "0.15.2" +description = "" optional = true -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, - {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, - {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, - {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, - {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, - {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, - {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, - {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, - {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, - {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, - {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, - {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, -] - -[package.extras] -dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] + {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, + {file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9b9b070fdad06e347563b88c278995735292ded1132f8657084989a4c84a6d5"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea621a7eef4b70e1f7a4e84dd989ae3f0eeb50fc8690254eacc08acb623e82f1"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf7fd9a5141634fa3aa8d6b7be362e6ae1b4cda60da81388fa533e0b552c98fd"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44f2a832cd0825295f7179eaf173381dc45230f9227ec4b44378322d900447c9"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b9ec69247a23747669ec4b0ca10f8e3dfb3545d550258129bd62291aabe8605"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b6a4c78da863ff26dbd5ad9a8ecc33d8a8d97b535172601cf00aee9d7ce9ce"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5ab2a4d21dcf76af60e05af8063138849eb1d6553a0d059f6534357bce8ba364"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a47acfac7e511f6bbfcf2d3fb8c26979c780a91e06fb5b9a43831b2c0153d024"}, + {file = "tokenizers-0.15.2-cp310-none-win32.whl", hash = "sha256:064ff87bb6acdbd693666de9a4b692add41308a2c0ec0770d6385737117215f2"}, + {file = "tokenizers-0.15.2-cp310-none-win_amd64.whl", hash = "sha256:3b919afe4df7eb6ac7cafd2bd14fb507d3f408db7a68c43117f579c984a73843"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:89cd1cb93e4b12ff39bb2d626ad77e35209de9309a71e4d3d4672667b4b256e7"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cfed5c64e5be23d7ee0f0e98081a25c2a46b0b77ce99a4f0605b1ec43dd481fa"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a907d76dcfda37023ba203ab4ceeb21bc5683436ebefbd895a0841fd52f6f6f2"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20ea60479de6fc7b8ae756b4b097572372d7e4032e2521c1bbf3d90c90a99ff0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48e2b9335be2bc0171df9281385c2ed06a15f5cf121c44094338306ab7b33f2c"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112a1dd436d2cc06e6ffdc0b06d55ac019a35a63afd26475205cb4b1bf0bfbff"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4620cca5c2817177ee8706f860364cc3a8845bc1e291aaf661fb899e5d1c45b0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccd73a82751c523b3fc31ff8194702e4af4db21dc20e55b30ecc2079c5d43cb7"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:107089f135b4ae7817affe6264f8c7a5c5b4fd9a90f9439ed495f54fcea56fb4"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0ff110ecc57b7aa4a594396525a3451ad70988e517237fe91c540997c4e50e29"}, + {file = "tokenizers-0.15.2-cp311-none-win32.whl", hash = "sha256:6d76f00f5c32da36c61f41c58346a4fa7f0a61be02f4301fd30ad59834977cc3"}, + {file = "tokenizers-0.15.2-cp311-none-win_amd64.whl", hash = "sha256:cc90102ed17271cf0a1262babe5939e0134b3890345d11a19c3145184b706055"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f86593c18d2e6248e72fb91c77d413a815153b8ea4e31f7cd443bdf28e467670"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0774bccc6608eca23eb9d620196687c8b2360624619623cf4ba9dc9bd53e8b51"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d0222c5b7c9b26c0b4822a82f6a7011de0a9d3060e1da176f66274b70f846b98"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3835738be1de66624fff2f4f6f6684775da4e9c00bde053be7564cbf3545cc66"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0143e7d9dcd811855c1ce1ab9bf5d96d29bf5e528fd6c7824d0465741e8c10fd"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db35825f6d54215f6b6009a7ff3eedee0848c99a6271c870d2826fbbedf31a38"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f5e64b0389a2be47091d8cc53c87859783b837ea1a06edd9d8e04004df55a5c"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e0480c452217edd35eca56fafe2029fb4d368b7c0475f8dfa3c5c9c400a7456"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a33ab881c8fe70474980577e033d0bc9a27b7ab8272896e500708b212995d834"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a308a607ca9de2c64c1b9ba79ec9a403969715a1b8ba5f998a676826f1a7039d"}, + {file = "tokenizers-0.15.2-cp312-none-win32.whl", hash = "sha256:b8fcfa81bcb9447df582c5bc96a031e6df4da2a774b8080d4f02c0c16b42be0b"}, + {file = "tokenizers-0.15.2-cp312-none-win_amd64.whl", hash = "sha256:38d7ab43c6825abfc0b661d95f39c7f8af2449364f01d331f3b51c94dcff7221"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:38bfb0204ff3246ca4d5e726e8cc8403bfc931090151e6eede54d0e0cf162ef0"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c861d35e8286a53e06e9e28d030b5a05bcbf5ac9d7229e561e53c352a85b1fc"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:936bf3842db5b2048eaa53dade907b1160f318e7c90c74bfab86f1e47720bdd6"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:620beacc3373277700d0e27718aa8b25f7b383eb8001fba94ee00aeea1459d89"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2735ecbbf37e52db4ea970e539fd2d450d213517b77745114f92867f3fc246eb"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:473c83c5e2359bb81b0b6fde870b41b2764fcdd36d997485e07e72cc3a62264a"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968fa1fb3c27398b28a4eca1cbd1e19355c4d3a6007f7398d48826bbe3a0f728"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:865c60ae6eaebdde7da66191ee9b7db52e542ed8ee9d2c653b6d190a9351b980"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7c0d8b52664ab2d4a8d6686eb5effc68b78608a9008f086a122a7b2996befbab"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f33dfbdec3784093a9aebb3680d1f91336c56d86cc70ddf88708251da1fe9064"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d44ba80988ff9424e33e0a49445072ac7029d8c0e1601ad25a0ca5f41ed0c1d6"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:dce74266919b892f82b1b86025a613956ea0ea62a4843d4c4237be2c5498ed3a"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0ef06b9707baeb98b316577acb04f4852239d856b93e9ec3a299622f6084e4be"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c73e2e74bbb07910da0d37c326869f34113137b23eadad3fc00856e6b3d9930c"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4eeb12daf02a59e29f578a865f55d87cd103ce62bd8a3a5874f8fdeaa82e336b"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba9f6895af58487ca4f54e8a664a322f16c26bbb442effd01087eba391a719e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccec77aa7150e38eec6878a493bf8c263ff1fa8a62404e16c6203c64c1f16a26"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f40604f5042ff210ba82743dda2b6aa3e55aa12df4e9f2378ee01a17e2855e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5645938a42d78c4885086767c70923abad047163d809c16da75d6b290cb30bbe"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05a77cbfebe28a61ab5c3891f9939cc24798b63fa236d84e5f29f3a85a200c00"}, + {file = "tokenizers-0.15.2-cp37-none-win32.whl", hash = "sha256:361abdc068e8afe9c5b818769a48624687fb6aaed49636ee39bec4e95e1a215b"}, + {file = "tokenizers-0.15.2-cp37-none-win_amd64.whl", hash = "sha256:7ef789f83eb0f9baeb4d09a86cd639c0a5518528f9992f38b28e819df397eb06"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4fe1f74a902bee74a3b25aff180fbfbf4f8b444ab37c4d496af7afd13a784ed2"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4b89038a684f40a6b15d6b09f49650ac64d951ad0f2a3ea9169687bbf2a8ba"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d05a1b06f986d41aed5f2de464c003004b2df8aaf66f2b7628254bcbfb72a438"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508711a108684111ec8af89d3a9e9e08755247eda27d0ba5e3c50e9da1600f6d"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:daa348f02d15160cb35439098ac96e3a53bacf35885072611cd9e5be7d333daa"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:494fdbe5932d3416de2a85fc2470b797e6f3226c12845cadf054dd906afd0442"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2d60f5246f4da9373f75ff18d64c69cbf60c3bca597290cea01059c336d2470"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93268e788825f52de4c7bdcb6ebc1fcd4a5442c02e730faa9b6b08f23ead0e24"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6fc7083ab404019fc9acafe78662c192673c1e696bd598d16dc005bd663a5cf9"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e39b41e5531d6b2122a77532dbea60e171ef87a3820b5a3888daa847df4153"}, + {file = "tokenizers-0.15.2-cp38-none-win32.whl", hash = "sha256:06cd0487b1cbfabefb2cc52fbd6b1f8d4c37799bd6c6e1641281adaa6b2504a7"}, + {file = "tokenizers-0.15.2-cp38-none-win_amd64.whl", hash = "sha256:5179c271aa5de9c71712e31cb5a79e436ecd0d7532a408fa42a8dbfa4bc23fd9"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:82f8652a74cc107052328b87ea8b34291c0f55b96d8fb261b3880216a9f9e48e"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:02458bee6f5f3139f1ebbb6d042b283af712c0981f5bc50edf771d6b762d5e4f"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c9a09cd26cca2e1c349f91aa665309ddb48d71636370749414fbf67bc83c5343"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158be8ea8554e5ed69acc1ce3fbb23a06060bd4bbb09029431ad6b9a466a7121"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ddba9a2b0c8c81633eca0bb2e1aa5b3a15362b1277f1ae64176d0f6eba78ab1"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ef5dd1d39797044642dbe53eb2bc56435308432e9c7907728da74c69ee2adca"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:454c203164e07a860dbeb3b1f4a733be52b0edbb4dd2e5bd75023ffa8b49403a"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cf6b7f1d4dc59af960e6ffdc4faffe6460bbfa8dce27a58bf75755ffdb2526d"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2ef09bbc16519f6c25d0c7fc0c6a33a6f62923e263c9d7cca4e58b8c61572afb"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c9a2ebdd2ad4ec7a68e7615086e633857c85e2f18025bd05d2a4399e6c5f7169"}, + {file = "tokenizers-0.15.2-cp39-none-win32.whl", hash = "sha256:918fbb0eab96fe08e72a8c2b5461e9cce95585d82a58688e7f01c2bd546c79d0"}, + {file = "tokenizers-0.15.2-cp39-none-win_amd64.whl", hash = "sha256:524e60da0135e106b254bd71f0659be9f89d83f006ea9093ce4d1fab498c6d0d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9b648a58281c4672212fab04e60648fde574877d0139cd4b4f93fe28ca8944"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7c7d18b733be6bbca8a55084027f7be428c947ddf871c500ee603e375013ffba"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13ca3611de8d9ddfbc4dc39ef54ab1d2d4aaa114ac8727dfdc6a6ec4be017378"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:237d1bf3361cf2e6463e6c140628e6406766e8b27274f5fcc62c747ae3c6f094"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67a0fe1e49e60c664915e9fb6b0cb19bac082ab1f309188230e4b2920230edb3"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4e022fe65e99230b8fd89ebdfea138c24421f91c1a4f4781a8f5016fd5cdfb4d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d857be2df69763362ac699f8b251a8cd3fac9d21893de129bc788f8baaef2693"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:708bb3e4283177236309e698da5fcd0879ce8fd37457d7c266d16b550bcbbd18"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c35e09e9899b72a76e762f9854e8750213f67567787d45f37ce06daf57ca78"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1257f4394be0d3b00de8c9e840ca5601d0a4a8438361ce9c2b05c7d25f6057b"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02272fe48280e0293a04245ca5d919b2c94a48b408b55e858feae9618138aeda"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dc3ad9ebc76eabe8b1d7c04d38be884b8f9d60c0cdc09b0aa4e3bcf746de0388"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:32e16bdeffa7c4f46bf2152172ca511808b952701d13e7c18833c0b73cb5c23f"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fb16ba563d59003028b678d2361a27f7e4ae0ab29c7a80690efa20d829c81fdb"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:2277c36d2d6cdb7876c274547921a42425b6810d38354327dd65a8009acf870c"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cf75d32e8d250781940d07f7eece253f2fe9ecdb1dc7ba6e3833fa17b82fcbc"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b3b31884dc8e9b21508bb76da80ebf7308fdb947a17affce815665d5c4d028"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10122d8d8e30afb43bb1fe21a3619f62c3e2574bff2699cf8af8b0b6c5dc4a3"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d88b96ff0fe8e91f6ef01ba50b0d71db5017fa4e3b1d99681cec89a85faf7bf7"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:37aaec5a52e959892870a7c47cef80c53797c0db9149d458460f4f31e2fb250e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e2ea752f2b0fe96eb6e2f3adbbf4d72aaa1272079b0dfa1145507bd6a5d537e6"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b19a808d8799fda23504a5cd31d2f58e6f52f140380082b352f877017d6342b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c86e5e068ac8b19204419ed8ca90f9d25db20578f5881e337d203b314f4104"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de19c4dc503c612847edf833c82e9f73cd79926a384af9d801dcf93f110cea4e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea09acd2fe3324174063d61ad620dec3bcf042b495515f27f638270a7d466e8b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cf27fd43472e07b57cf420eee1e814549203d56de00b5af8659cb99885472f1f"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7ca22bd897537a0080521445d91a58886c8c04084a6a19e6c78c586e0cfa92a5"}, + {file = "tokenizers-0.15.2.tar.gz", hash = "sha256:e6e9c6e019dd5484be5beafc775ae6c925f4c69a3487040ed09b45e13df2cb91"}, +] + +[package.dependencies] +huggingface_hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] [[package]] @@ -8610,6 +9003,17 @@ files = [ {file = "types_PyYAML-6.0.12.11-py3-none-any.whl", hash = "sha256:a461508f3096d1d5810ec5ab95d7eeecb651f3a15b71959999988942063bf01d"}, ] +[[package]] +name = "types-regex" +version = "2024.5.15.20240519" +description = "Typing stubs for regex" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-regex-2024.5.15.20240519.tar.gz", hash = "sha256:ef3f594a95a95d6b9b5704a1facf3511a73e4731209ddb8868461db4c42dc12b"}, + {file = "types_regex-2024.5.15.20240519-py3-none-any.whl", hash = "sha256:d5895079cc66f91ae8818aeef14e9337c492ceb87ad0ff3df8c1c04d418cb9dd"}, +] + [[package]] name = "types-requests" version = "2.31.0.2" @@ -8933,6 +9337,20 @@ files = [ {file = "win_precise_time-1.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f510fa92d9c39ea533c983e1d62c7bc66fdf0a3e3c3bdda48d4ebb634ff7034"}, ] +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = true +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + [[package]] name = "wrapt" version = "1.15.0" @@ -9224,6 +9642,7 @@ duckdb = ["duckdb"] filesystem = ["botocore", "s3fs"] gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] gs = ["gcsfs"] +lancedb = ["lancedb", "pyarrow"] motherduck = ["duckdb", "pyarrow"] mssql = ["pyodbc"] parquet = ["pyarrow"] @@ -9238,4 +9657,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "4ca5f4a7955437d6da09be909a729172b9a663cc0649227e6088dc1c2cd27e57" +content-hash = "e517168f2ff67c46f3b37d7dcde88b73a1e2ae0d6890243b4c6d1e0aa504eff7" diff --git a/pyproject.toml b/pyproject.toml index b99c9e4051..849626314a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.4.13a0" +version = "0.5.1a0" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] @@ -73,11 +73,12 @@ pipdeptree = {version = ">=2.9.0,<2.10", optional = true} pyathena = {version = ">=2.9.6", optional = true} weaviate-client = {version = ">=3.22", optional = true} adlfs = {version = ">=2022.4.0", optional = true} -pyodbc = {version = "^4.0.39", optional = true} -qdrant-client = {version = "^1.6.4", optional = true, extras = ["fastembed"]} -databricks-sql-connector = {version = ">=3", optional = true} +pyodbc = {version = ">=4.0.39", optional = true} +qdrant-client = {version = ">=1.8", optional = true, extras = ["fastembed"]} +databricks-sql-connector = {version = ">=2.9.3", optional = true} clickhouse-driver = { version = ">=0.2.7", optional = true } clickhouse-connect = { version = ">=0.7.7", optional = true } +lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'" } deltalake = { version = ">=0.17.4", optional = true } [tool.poetry.extras] @@ -103,8 +104,10 @@ qdrant = ["qdrant-client"] databricks = ["databricks-sql-connector"] clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"] dremio = ["pyarrow"] +lancedb = ["lancedb", "pyarrow"] deltalake = ["deltalake", "pyarrow"] + [tool.poetry.scripts] dlt = "dlt.cli._dlt:_main" @@ -149,6 +152,7 @@ types-pytz = ">=2024.1.0.20240203" ruff = "^0.3.2" pyjwt = "^2.8.0" pytest-mock = "^3.14.0" +types-regex = "^2024.5.15.20240519" [tool.poetry.group.pipeline] optional = true diff --git a/tests/common/cases/destinations/null.py b/tests/common/cases/destinations/null.py index b2054cd7e8..37e87d89cf 100644 --- a/tests/common/cases/destinations/null.py +++ b/tests/common/cases/destinations/null.py @@ -14,7 +14,7 @@ def __init__(self, **kwargs: Any) -> None: spec = DestinationClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: + def _raw_capabilities(self) -> DestinationCapabilitiesContext: return DestinationCapabilitiesContext.generic_capabilities() @property diff --git a/tests/common/cases/normalizers/__init__.py b/tests/common/cases/normalizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/common/cases/normalizers/snake_no_x.py b/tests/common/cases/normalizers/snake_no_x.py new file mode 100644 index 0000000000..af3a53cbce --- /dev/null +++ b/tests/common/cases/normalizers/snake_no_x.py @@ -0,0 +1,10 @@ +from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention + + +class NamingConvention(SnakeCaseNamingConvention): + def normalize_identifier(self, identifier: str) -> str: + identifier = super().normalize_identifier(identifier) + if identifier.endswith("x"): + print(identifier[:-1] + "_") + return identifier[:-1] + "_" + return identifier diff --git a/tests/common/cases/normalizers/sql_upper.py b/tests/common/cases/normalizers/sql_upper.py new file mode 100644 index 0000000000..f2175f06ad --- /dev/null +++ b/tests/common/cases/normalizers/sql_upper.py @@ -0,0 +1,18 @@ +from typing import Any, Sequence + +from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention + + +class NamingConvention(BaseNamingConvention): + PATH_SEPARATOR = "__" + + _CLEANUP_TABLE = str.maketrans(".\n\r'\"▶", "______") + + @property + def is_case_sensitive(self) -> bool: + return True + + def normalize_identifier(self, identifier: str) -> str: + identifier = super().normalize_identifier(identifier) + norm_identifier = identifier.translate(self._CLEANUP_TABLE).upper() + return self.shorten_identifier(norm_identifier, identifier, self.max_length) diff --git a/tests/common/cases/normalizers/title_case.py b/tests/common/cases/normalizers/title_case.py new file mode 100644 index 0000000000..2b93b476c8 --- /dev/null +++ b/tests/common/cases/normalizers/title_case.py @@ -0,0 +1,15 @@ +from typing import ClassVar +from dlt.common.normalizers.naming.direct import NamingConvention as DirectNamingConvention + + +class NamingConvention(DirectNamingConvention): + """Test case sensitive naming that capitalizes first and last letter and leaves the rest intact""" + + PATH_SEPARATOR: ClassVar[str] = "__" + + def normalize_identifier(self, identifier: str) -> str: + # keep prefix + if identifier == "_dlt": + return "_dlt" + identifier = super().normalize_identifier(identifier) + return identifier[0].upper() + identifier[1:-1] + identifier[-1].upper() diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 48993971c2..7c3138ea73 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -621,7 +621,7 @@ class _SecretCredentials(RunConfiguration): "dlthub_telemetry": True, "dlthub_telemetry_endpoint": "https://telemetry-tracker.services4758.workers.dev", "dlthub_telemetry_segment_write_key": None, - "log_format": "{asctime}|[{levelname:<21}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}", + "log_format": "{asctime}|[{levelname}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}", "log_level": "WARNING", "request_timeout": 60, "request_max_attempts": 5, diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index f0494e9898..13d68b53e9 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -570,7 +570,19 @@ def get_cf(aux: str = dlt.config.value, last_config: AuxTest = None): def test_inject_spec_into_argument_with_spec_type() -> None: # if signature contains argument with type of SPEC, it gets injected there - from dlt.destinations.impl.dummy import _configure, DummyClientConfiguration + import dlt + from dlt.common.configuration import known_sections + from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration + + @with_config( + spec=DummyClientConfiguration, + sections=( + known_sections.DESTINATION, + "dummy", + ), + ) + def _configure(config: DummyClientConfiguration = dlt.config.value) -> DummyClientConfiguration: + return config # _configure has argument of type DummyClientConfiguration that it returns # this type holds resolved configuration diff --git a/tests/common/data_writers/test_data_writers.py b/tests/common/data_writers/test_data_writers.py index 6cc7cb55ab..9b4e61a2f7 100644 --- a/tests/common/data_writers/test_data_writers.py +++ b/tests/common/data_writers/test_data_writers.py @@ -7,11 +7,9 @@ from dlt.common.data_writers.exceptions import DataWriterNotFound, SpecLookupFailed from dlt.common.typing import AnyFun -# from dlt.destinations.postgres import capabilities -from dlt.destinations.impl.redshift import capabilities as redshift_caps from dlt.common.data_writers.escape import ( escape_redshift_identifier, - escape_bigquery_identifier, + escape_hive_identifier, escape_redshift_literal, escape_postgres_literal, escape_duckdb_literal, @@ -29,8 +27,10 @@ DataWriter, DataWriterMetrics, EMPTY_DATA_WRITER_METRICS, + ImportFileWriter, InsertValuesWriter, JsonlWriter, + create_import_spec, get_best_writer_spec, resolve_best_writer_spec, is_native_writer, @@ -51,8 +51,10 @@ class _BytesIOWriter(DataWriter): @pytest.fixture def insert_writer() -> Iterator[DataWriter]: + from dlt.destinations import redshift + with io.StringIO() as f: - yield InsertValuesWriter(f, caps=redshift_caps()) + yield InsertValuesWriter(f, caps=redshift().capabilities()) @pytest.fixture @@ -154,7 +156,7 @@ def test_identifier_escape() -> None: def test_identifier_escape_bigquery() -> None: assert ( - escape_bigquery_identifier(", NULL'); DROP TABLE\"` -\\-") + escape_hive_identifier(", NULL'); DROP TABLE\"` -\\-") == "`, NULL'); DROP TABLE\"\\` -\\\\-`" ) @@ -259,3 +261,14 @@ def test_get_best_writer() -> None: assert WRITER_SPECS[get_best_writer_spec("arrow", "insert_values")] == ArrowToInsertValuesWriter with pytest.raises(DataWriterNotFound): get_best_writer_spec("arrow", "tsv") # type: ignore + + +def test_import_file_writer() -> None: + spec = create_import_spec("jsonl", ["jsonl"]) + assert spec.data_item_format == "file" + assert spec.file_format == "jsonl" + writer = DataWriter.writer_class_from_spec(spec) + assert writer is ImportFileWriter + w_ = writer(None) + with pytest.raises(NotImplementedError): + w_.write_header(None) diff --git a/tests/common/normalizers/custom_normalizers.py b/tests/common/normalizers/custom_normalizers.py index 3ae65c8b53..4a0f456eef 100644 --- a/tests/common/normalizers/custom_normalizers.py +++ b/tests/common/normalizers/custom_normalizers.py @@ -11,6 +11,13 @@ def normalize_identifier(self, identifier: str) -> str: return "column_" + identifier.lower() +class ColumnNamingConvention(SnakeCaseNamingConvention): + def normalize_identifier(self, identifier: str) -> str: + if identifier.startswith("column_"): + return identifier + return "column_" + identifier.lower() + + class DataItemNormalizer(RelationalNormalizer): def extend_schema(self) -> None: json_config = self.schema._normalizers_config["json"]["config"] diff --git a/tests/common/normalizers/test_import_normalizers.py b/tests/common/normalizers/test_import_normalizers.py index df6b973943..fe356de327 100644 --- a/tests/common/normalizers/test_import_normalizers.py +++ b/tests/common/normalizers/test_import_normalizers.py @@ -1,14 +1,23 @@ import os - import pytest from dlt.common.configuration.container import Container from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.normalizers import explicit_normalizers, import_normalizers +from dlt.common.normalizers.typing import TNormalizersConfig +from dlt.common.normalizers.utils import ( + DEFAULT_NAMING_NAMESPACE, + explicit_normalizers, + import_normalizers, + naming_from_reference, + serialize_reference, +) from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer -from dlt.common.normalizers.naming import snake_case -from dlt.common.normalizers.naming import direct -from dlt.common.normalizers.naming.exceptions import InvalidNamingModule, UnknownNamingModule +from dlt.common.normalizers.naming import snake_case, direct +from dlt.common.normalizers.naming.exceptions import ( + InvalidNamingType, + NamingTypeNotFound, + UnknownNamingModule, +) from tests.common.normalizers.custom_normalizers import ( DataItemNormalizer as CustomRelationalNormalizer, @@ -16,7 +25,7 @@ from tests.utils import preserve_environ -def test_default_normalizers() -> None: +def test_explicit_normalizers() -> None: config = explicit_normalizers() assert config["names"] is None assert config["json"] is None @@ -26,6 +35,12 @@ def test_default_normalizers() -> None: assert config["names"] == "direct" assert config["json"] == {"module": "custom"} + # pass modules and types, make sure normalizer config is serialized + config = explicit_normalizers(direct) + assert config["names"] == f"{DEFAULT_NAMING_NAMESPACE}.direct.NamingConvention" + config = explicit_normalizers(direct.NamingConvention) + assert config["names"] == f"{DEFAULT_NAMING_NAMESPACE}.direct.NamingConvention" + # use environ os.environ["SCHEMA__NAMING"] = "direct" os.environ["SCHEMA__JSON_NORMALIZER"] = '{"module": "custom"}' @@ -34,13 +49,75 @@ def test_default_normalizers() -> None: assert config["json"] == {"module": "custom"} -def test_default_normalizers_with_caps() -> None: +def test_explicit_normalizers_caps_ignored() -> None: # gets the naming convention from capabilities destination_caps = DestinationCapabilitiesContext.generic_capabilities() destination_caps.naming_convention = "direct" with Container().injectable_context(destination_caps): config = explicit_normalizers() - assert config["names"] == "direct" + assert config["names"] is None + + +def test_serialize_reference() -> None: + assert serialize_reference(None) is None + assert serialize_reference("module") == "module" + assert ( + serialize_reference(snake_case) == f"{DEFAULT_NAMING_NAMESPACE}.snake_case.NamingConvention" + ) + assert ( + serialize_reference(snake_case.NamingConvention) + == f"{DEFAULT_NAMING_NAMESPACE}.snake_case.NamingConvention" + ) + # test a wrong module and type + with pytest.raises(NamingTypeNotFound): + serialize_reference(pytest) + with pytest.raises(ValueError): + serialize_reference(Container) # type: ignore[arg-type] + + +def test_naming_from_reference() -> None: + assert naming_from_reference("snake_case").name() == "snake_case" + assert naming_from_reference("snake_case.NamingConvention").name() == "snake_case" + + # now not visible + with pytest.raises(UnknownNamingModule): + naming_from_reference("custom_normalizers") + + # temporarily add current file dir to paths and import module that clash with dlt predefined (no path) + import sys + + try: + sys.path.insert(0, os.path.dirname(__file__)) + assert naming_from_reference("custom_normalizers").name() == "custom_normalizers" + assert ( + naming_from_reference("custom_normalizers.NamingConvention").name() + == "custom_normalizers" + ) + assert ( + naming_from_reference("custom_normalizers.ColumnNamingConvention").name() + == "custom_normalizers" + ) + finally: + sys.path.pop(0) + + # non standard location + assert ( + naming_from_reference("dlt.destinations.impl.weaviate.naming").name() + == "dlt.destinations.impl.weaviate.naming" + ) + + # import module + assert naming_from_reference(snake_case).name() == "snake_case" + assert naming_from_reference(snake_case.NamingConvention).name() == "snake_case" + + with pytest.raises(ValueError): + naming_from_reference(snake_case.NamingConvention()) # type: ignore[arg-type] + + # with capabilities + caps = DestinationCapabilitiesContext.generic_capabilities() + caps.max_identifier_length = 120 + naming = naming_from_reference(snake_case.NamingConvention, caps) + assert naming.max_length == 120 def test_import_normalizers() -> None: @@ -64,6 +141,40 @@ def test_import_normalizers() -> None: assert json_normalizer is CustomRelationalNormalizer +def test_import_normalizers_with_defaults() -> None: + explicit = explicit_normalizers() + default_: TNormalizersConfig = { + "names": "dlt.destinations.impl.weaviate.naming", + "json": {"module": "tests.common.normalizers.custom_normalizers"}, + } + config, naming, json_normalizer = import_normalizers(explicit, default_) + + assert config["names"] == "dlt.destinations.impl.weaviate.naming" + assert config["json"] == {"module": "tests.common.normalizers.custom_normalizers"} + assert naming.name() == "dlt.destinations.impl.weaviate.naming" + assert json_normalizer is CustomRelationalNormalizer + + # correctly overrides + explicit["names"] = "sql_cs_v1" + explicit["json"] = {"module": "dlt.common.normalizers.json.relational"} + config, naming, json_normalizer = import_normalizers(explicit, default_) + assert config["names"] == "sql_cs_v1" + assert config["json"] == {"module": "dlt.common.normalizers.json.relational"} + assert naming.name() == "sql_cs_v1" + assert json_normalizer is RelationalNormalizer + + +@pytest.mark.parametrize("sections", ("", "SOURCES__", "SOURCES__TEST_SCHEMA__")) +def test_config_sections(sections: str) -> None: + os.environ[f"{sections}SCHEMA__NAMING"] = "direct" + os.environ[f"{sections}SCHEMA__JSON_NORMALIZER"] = ( + '{"module": "tests.common.normalizers.custom_normalizers"}' + ) + config, _, _ = import_normalizers(explicit_normalizers(schema_name="test_schema")) + assert config["names"] == "direct" + assert config["json"] == {"module": "tests.common.normalizers.custom_normalizers"} + + def test_import_normalizers_with_caps() -> None: # gets the naming convention from capabilities destination_caps = DestinationCapabilitiesContext.generic_capabilities() @@ -74,6 +185,25 @@ def test_import_normalizers_with_caps() -> None: assert isinstance(naming, direct.NamingConvention) assert naming.max_length == 127 + _, naming, _ = import_normalizers(explicit_normalizers(snake_case)) + assert isinstance(naming, snake_case.NamingConvention) + assert naming.max_length == 127 + + # max table nesting generates relational normalizer + default_: TNormalizersConfig = { + "names": "dlt.destinations.impl.weaviate.naming", + "json": {"module": "tests.common.normalizers.custom_normalizers"}, + } + destination_caps.max_table_nesting = 0 + with Container().injectable_context(destination_caps): + config, _, relational = import_normalizers(explicit_normalizers()) + assert config["json"]["config"]["max_nesting"] == 0 + assert relational is RelationalNormalizer + + # wrong normalizer + config, _, relational = import_normalizers(explicit_normalizers(), default_) + assert "config" not in config["json"] + def test_import_invalid_naming_module() -> None: with pytest.raises(UnknownNamingModule) as py_ex: @@ -82,6 +212,7 @@ def test_import_invalid_naming_module() -> None: with pytest.raises(UnknownNamingModule) as py_ex: import_normalizers(explicit_normalizers("dlt.common.tests")) assert py_ex.value.naming_module == "dlt.common.tests" - with pytest.raises(InvalidNamingModule) as py_ex2: - import_normalizers(explicit_normalizers("dlt.pipeline")) + with pytest.raises(InvalidNamingType) as py_ex2: + import_normalizers(explicit_normalizers("dlt.pipeline.helpers")) assert py_ex2.value.naming_module == "dlt.pipeline" + assert py_ex2.value.naming_class == "helpers" diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index 502ce619dd..159e33da4d 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -2,16 +2,15 @@ from dlt.common.typing import StrAny, DictStrAny from dlt.common.normalizers.naming import NamingConvention -from dlt.common.schema.typing import TSimpleRegex +from dlt.common.schema.typing import TColumnName, TSimpleRegex from dlt.common.utils import digest128, uniq_id -from dlt.common.schema import Schema, TTableSchema +from dlt.common.schema import Schema from dlt.common.schema.utils import new_table from dlt.common.normalizers.json.relational import ( RelationalNormalizerConfigPropagation, DataItemNormalizer as RelationalNormalizer, DLT_ID_LENGTH_BYTES, - TDataItemRow, ) # _flatten, _get_child_row_hash, _normalize_row, normalize_data_item, @@ -30,7 +29,7 @@ def test_flatten_fix_field_name(norm: RelationalNormalizer) -> None: "f 2": [], "f!3": {"f4": "a", "f-5": "b", "f*6": {"c": 7, "c v": 8, "c x": []}}, } - flattened_row, lists = norm._flatten("mock_table", row, 0) # type: ignore[arg-type] + flattened_row, lists = norm._flatten("mock_table", row, 0) assert "f_1" in flattened_row # assert "f_2" in flattened_row assert "f_3__f4" in flattened_row @@ -63,12 +62,12 @@ def test_preserve_complex_value(norm: RelationalNormalizer) -> None: ) ) row_1 = {"value": 1} - flattened_row, _ = norm._flatten("with_complex", row_1, 0) # type: ignore[arg-type] - assert flattened_row["value"] == 1 # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("with_complex", row_1, 0) + assert flattened_row["value"] == 1 row_2 = {"value": {"complex": True}} - flattened_row, _ = norm._flatten("with_complex", row_2, 0) # type: ignore[arg-type] - assert flattened_row["value"] == row_2["value"] # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("with_complex", row_2, 0) + assert flattened_row["value"] == row_2["value"] # complex value is not flattened assert "value__complex" not in flattened_row @@ -79,12 +78,12 @@ def test_preserve_complex_value_with_hint(norm: RelationalNormalizer) -> None: norm.schema._compile_settings() row_1 = {"value": 1} - flattened_row, _ = norm._flatten("any_table", row_1, 0) # type: ignore[arg-type] - assert flattened_row["value"] == 1 # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("any_table", row_1, 0) + assert flattened_row["value"] == 1 row_2 = {"value": {"complex": True}} - flattened_row, _ = norm._flatten("any_table", row_2, 0) # type: ignore[arg-type] - assert flattened_row["value"] == row_2["value"] # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("any_table", row_2, 0) + assert flattened_row["value"] == row_2["value"] # complex value is not flattened assert "value__complex" not in flattened_row @@ -94,7 +93,7 @@ def test_child_table_linking(norm: RelationalNormalizer) -> None: # request _dlt_root_id propagation add_dlt_root_id_propagation(norm) - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) # should have 7 entries (root + level 1 + 3 * list + 2 * object) assert len(rows) == 7 # root elem will not have a root hash if not explicitly added, "extend" is added only to child @@ -142,7 +141,7 @@ def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: norm.schema.merge_hints({"primary_key": [TSimpleRegex("id")]}) norm.schema._compile_settings() - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) root = next(t for t in rows if t[0][0] == "table")[1] # record hash is random for primary keys, not based on their content # this is a change introduced in dlt 0.2.0a30 @@ -172,7 +171,7 @@ def test_yields_parents_first(norm: RelationalNormalizer) -> None: "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], "g": [{"id": "level2_g", "l": ["a"]}], } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) tables = list(r[0][0] for r in rows) # child tables are always yielded before parent tables expected_tables = [ @@ -218,7 +217,7 @@ def test_yields_parent_relation(norm: RelationalNormalizer) -> None: } ], } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) # normalizer must return parent table first and move in order of the list elements when yielding child tables # the yielding order if fully defined expected_parents = [ @@ -276,10 +275,10 @@ def test_yields_parent_relation(norm: RelationalNormalizer) -> None: def test_list_position(norm: RelationalNormalizer) -> None: - row: StrAny = { + row: DictStrAny = { "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}] } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) # root has no pos root = [t for t in rows if t[0][0] == "table"][0][1] assert "_dlt_list_idx" not in root @@ -290,13 +289,13 @@ def test_list_position(norm: RelationalNormalizer) -> None: # f_l must be ordered as it appears in the list for pos, elem in enumerate(["a", "b", "c"]): - row = next(t[1] for t in rows if t[0][0] == "table__f__l" and t[1]["value"] == elem) - assert row["_dlt_list_idx"] == pos + row_1 = next(t[1] for t in rows if t[0][0] == "table__f__l" and t[1]["value"] == elem) + assert row_1["_dlt_list_idx"] == pos # f_lo must be ordered - list of objects for pos, elem in enumerate(["a", "b", "c"]): - row = next(t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["e"] == elem) - assert row["_dlt_list_idx"] == pos + row_2 = next(t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["e"] == elem) + assert row_2["_dlt_list_idx"] == pos # def test_list_of_lists(norm: RelationalNormalizer) -> None: @@ -430,7 +429,7 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: "_dlt_id": row_id, "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}], } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) children = [t for t in rows if t[0][0] != "table"] # all hashes must be different distinct_hashes = set([ch[1]["_dlt_id"] for ch in children]) @@ -449,19 +448,19 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: assert f_lo_p2["_dlt_id"] == digest128(f"{el_f['_dlt_id']}_table__f__lo_2", DLT_ID_LENGTH_BYTES) # same data with same table and row_id - rows_2 = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows_2 = list(norm._normalize_row(row, {}, ("table",))) children_2 = [t for t in rows_2 if t[0][0] != "table"] # corresponding hashes must be identical assert all(ch[0][1]["_dlt_id"] == ch[1][1]["_dlt_id"] for ch in zip(children, children_2)) # change parent table and all child hashes must be different - rows_4 = list(norm._normalize_row(row, {}, ("other_table",))) # type: ignore[arg-type] + rows_4 = list(norm._normalize_row(row, {}, ("other_table",))) children_4 = [t for t in rows_4 if t[0][0] != "other_table"] assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_4)) # change parent hash and all child hashes must be different row["_dlt_id"] = uniq_id() - rows_3 = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows_3 = list(norm._normalize_row(row, {}, ("table",))) children_3 = [t for t in rows_3 if t[0][0] != "table"] assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_3)) @@ -469,14 +468,16 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: def test_keeps_dlt_id(norm: RelationalNormalizer) -> None: h = uniq_id() row = {"a": "b", "_dlt_id": h} - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) root = [t for t in rows if t[0][0] == "table"][0][1] assert root["_dlt_id"] == h def test_propagate_hardcoded_context(norm: RelationalNormalizer) -> None: row = {"level": 1, "list": ["a", "b", "c"], "comp": [{"_timestamp": "a"}]} - rows = list(norm._normalize_row(row, {"_timestamp": 1238.9, "_dist_key": "SENDER_3000"}, ("table",))) # type: ignore[arg-type] + rows = list( + norm._normalize_row(row, {"_timestamp": 1238.9, "_dist_key": "SENDER_3000"}, ("table",)) + ) # context is not added to root element root = next(t for t in rows if t[0][0] == "table")[1] assert "_timestamp" in root @@ -506,7 +507,7 @@ def test_propagates_root_context(norm: RelationalNormalizer) -> None: "dependent_list": [1, 2, 3], "dependent_objects": [{"vx": "ax"}], } - normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # all non-root rows must have: non_root = [r for r in normalized_rows if r[0][1] is not None] assert all(r[1]["_dlt_root_id"] == "###" for r in non_root) @@ -522,12 +523,12 @@ def test_propagates_table_context( prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ "config" ]["propagation"] - prop_config["root"]["timestamp"] = "_partition_ts" # type: ignore[index] + prop_config["root"][TColumnName("timestamp")] = TColumnName("_partition_ts") # for table "table__lvl1" request to propagate "vx" and "partition_ovr" as "_partition_ts" (should overwrite root) - prop_config["tables"]["table__lvl1"] = { # type: ignore[index] - "vx": "__vx", - "partition_ovr": "_partition_ts", - "__not_found": "__not_found", + prop_config["tables"]["table__lvl1"] = { + TColumnName("vx"): TColumnName("__vx"), + TColumnName("partition_ovr"): TColumnName("_partition_ts"), + TColumnName("__not_found"): TColumnName("__not_found"), } if add_pk: @@ -545,7 +546,7 @@ def test_propagates_table_context( # to reproduce a bug where rows with _dlt_id set were not extended row["lvl1"][0]["_dlt_id"] = "row_id_lvl1" # type: ignore[index] - normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) non_root = [r for r in normalized_rows if r[0][1] is not None] # _dlt_root_id in all non root assert all(r[1]["_dlt_root_id"] == "###" for r in non_root) @@ -574,10 +575,10 @@ def test_propagates_table_context_to_lists(norm: RelationalNormalizer) -> None: prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ "config" ]["propagation"] - prop_config["root"]["timestamp"] = "_partition_ts" # type: ignore[index] + prop_config["root"][TColumnName("timestamp")] = TColumnName("_partition_ts") row = {"_dlt_id": "###", "timestamp": 12918291.1212, "lvl1": [1, 2, 3, [4, 5, 6]]} - normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # _partition_ts == timestamp on all child tables non_root = [r for r in normalized_rows if r[0][1] is not None] assert all(r[1]["_partition_ts"] == 12918291.1212 for r in non_root) @@ -590,7 +591,7 @@ def test_removes_normalized_list(norm: RelationalNormalizer) -> None: # after normalizing the list that got normalized into child table must be deleted row = {"comp": [{"_timestamp": "a"}]} # get iterator - normalized_rows_i = norm._normalize_row(row, {}, ("table",)) # type: ignore[arg-type] + normalized_rows_i = norm._normalize_row(row, {}, ("table",)) # yield just one item root_row = next(normalized_rows_i) # root_row = next(r for r in normalized_rows if r[0][1] is None) @@ -614,7 +615,7 @@ def test_preserves_complex_types_list(norm: RelationalNormalizer) -> None: ) ) row = {"value": ["from", {"complex": True}]} - normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # make sure only 1 row is emitted, the list is not normalized assert len(normalized_rows) == 1 # value is kept in root row -> market as complex @@ -623,7 +624,7 @@ def test_preserves_complex_types_list(norm: RelationalNormalizer) -> None: # same should work for a list row = {"value": ["from", ["complex", True]]} # type: ignore[list-item] - normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # make sure only 1 row is emitted, the list is not normalized assert len(normalized_rows) == 1 # value is kept in root row -> market as complex @@ -735,7 +736,7 @@ def test_table_name_meta_normalized() -> None: def test_parse_with_primary_key() -> None: schema = create_schema_with_name("discord") - schema.merge_hints({"primary_key": ["id"]}) # type: ignore[list-item] + schema._merge_hints({"primary_key": ["id"]}) # type: ignore[list-item] schema._compile_settings() add_dlt_root_id_propagation(schema.data_item_normalizer) # type: ignore[arg-type] diff --git a/tests/common/normalizers/test_naming.py b/tests/common/normalizers/test_naming.py index 3bf4762c35..84d36537e6 100644 --- a/tests/common/normalizers/test_naming.py +++ b/tests/common/normalizers/test_naming.py @@ -2,13 +2,29 @@ import string from typing import List, Type -from dlt.common.normalizers.naming import NamingConvention -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention -from dlt.common.normalizers.naming.direct import NamingConvention as DirectNamingConvention +from dlt.common.normalizers.naming import ( + NamingConvention, + snake_case, + direct, + duck_case, + sql_ci_v1, + sql_cs_v1, +) from dlt.common.typing import DictStrStr from dlt.common.utils import uniq_id +ALL_NAMING_CONVENTIONS = { + snake_case.NamingConvention, + direct.NamingConvention, + duck_case.NamingConvention, + sql_ci_v1.NamingConvention, + sql_cs_v1.NamingConvention, +} + +ALL_UNDERSCORE_PATH_CONVENTIONS = ALL_NAMING_CONVENTIONS - {direct.NamingConvention} + + LONG_PATH = "prospects_external_data__data365_member__member__feed_activities_created_post__items__comments__items__comments__items__author_details__educations" DENSE_PATH = "__".join(string.ascii_lowercase) LONG_IDENT = 10 * string.printable @@ -139,7 +155,7 @@ def test_shorten_identifier() -> None: assert len(norm_ident) == 20 -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) -> None: naming = convention() # None/empty ident raises @@ -164,7 +180,7 @@ def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) - assert tag in naming.normalize_identifier(RAW_IDENT) -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_normalize_path_shorting(convention: Type[NamingConvention]) -> None: naming = convention() path = naming.make_path(*LONG_PATH.split("__")) @@ -207,10 +223,11 @@ def test_normalize_path_shorting(convention: Type[NamingConvention]) -> None: assert len(naming.break_path(norm_path)) == 1 -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_normalize_path(convention: Type[NamingConvention]) -> None: naming = convention() raw_path_str = naming.make_path(*RAW_PATH) + assert convention.PATH_SEPARATOR in raw_path_str # count separators norm_path_str = naming.normalize_path(raw_path_str) assert len(naming.break_path(norm_path_str)) == len(RAW_PATH) @@ -248,7 +265,7 @@ def test_normalize_path(convention: Type[NamingConvention]) -> None: assert tag in tagged_raw_path_str -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_shorten_fragments(convention: Type[NamingConvention]) -> None: # max length around the length of the path naming = convention() @@ -266,8 +283,30 @@ def test_shorten_fragments(convention: Type[NamingConvention]) -> None: assert naming.shorten_fragments(*RAW_PATH_WITH_EMPTY_IDENT) == norm_path -# 'event__parse_data__response_selector__default__response__response_templates' -# E 'event__parse_data__response_selector__default__response__responses' +@pytest.mark.parametrize("convention", ALL_UNDERSCORE_PATH_CONVENTIONS) +def test_normalize_break_path(convention: Type[NamingConvention]) -> None: + naming_unlimited = convention() + assert naming_unlimited.break_path("A__B__C") == ["A", "B", "C"] + # what if path has _a and _b which valid normalized idents + assert naming_unlimited.break_path("_a___b__C___D") == ["_a", "_b", "C", "_D"] + # skip empty identifiers from path + assert naming_unlimited.break_path("_a_____b") == ["_a", "_b"] + assert naming_unlimited.break_path("_a____b") == ["_a", "b"] + assert naming_unlimited.break_path("_a__ \t\r__b") == ["_a", "b"] + + +@pytest.mark.parametrize("convention", ALL_UNDERSCORE_PATH_CONVENTIONS) +def test_normalize_make_path(convention: Type[NamingConvention]) -> None: + naming_unlimited = convention() + assert naming_unlimited.make_path("A", "B") == "A__B" + assert naming_unlimited.make_path("_A", "_B") == "_A___B" + assert naming_unlimited.make_path("_A", "", "_B") == "_A___B" + assert naming_unlimited.make_path("_A", "\t\n ", "_B") == "_A___B" + + +def test_naming_convention_name() -> None: + assert snake_case.NamingConvention.name() == "snake_case" + assert direct.NamingConvention.name() == "direct" def assert_short_path(norm_path: str, naming: NamingConvention) -> None: diff --git a/tests/common/normalizers/test_naming_snake_case.py b/tests/common/normalizers/test_naming_snake_case.py index 6d619b5257..ee4f43e7f0 100644 --- a/tests/common/normalizers/test_naming_snake_case.py +++ b/tests/common/normalizers/test_naming_snake_case.py @@ -1,9 +1,7 @@ -from typing import Type import pytest from dlt.common.normalizers.naming import NamingConvention from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention -from dlt.common.normalizers.naming.duck_case import NamingConvention as DuckCaseNamingConvention @pytest.fixture @@ -54,30 +52,9 @@ def test_normalize_path(naming_unlimited: NamingConvention) -> None: def test_normalize_non_alpha_single_underscore() -> None: - assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "-=!*") == "_" - assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "1-=!0*-") == "1_0_" - assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "1-=!_0*-") == "1__0_" - - -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention)) -def test_normalize_break_path(convention: Type[NamingConvention]) -> None: - naming_unlimited = convention() - assert naming_unlimited.break_path("A__B__C") == ["A", "B", "C"] - # what if path has _a and _b which valid normalized idents - assert naming_unlimited.break_path("_a___b__C___D") == ["_a", "_b", "C", "_D"] - # skip empty identifiers from path - assert naming_unlimited.break_path("_a_____b") == ["_a", "_b"] - assert naming_unlimited.break_path("_a____b") == ["_a", "b"] - assert naming_unlimited.break_path("_a__ \t\r__b") == ["_a", "b"] - - -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention)) -def test_normalize_make_path(convention: Type[NamingConvention]) -> None: - naming_unlimited = convention() - assert naming_unlimited.make_path("A", "B") == "A__B" - assert naming_unlimited.make_path("_A", "_B") == "_A___B" - assert naming_unlimited.make_path("_A", "", "_B") == "_A___B" - assert naming_unlimited.make_path("_A", "\t\n ", "_B") == "_A___B" + assert SnakeCaseNamingConvention.RE_NON_ALPHANUMERIC.sub("_", "-=!*") == "_" + assert SnakeCaseNamingConvention.RE_NON_ALPHANUMERIC.sub("_", "1-=!0*-") == "1_0_" + assert SnakeCaseNamingConvention.RE_NON_ALPHANUMERIC.sub("_", "1-=!_0*-") == "1__0_" def test_normalizes_underscores(naming_unlimited: NamingConvention) -> None: diff --git a/tests/common/normalizers/test_naming_sql.py b/tests/common/normalizers/test_naming_sql.py new file mode 100644 index 0000000000..c290354c6a --- /dev/null +++ b/tests/common/normalizers/test_naming_sql.py @@ -0,0 +1,55 @@ +import pytest +from typing import Type +from dlt.common.normalizers.naming import NamingConvention, sql_ci_v1, sql_cs_v1 + +ALL_NAMING_CONVENTIONS = {sql_ci_v1.NamingConvention, sql_cs_v1.NamingConvention} + + +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) +def test_normalize_identifier(convention: Type[NamingConvention]) -> None: + naming = convention() + assert naming.normalize_identifier("event_value") == "event_value" + assert naming.normalize_identifier("event value") == "event_value" + assert naming.normalize_identifier("event-.!:*<>value") == "event_value" + # prefix leading digits + assert naming.normalize_identifier("1event_n'") == "_1event_n" + # remove trailing underscores + assert naming.normalize_identifier("123event_n'") == "_123event_n" + # contract underscores + assert naming.normalize_identifier("___a___b") == "_a_b" + # trim spaces + assert naming.normalize_identifier(" small love potion ") == "small_love_potion" + + # special characters converted to _ + assert naming.normalize_identifier("+-!$*@#=|:") == "_" + # leave single underscore + assert naming.normalize_identifier("_") == "_" + # some other cases + assert naming.normalize_identifier("+1") == "_1" + assert naming.normalize_identifier("-1") == "_1" + + +def test_case_sensitive_normalize() -> None: + naming = sql_cs_v1.NamingConvention() + # all lowercase and converted to snake + assert naming.normalize_identifier("123BaNaNa") == "_123BaNaNa" + # consecutive capital letters + assert naming.normalize_identifier("BANANA") == "BANANA" + assert naming.normalize_identifier("BAN_ANA") == "BAN_ANA" + assert naming.normalize_identifier("BANaNA") == "BANaNA" + # handling spaces + assert naming.normalize_identifier("Small Love Potion") == "Small_Love_Potion" + assert naming.normalize_identifier(" Small Love Potion ") == "Small_Love_Potion" + + +def test_case_insensitive_normalize() -> None: + naming = sql_ci_v1.NamingConvention() + # all lowercase and converted to snake + assert naming.normalize_identifier("123BaNaNa") == "_123banana" + # consecutive capital letters + assert naming.normalize_identifier("BANANA") == "banana" + assert naming.normalize_identifier("BAN_ANA") == "ban_ana" + assert naming.normalize_identifier("BANaNA") == "banana" + # handling spaces + assert naming.normalize_identifier("Small Love Potion") == "small_love_potion" + assert naming.normalize_identifier(" Small Love Potion ") == "small_love_potion" diff --git a/tests/common/schema/conftest.py b/tests/common/schema/conftest.py new file mode 100644 index 0000000000..53d02fc663 --- /dev/null +++ b/tests/common/schema/conftest.py @@ -0,0 +1,25 @@ +import pytest + +from dlt.common.configuration import resolve_configuration +from dlt.common.schema import Schema +from dlt.common.storages import SchemaStorageConfiguration, SchemaStorage + + +from tests.utils import autouse_test_storage, preserve_environ + + +@pytest.fixture +def schema() -> Schema: + return Schema("event") + + +@pytest.fixture +def schema_storage() -> SchemaStorage: + C = resolve_configuration( + SchemaStorageConfiguration(), + explicit_value={ + "import_schema_path": "tests/common/cases/schemas/rasa", + "external_schema_format": "json", + }, + ) + return SchemaStorage(C, makedirs=True) diff --git a/tests/common/schema/test_filtering.py b/tests/common/schema/test_filtering.py index 8cfac9309f..6634a38aa6 100644 --- a/tests/common/schema/test_filtering.py +++ b/tests/common/schema/test_filtering.py @@ -10,11 +10,6 @@ from tests.common.utils import load_json_case -@pytest.fixture -def schema() -> Schema: - return Schema("event") - - def test_row_field_filter(schema: Schema) -> None: _add_excludes(schema) bot_case: DictStrAny = load_json_case("mod_bot_case") diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index 0a40953f53..e2821d5626 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -1,3 +1,4 @@ +import os import pytest from copy import deepcopy from typing import Any, List @@ -16,11 +17,6 @@ from tests.common.utils import load_json_case -@pytest.fixture -def schema() -> Schema: - return Schema("event") - - def test_get_preferred_type(schema: Schema) -> None: _add_preferred_types(schema) @@ -204,11 +200,10 @@ def test_shorten_variant_column(schema: Schema) -> None: } _, new_table = schema.coerce_row("event_user", None, row_1) # schema assumes that identifiers are already normalized so confidence even if it is longer than 9 chars - schema.update_table(new_table) + schema.update_table(new_table, normalize_identifiers=False) assert "confidence" in schema.tables["event_user"]["columns"] # confidence_123456 # now variant is created and this will be normalized - # TODO: we should move the handling of variants to normalizer new_row_2, new_table = schema.coerce_row("event_user", None, {"confidence": False}) tag = schema.naming._compute_tag( "confidence__v_bool", collision_prob=schema.naming._DEFAULT_COLLISION_PROB @@ -219,6 +214,9 @@ def test_shorten_variant_column(schema: Schema) -> None: def test_coerce_complex_variant(schema: Schema) -> None: + # for this test use case sensitive naming convention + os.environ["SCHEMA__NAMING"] = "direct" + schema.update_normalizers() # create two columns to which complex type cannot be coerced row = {"floatX": 78172.128, "confidenceX": 1.2, "strX": "STR"} new_row, new_table = schema.coerce_row("event_user", None, row) @@ -252,12 +250,12 @@ def test_coerce_complex_variant(schema: Schema) -> None: c_new_columns_v = list(c_new_table_v["columns"].values()) # two new variant columns added assert len(c_new_columns_v) == 2 - assert c_new_columns_v[0]["name"] == "floatX__v_complex" - assert c_new_columns_v[1]["name"] == "confidenceX__v_complex" + assert c_new_columns_v[0]["name"] == "floatX▶v_complex" + assert c_new_columns_v[1]["name"] == "confidenceX▶v_complex" assert c_new_columns_v[0]["variant"] is True assert c_new_columns_v[1]["variant"] is True - assert c_new_row_v["floatX__v_complex"] == v_list - assert c_new_row_v["confidenceX__v_complex"] == v_dict + assert c_new_row_v["floatX▶v_complex"] == v_list + assert c_new_row_v["confidenceX▶v_complex"] == v_dict assert c_new_row_v["strX"] == json.dumps(v_dict) schema.update_table(c_new_table_v) @@ -265,8 +263,8 @@ def test_coerce_complex_variant(schema: Schema) -> None: c_row_v = {"floatX": v_list, "confidenceX": v_dict, "strX": v_dict} c_new_row_v, c_new_table_v = schema.coerce_row("event_user", None, c_row_v) assert c_new_table_v is None - assert c_new_row_v["floatX__v_complex"] == v_list - assert c_new_row_v["confidenceX__v_complex"] == v_dict + assert c_new_row_v["floatX▶v_complex"] == v_list + assert c_new_row_v["confidenceX▶v_complex"] == v_dict assert c_new_row_v["strX"] == json.dumps(v_dict) @@ -539,7 +537,7 @@ def test_infer_on_incomplete_column(schema: Schema) -> None: incomplete_col["primary_key"] = True incomplete_col["x-special"] = "spec" # type: ignore[typeddict-unknown-key] table = utils.new_table("table", columns=[incomplete_col]) - schema.update_table(table) + schema.update_table(table, normalize_identifiers=False) # make sure that column is still incomplete and has no default hints assert schema.get_table("table")["columns"]["I"] == { "name": "I", diff --git a/tests/common/schema/test_merges.py b/tests/common/schema/test_merges.py index 8516414abd..893fd1db5f 100644 --- a/tests/common/schema/test_merges.py +++ b/tests/common/schema/test_merges.py @@ -2,10 +2,9 @@ import pytest from copy import copy, deepcopy -from dlt.common.schema import Schema, utils +from dlt.common.schema import utils from dlt.common.schema.exceptions import ( CannotCoerceColumnException, - CannotCoerceNullException, TablePropertiesConflictException, ) from dlt.common.schema.typing import TColumnSchemaBase, TStoredSchema, TTableSchema, TColumnSchema @@ -294,10 +293,10 @@ def test_diff_tables() -> None: empty = utils.new_table("table") del empty["resource"] print(empty) - partial = utils.diff_table(empty, deepcopy(table)) + partial = utils.diff_table("schema", empty, deepcopy(table)) # partial is simply table assert partial == table - partial = utils.diff_table(deepcopy(table), empty) + partial = utils.diff_table("schema", deepcopy(table), empty) # partial is empty assert partial == empty @@ -305,7 +304,7 @@ def test_diff_tables() -> None: changed = deepcopy(table) changed["description"] = "new description" changed["name"] = "new name" - partial = utils.diff_table(deepcopy(table), changed) + partial = utils.diff_table("schema", deepcopy(table), changed) print(partial) assert partial == {"name": "new name", "description": "new description", "columns": {}} @@ -313,7 +312,7 @@ def test_diff_tables() -> None: existing = deepcopy(table) changed["write_disposition"] = "append" changed["schema_contract"] = "freeze" - partial = utils.diff_table(deepcopy(existing), changed) + partial = utils.diff_table("schema", deepcopy(existing), changed) assert partial == { "name": "new name", "description": "new description", @@ -323,14 +322,14 @@ def test_diff_tables() -> None: } existing["write_disposition"] = "append" existing["schema_contract"] = "freeze" - partial = utils.diff_table(deepcopy(existing), changed) + partial = utils.diff_table("schema", deepcopy(existing), changed) assert partial == {"name": "new name", "description": "new description", "columns": {}} # detect changed column existing = deepcopy(table) changed = deepcopy(table) changed["columns"]["test"]["cluster"] = True - partial = utils.diff_table(existing, changed) + partial = utils.diff_table("schema", existing, changed) assert "test" in partial["columns"] assert "test_2" not in partial["columns"] assert existing["columns"]["test"] == table["columns"]["test"] != partial["columns"]["test"] @@ -339,7 +338,7 @@ def test_diff_tables() -> None: existing = deepcopy(table) changed = deepcopy(table) changed["columns"]["test"]["foreign_key"] = False - partial = utils.diff_table(existing, changed) + partial = utils.diff_table("schema", existing, changed) assert "test" in partial["columns"] # even if not present in tab_a at all @@ -347,7 +346,7 @@ def test_diff_tables() -> None: changed = deepcopy(table) changed["columns"]["test"]["foreign_key"] = False del existing["columns"]["test"]["foreign_key"] - partial = utils.diff_table(existing, changed) + partial = utils.diff_table("schema", existing, changed) assert "test" in partial["columns"] @@ -363,7 +362,7 @@ def test_diff_tables_conflicts() -> None: other = utils.new_table("table_2") with pytest.raises(TablePropertiesConflictException) as cf_ex: - utils.diff_table(table, other) + utils.diff_table("schema", table, other) assert cf_ex.value.table_name == "table" assert cf_ex.value.prop_name == "parent" @@ -371,7 +370,7 @@ def test_diff_tables_conflicts() -> None: changed = deepcopy(table) changed["columns"]["test"]["data_type"] = "bigint" with pytest.raises(CannotCoerceColumnException): - utils.diff_table(table, changed) + utils.diff_table("schema", table, changed) def test_merge_tables() -> None: @@ -391,7 +390,7 @@ def test_merge_tables() -> None: changed["new-prop-3"] = False # type: ignore[typeddict-unknown-key] # drop column so partial has it del table["columns"]["test"] - partial = utils.merge_table(table, changed) + partial = utils.merge_table("schema", table, changed) assert "test" in table["columns"] assert table["x-special"] == 129 # type: ignore[typeddict-item] assert table["description"] == "new description" @@ -420,7 +419,7 @@ def test_merge_tables_incomplete_columns() -> None: changed["columns"] = deepcopy({"test": COL_1_HINTS, "test_2": COL_2_HINTS}) # it is completed now changed["columns"]["test_2"]["data_type"] = "bigint" - partial = utils.merge_table(table, changed) + partial = utils.merge_table("schema", table, changed) assert list(partial["columns"].keys()) == ["test_2"] # test_2 goes to the end, it was incomplete in table so it got dropped before update assert list(table["columns"].keys()) == ["test", "test_2"] @@ -435,7 +434,7 @@ def test_merge_tables_incomplete_columns() -> None: changed["columns"] = deepcopy({"test": COL_1_HINTS, "test_2": COL_2_HINTS}) # still incomplete but changed changed["columns"]["test_2"]["nullable"] = False - partial = utils.merge_table(table, changed) + partial = utils.merge_table("schema", table, changed) assert list(partial["columns"].keys()) == ["test_2"] # incomplete -> incomplete stays in place assert list(table["columns"].keys()) == ["test_2", "test"] diff --git a/tests/common/schema/test_normalize_identifiers.py b/tests/common/schema/test_normalize_identifiers.py new file mode 100644 index 0000000000..60f8c04604 --- /dev/null +++ b/tests/common/schema/test_normalize_identifiers.py @@ -0,0 +1,419 @@ +from copy import deepcopy +import os +from typing import Callable +import pytest + +from dlt.common import json +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.container import Container +from dlt.common.normalizers.naming.naming import NamingConvention +from dlt.common.storages import SchemaStorageConfiguration +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.normalizers.naming import snake_case, direct +from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils +from dlt.common.schema.exceptions import TableIdentifiersFrozen +from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX +from dlt.common.storages import SchemaStorage + +from tests.common.cases.normalizers import sql_upper +from tests.common.utils import load_json_case, load_yml_case + + +@pytest.fixture +def schema_storage_no_import() -> SchemaStorage: + C = resolve_configuration(SchemaStorageConfiguration()) + return SchemaStorage(C, makedirs=True) + + +@pytest.fixture +def cn_schema() -> Schema: + return Schema( + "column_default", + { + "names": "tests.common.normalizers.custom_normalizers", + "json": { + "module": "tests.common.normalizers.custom_normalizers", + "config": {"not_null": ["fake_id"]}, + }, + }, + ) + + +def test_save_store_schema_custom_normalizers( + cn_schema: Schema, schema_storage: SchemaStorage +) -> None: + schema_storage.save_schema(cn_schema) + schema_copy = schema_storage.load_schema(cn_schema.name) + assert_new_schema_values_custom_normalizers(schema_copy) + + +def test_new_schema_custom_normalizers(cn_schema: Schema) -> None: + assert_new_schema_values_custom_normalizers(cn_schema) + + +def test_save_load_incomplete_column( + schema: Schema, schema_storage_no_import: SchemaStorage +) -> None: + # make sure that incomplete column is saved and restored without default hints + incomplete_col = utils.new_column("I", nullable=False) + incomplete_col["primary_key"] = True + incomplete_col["x-special"] = "spec" # type: ignore[typeddict-unknown-key] + table = utils.new_table("table", columns=[incomplete_col]) + schema.update_table(table, normalize_identifiers=False) + schema_storage_no_import.save_schema(schema) + schema_copy = schema_storage_no_import.load_schema("event") + assert schema_copy.get_table("table")["columns"]["I"] == { + "name": "I", + "nullable": False, + "primary_key": True, + "x-special": "spec", + } + + +def test_schema_config_normalizers(schema: Schema, schema_storage_no_import: SchemaStorage) -> None: + # save snake case schema + assert schema._normalizers_config["names"] == "snake_case" + schema_storage_no_import.save_schema(schema) + # config direct naming convention + os.environ["SCHEMA__NAMING"] = "direct" + # new schema has direct naming convention + schema_direct_nc = Schema("direct_naming") + schema_storage_no_import.save_schema(schema_direct_nc) + assert schema_direct_nc._normalizers_config["names"] == "direct" + # still after loading the config is "snake" + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "snake_case" + # provide capabilities context + destination_caps = DestinationCapabilitiesContext.generic_capabilities() + destination_caps.naming_convention = "sql_cs_v1" + destination_caps.max_identifier_length = 127 + with Container().injectable_context(destination_caps): + # caps are ignored if schema is configured + schema_direct_nc = Schema("direct_naming") + assert schema_direct_nc._normalizers_config["names"] == "direct" + # but length is there + assert schema_direct_nc.naming.max_length == 127 + # when loading schema configuration is ignored + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "snake_case" + assert schema.naming.max_length == 127 + # but if we ask to update normalizers config schema is applied + schema.update_normalizers() + assert schema._normalizers_config["names"] == "direct" + + # load schema_direct_nc (direct) + schema_direct_nc = schema_storage_no_import.load_schema(schema_direct_nc.name) + assert schema_direct_nc._normalizers_config["names"] == "direct" + + # drop config + del os.environ["SCHEMA__NAMING"] + schema_direct_nc = schema_storage_no_import.load_schema(schema_direct_nc.name) + assert schema_direct_nc._normalizers_config["names"] == "direct" + + +def test_schema_normalizers_no_config( + schema: Schema, schema_storage_no_import: SchemaStorage +) -> None: + # convert schema to direct and save + os.environ["SCHEMA__NAMING"] = "direct" + schema.update_normalizers() + assert schema._normalizers_config["names"] == "direct" + schema_storage_no_import.save_schema(schema) + # make sure we drop the config correctly + del os.environ["SCHEMA__NAMING"] + schema_test = Schema("test") + assert schema_test.naming.name() == "snake_case" + # use capabilities without default naming convention + destination_caps = DestinationCapabilitiesContext.generic_capabilities() + assert destination_caps.naming_convention is None + destination_caps.max_identifier_length = 66 + with Container().injectable_context(destination_caps): + schema_in_caps = Schema("schema_in_caps") + assert schema_in_caps._normalizers_config["names"] == "snake_case" + assert schema_in_caps.naming.name() == "snake_case" + assert schema_in_caps.naming.max_length == 66 + schema_in_caps.update_normalizers() + assert schema_in_caps.naming.name() == "snake_case" + # old schema preserves convention when loaded + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "direct" + # update normalizer no effect + schema.update_normalizers() + assert schema._normalizers_config["names"] == "direct" + assert schema.naming.max_length == 66 + + # use caps with default naming convention + destination_caps = DestinationCapabilitiesContext.generic_capabilities() + destination_caps.naming_convention = "sql_cs_v1" + destination_caps.max_identifier_length = 127 + with Container().injectable_context(destination_caps): + schema_in_caps = Schema("schema_in_caps") + # new schema gets convention from caps + assert schema_in_caps._normalizers_config["names"] == "sql_cs_v1" + # old schema preserves convention when loaded + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "direct" + # update changes to caps schema + schema.update_normalizers() + assert schema._normalizers_config["names"] == "sql_cs_v1" + assert schema.naming.max_length == 127 + + +@pytest.mark.parametrize("section", ("SOURCES__SCHEMA__NAMING", "SOURCES__THIS__SCHEMA__NAMING")) +def test_config_with_section(section: str) -> None: + os.environ["SOURCES__OTHER__SCHEMA__NAMING"] = "direct" + os.environ[section] = "sql_cs_v1" + this_schema = Schema("this") + that_schema = Schema("that") + assert this_schema.naming.name() == "sql_cs_v1" + expected_that_schema = ( + "snake_case" if section == "SOURCES__THIS__SCHEMA__NAMING" else "sql_cs_v1" + ) + assert that_schema.naming.name() == expected_that_schema + + # test update normalizers + os.environ[section] = "direct" + expected_that_schema = "snake_case" if section == "SOURCES__THIS__SCHEMA__NAMING" else "direct" + this_schema.update_normalizers() + assert this_schema.naming.name() == "direct" + that_schema.update_normalizers() + assert that_schema.naming.name() == expected_that_schema + + +def test_normalize_table_identifiers() -> None: + # load with snake case + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + issues_table = schema.tables["issues"] + issues_table_str = json.dumps(issues_table) + # normalize table to upper + issues_table_norm = utils.normalize_table_identifiers( + issues_table, sql_upper.NamingConvention() + ) + # nothing got changes in issues table + assert issues_table_str == json.dumps(issues_table) + # check normalization + assert issues_table_norm["name"] == "ISSUES" + assert "REACTIONS___1" in issues_table_norm["columns"] + # subsequent normalization does not change dict + assert issues_table_norm == utils.normalize_table_identifiers( + issues_table_norm, sql_upper.NamingConvention() + ) + + +def test_normalize_table_identifiers_idempotent() -> None: + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + # assert column generated from "reactions/+1" and "-1", it is a valid identifier even with three underscores + assert "reactions___1" in schema.tables["issues"]["columns"] + issues_table = schema.tables["issues"] + # this schema is already normalized so normalization is idempotent + assert schema.tables["issues"] == utils.normalize_table_identifiers(issues_table, schema.naming) + assert schema.tables["issues"] == utils.normalize_table_identifiers( + utils.normalize_table_identifiers(issues_table, schema.naming), schema.naming + ) + + +def test_normalize_table_identifiers_merge_columns() -> None: + # create conflicting columns + table_create = [ + {"name": "case", "data_type": "bigint", "nullable": False, "x-description": "desc"}, + {"name": "Case", "data_type": "double", "nullable": True, "primary_key": True}, + ] + # schema normalizing to snake case will conflict on case and Case + table = utils.new_table("blend", columns=table_create) # type: ignore[arg-type] + table_str = json.dumps(table) + norm_table = utils.normalize_table_identifiers(table, Schema("norm").naming) + # nothing got changed in original table + assert table_str == json.dumps(table) + # only one column + assert len(norm_table["columns"]) == 1 + assert norm_table["columns"]["case"] == { + "nullable": False, # remove default, preserve non default + "primary_key": True, + "name": "case", + "data_type": "double", + "x-description": "desc", + } + + +def test_update_normalizers() -> None: + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + # drop seen data + del schema.tables["issues"]["x-normalizer"] + del schema.tables["issues__labels"]["x-normalizer"] + del schema.tables["issues__assignees"]["x-normalizer"] + # save default hints in original form + default_hints = schema._settings["default_hints"] + + os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.sql_upper" + schema.update_normalizers() + assert isinstance(schema.naming, sql_upper.NamingConvention) + # print(schema.to_pretty_yaml()) + assert_schema_identifiers_case(schema, str.upper) + + # resource must be old name + assert schema.tables["ISSUES"]["resource"] == "issues" + + # make sure normalizer config is replaced + assert schema._normalizers_config["names"] == "tests.common.cases.normalizers.sql_upper" + assert "allow_identifier_change_on_table_with_data" not in schema._normalizers_config + + # regexes are uppercased + new_default_hints = schema._settings["default_hints"] + for hint, regexes in default_hints.items(): + # same number of hints + assert len(regexes) == len(new_default_hints[hint]) + # but all upper cased + assert set(n.upper() for n in regexes) == set(new_default_hints[hint]) + + +def test_normalize_default_hints(schema_storage_no_import: SchemaStorage) -> None: + # use destination caps to force naming convention + from dlt.common.destination import DestinationCapabilitiesContext + from dlt.common.configuration.container import Container + + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + orig_schema = Schema.from_dict(eth_V9) + # save schema + schema_storage_no_import.save_schema(orig_schema) + + with Container().injectable_context( + DestinationCapabilitiesContext.generic_capabilities(naming_convention=sql_upper) + ) as caps: + assert caps.naming_convention is sql_upper + # creating a schema from dict keeps original normalizers + schema = Schema.from_dict(eth_V9) + assert_schema_identifiers_case(schema, str.lower) + assert schema._normalizers_config["names"].endswith("snake_case") + + # loading from storage keeps storage normalizers + storage_schema = schema_storage_no_import.load_schema("ethereum") + assert_schema_identifiers_case(storage_schema, str.lower) + assert storage_schema._normalizers_config["names"].endswith("snake_case") + + # new schema instance is created using caps/config + new_schema = Schema("new") + assert_schema_identifiers_case(new_schema, str.upper) + assert ( + new_schema._normalizers_config["names"] + == "tests.common.cases.normalizers.sql_upper.NamingConvention" + ) + + # attempt to update normalizers blocked by tables with data + with pytest.raises(TableIdentifiersFrozen): + schema.update_normalizers() + # also cloning with update normalizers + with pytest.raises(TableIdentifiersFrozen): + schema.clone(update_normalizers=True) + + # remove processing hints and normalize + norm_cloned = schema.clone(update_normalizers=True, remove_processing_hints=True) + assert_schema_identifiers_case(norm_cloned, str.upper) + assert ( + norm_cloned._normalizers_config["names"] + == "tests.common.cases.normalizers.sql_upper.NamingConvention" + ) + + norm_schema = Schema.from_dict( + deepcopy(eth_V9), remove_processing_hints=True, bump_version=False + ) + norm_schema.update_normalizers() + assert_schema_identifiers_case(norm_schema, str.upper) + assert ( + norm_schema._normalizers_config["names"] + == "tests.common.cases.normalizers.sql_upper.NamingConvention" + ) + + # both ways of obtaining schemas (cloning, cleaning dict) must generate identical schemas + assert norm_cloned.to_pretty_json() == norm_schema.to_pretty_json() + + # save to storage + schema_storage_no_import.save_schema(norm_cloned) + + # load schema out of caps + storage_schema = schema_storage_no_import.load_schema("ethereum") + assert_schema_identifiers_case(storage_schema, str.upper) + # the instance got converted into + assert storage_schema._normalizers_config["names"].endswith("sql_upper.NamingConvention") + assert storage_schema.stored_version_hash == storage_schema.version_hash + # cloned when bumped must have same version hash + norm_cloned._bump_version() + assert storage_schema.stored_version_hash == norm_cloned.stored_version_hash + + +def test_raise_on_change_identifier_table_with_data() -> None: + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + # mark issues table to seen data and change naming to sql upper + issues_table = schema.tables["issues"] + issues_table["x-normalizer"] = {"seen-data": True} + os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.sql_upper" + with pytest.raises(TableIdentifiersFrozen) as fr_ex: + schema.update_normalizers() + assert fr_ex.value.table_name == "issues" + assert isinstance(fr_ex.value.from_naming, snake_case.NamingConvention) + assert isinstance(fr_ex.value.to_naming, sql_upper.NamingConvention) + # try again, get exception (schema was not partially modified) + with pytest.raises(TableIdentifiersFrozen) as fr_ex: + schema.update_normalizers() + + # use special naming convention that only changes column names ending with x to _ + issues_table["columns"]["columnx"] = {"name": "columnx", "data_type": "bigint"} + assert schema.tables["issues"] is issues_table + os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.snake_no_x" + with pytest.raises(TableIdentifiersFrozen) as fr_ex: + schema.update_normalizers() + assert fr_ex.value.table_name == "issues" + # allow to change tables with data + os.environ["SCHEMA__ALLOW_IDENTIFIER_CHANGE_ON_TABLE_WITH_DATA"] = "True" + schema.update_normalizers() + assert schema._normalizers_config["allow_identifier_change_on_table_with_data"] is True + + +def assert_schema_identifiers_case(schema: Schema, casing: Callable[[str], str]) -> None: + for table_name, table in schema.tables.items(): + assert table_name == casing(table_name) == table["name"] + if "parent" in table: + assert table["parent"] == casing(table["parent"]) + for col_name, column in table["columns"].items(): + assert col_name == casing(col_name) == column["name"] + + # make sure table prefixes are set + assert schema._dlt_tables_prefix == casing("_dlt") + assert schema.loads_table_name == casing("_dlt_loads") + assert schema.version_table_name == casing("_dlt_version") + assert schema.state_table_name == casing("_dlt_pipeline_state") + + def _case_regex(regex: str) -> str: + if regex.startswith(SIMPLE_REGEX_PREFIX): + return SIMPLE_REGEX_PREFIX + casing(regex[3:]) + else: + return casing(regex) + + # regexes are uppercased + new_default_hints = schema._settings["default_hints"] + for hint, regexes in new_default_hints.items(): + # but all upper cased + assert set(_case_regex(n) for n in regexes) == set(new_default_hints[hint]) + + +def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: + # check normalizers config + assert schema._normalizers_config["names"] == "tests.common.normalizers.custom_normalizers" + assert ( + schema._normalizers_config["json"]["module"] + == "tests.common.normalizers.custom_normalizers" + ) + # check if schema was extended by json normalizer + assert ["fake_id"] == schema.settings["default_hints"]["not_null"] + # call normalizers + assert schema.naming.normalize_identifier("a") == "column_a" + assert schema.naming.normalize_path("a__b") == "column_a__column_b" + assert schema.naming.normalize_identifier("1A_b") == "column_1a_b" + # assumes elements are normalized + assert schema.naming.make_path("A", "B", "!C") == "A__B__!C" + assert schema.naming.break_path("A__B__!C") == ["A", "B", "!C"] + row = list(schema.normalize_data_item({"bool": True}, "load_id", "a_table")) + assert row[0] == (("a_table", None), {"bool": True}) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 887b0aa9a0..93be165358 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -1,19 +1,17 @@ -from copy import deepcopy import os -from typing import List, Sequence, cast +from typing import Dict, List, Sequence import pytest +from copy import deepcopy from dlt.common import pendulum -from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.container import Container +from dlt.common.json import json +from dlt.common.data_types.typing import TDataType from dlt.common.schema.migrations import migrate_schema -from dlt.common.storages import SchemaStorageConfiguration -from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import DictValidationException -from dlt.common.normalizers.naming import snake_case, direct +from dlt.common.normalizers.naming import snake_case from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id -from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils, TColumnHint +from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils from dlt.common.schema.exceptions import ( InvalidSchemaName, ParentTableNotFoundException, @@ -28,50 +26,12 @@ ) from dlt.common.storages import SchemaStorage -from tests.utils import autouse_test_storage, preserve_environ from tests.common.utils import load_json_case, load_yml_case, COMMON_TEST_CASES_PATH SCHEMA_NAME = "event" EXPECTED_FILE_NAME = f"{SCHEMA_NAME}.schema.json" -@pytest.fixture -def schema_storage() -> SchemaStorage: - C = resolve_configuration( - SchemaStorageConfiguration(), - explicit_value={ - "import_schema_path": "tests/common/cases/schemas/rasa", - "external_schema_format": "json", - }, - ) - return SchemaStorage(C, makedirs=True) - - -@pytest.fixture -def schema_storage_no_import() -> SchemaStorage: - C = resolve_configuration(SchemaStorageConfiguration()) - return SchemaStorage(C, makedirs=True) - - -@pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def cn_schema() -> Schema: - return Schema( - "column_default", - { - "names": "tests.common.normalizers.custom_normalizers", - "json": { - "module": "tests.common.normalizers.custom_normalizers", - "config": {"not_null": ["fake_id"]}, - }, - }, - ) - - def test_normalize_schema_name(schema: Schema) -> None: assert schema.naming.normalize_table_identifier("BAN_ANA") == "ban_ana" assert schema.naming.normalize_table_identifier("event-.!:value") == "event_value" @@ -102,38 +62,6 @@ def test_new_schema(schema: Schema) -> None: utils.validate_stored_schema(stored_schema) -def test_new_schema_custom_normalizers(cn_schema: Schema) -> None: - assert_is_new_schema(cn_schema) - assert_new_schema_props_custom_normalizers(cn_schema) - - -def test_schema_config_normalizers(schema: Schema, schema_storage_no_import: SchemaStorage) -> None: - # save snake case schema - schema_storage_no_import.save_schema(schema) - # config direct naming convention - os.environ["SCHEMA__NAMING"] = "direct" - # new schema has direct naming convention - schema_direct_nc = Schema("direct_naming") - assert schema_direct_nc._normalizers_config["names"] == "direct" - # still after loading the config is "snake" - schema = schema_storage_no_import.load_schema(schema.name) - assert schema._normalizers_config["names"] == "snake_case" - # provide capabilities context - destination_caps = DestinationCapabilitiesContext.generic_capabilities() - destination_caps.naming_convention = "snake_case" - destination_caps.max_identifier_length = 127 - with Container().injectable_context(destination_caps): - # caps are ignored if schema is configured - schema_direct_nc = Schema("direct_naming") - assert schema_direct_nc._normalizers_config["names"] == "direct" - # but length is there - assert schema_direct_nc.naming.max_length == 127 - # also for loaded schema - schema = schema_storage_no_import.load_schema(schema.name) - assert schema._normalizers_config["names"] == "snake_case" - assert schema.naming.max_length == 127 - - def test_simple_regex_validator() -> None: # can validate only simple regexes assert utils.simple_regex_validator(".", "k", "v", str) is False @@ -394,33 +322,6 @@ def test_save_store_schema(schema: Schema, schema_storage: SchemaStorage) -> Non assert_new_schema_props(schema_copy) -def test_save_store_schema_custom_normalizers( - cn_schema: Schema, schema_storage: SchemaStorage -) -> None: - schema_storage.save_schema(cn_schema) - schema_copy = schema_storage.load_schema(cn_schema.name) - assert_new_schema_props_custom_normalizers(schema_copy) - - -def test_save_load_incomplete_column( - schema: Schema, schema_storage_no_import: SchemaStorage -) -> None: - # make sure that incomplete column is saved and restored without default hints - incomplete_col = utils.new_column("I", nullable=False) - incomplete_col["primary_key"] = True - incomplete_col["x-special"] = "spec" # type: ignore[typeddict-unknown-key] - table = utils.new_table("table", columns=[incomplete_col]) - schema.update_table(table) - schema_storage_no_import.save_schema(schema) - schema_copy = schema_storage_no_import.load_schema("event") - assert schema_copy.get_table("table")["columns"]["I"] == { - "name": "I", - "nullable": False, - "primary_key": True, - "x-special": "spec", - } - - def test_upgrade_engine_v1_schema() -> None: schema_dict: DictStrAny = load_json_case("schemas/ev1/event.schema") # ensure engine v1 @@ -479,7 +380,7 @@ def test_unknown_engine_upgrade() -> None: def test_preserve_column_order(schema: Schema, schema_storage: SchemaStorage) -> None: # python dicts are ordered from v3.6, add 50 column with random names update: List[TColumnSchema] = [ - schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50) + schema._infer_column("t" + uniq_id(), pendulum.now().timestamp()) for _ in range(50) ] schema.update_table(utils.new_table("event_test_order", columns=update)) @@ -496,7 +397,7 @@ def verify_items(table, update) -> None: verify_items(table, update) # add more columns update2: List[TColumnSchema] = [ - schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50) + schema._infer_column("t" + uniq_id(), pendulum.now().timestamp()) for _ in range(50) ] loaded_schema.update_table(utils.new_table("event_test_order", columns=update2)) table = loaded_schema.get_table_columns("event_test_order") @@ -648,6 +549,79 @@ def test_merge_hints(schema: Schema) -> None: for k in expected_hints: assert set(expected_hints[k]) == set(schema._settings["default_hints"][k]) # type: ignore[index] + # make sure that re:^_dlt_id$ and _dlt_id are equivalent when merging so we can use both forms + alt_form_hints = { + "not_null": ["re:^_dlt_id$"], + "foreign_key": ["_dlt_parent_id"], + } + schema.merge_hints(alt_form_hints) # type: ignore[arg-type] + # we keep the older forms so nothing changed + assert len(expected_hints) == len(schema._settings["default_hints"]) + for k in expected_hints: + assert set(expected_hints[k]) == set(schema._settings["default_hints"][k]) # type: ignore[index] + + # check normalize some regex forms + upper_hints = { + "not_null": [ + "_DLT_ID", + ], + "foreign_key": ["re:^_DLT_PARENT_ID$"], + } + schema.merge_hints(upper_hints) # type: ignore[arg-type] + # all upper form hints can be automatically converted to lower form + assert len(expected_hints) == len(schema._settings["default_hints"]) + for k in expected_hints: + assert set(expected_hints[k]) == set(schema._settings["default_hints"][k]) # type: ignore[index] + + # this form cannot be converted + upper_hints = { + "not_null": [ + "re:TU[b-b]a", + ], + } + schema.merge_hints(upper_hints) # type: ignore[arg-type] + assert "re:TU[b-b]a" in schema.settings["default_hints"]["not_null"] + + +def test_update_preferred_types(schema: Schema) -> None: + # no preferred types in the schema + assert "preferred_types" not in schema.settings + + expected: Dict[TSimpleRegex, TDataType] = { + TSimpleRegex("_dlt_id"): "bigint", + TSimpleRegex("re:^timestamp$"): "timestamp", + } + schema.update_preferred_types(expected) + assert schema.settings["preferred_types"] == expected + # no changes + schema.update_preferred_types(expected) + assert schema.settings["preferred_types"] == expected + + # add and replace, canonical form used to update / replace + updated: Dict[TSimpleRegex, TDataType] = { + TSimpleRegex("_dlt_id"): "decimal", + TSimpleRegex("timestamp"): "date", + TSimpleRegex("re:TU[b-c]a"): "text", + } + schema.update_preferred_types(updated) + assert schema.settings["preferred_types"] == { + "_dlt_id": "decimal", + "re:^timestamp$": "date", + "re:TU[b-c]a": "text", + } + + # will normalize some form of regex + updated = { + TSimpleRegex("_DLT_id"): "text", + TSimpleRegex("re:^TIMESTAMP$"): "timestamp", + } + schema.update_preferred_types(updated) + assert schema.settings["preferred_types"] == { + "_dlt_id": "text", + "re:^timestamp$": "timestamp", + "re:TU[b-c]a": "text", + } + def test_default_table_resource() -> None: """Parent tables without `resource` set default to table name""" @@ -766,9 +740,9 @@ def test_normalize_table_identifiers() -> None: assert "reactions___1" in schema.tables["issues"]["columns"] issues_table = deepcopy(schema.tables["issues"]) # this schema is already normalized so normalization is idempotent - assert schema.tables["issues"] == schema.normalize_table_identifiers(issues_table) - assert schema.tables["issues"] == schema.normalize_table_identifiers( - schema.normalize_table_identifiers(issues_table) + assert schema.tables["issues"] == utils.normalize_table_identifiers(issues_table, schema.naming) + assert schema.tables["issues"] == utils.normalize_table_identifiers( + utils.normalize_table_identifiers(issues_table, schema.naming), schema.naming ) @@ -780,7 +754,10 @@ def test_normalize_table_identifiers_merge_columns() -> None: ] # schema normalizing to snake case will conflict on case and Case table = utils.new_table("blend", columns=table_create) # type: ignore[arg-type] - norm_table = Schema("norm").normalize_table_identifiers(table) + table_str = json.dumps(table) + norm_table = utils.normalize_table_identifiers(table, Schema("norm").naming) + # nothing got changed in original table + assert table_str == json.dumps(table) # only one column assert len(norm_table["columns"]) == 1 assert norm_table["columns"]["case"] == { @@ -859,20 +836,21 @@ def test_group_tables_by_resource(schema: Schema) -> None: schema.update_table(utils.new_table("a_events", columns=[])) schema.update_table(utils.new_table("b_events", columns=[])) schema.update_table(utils.new_table("c_products", columns=[], resource="products")) - schema.update_table(utils.new_table("a_events__1", columns=[], parent_table_name="a_events")) + schema.update_table(utils.new_table("a_events___1", columns=[], parent_table_name="a_events")) schema.update_table( - utils.new_table("a_events__1__2", columns=[], parent_table_name="a_events__1") + utils.new_table("a_events___1___2", columns=[], parent_table_name="a_events___1") ) - schema.update_table(utils.new_table("b_events__1", columns=[], parent_table_name="b_events")) + schema.update_table(utils.new_table("b_events___1", columns=[], parent_table_name="b_events")) + # print(schema.to_pretty_yaml()) # All resources without filter expected_tables = { "a_events": [ schema.tables["a_events"], - schema.tables["a_events__1"], - schema.tables["a_events__1__2"], + schema.tables["a_events___1"], + schema.tables["a_events___1___2"], ], - "b_events": [schema.tables["b_events"], schema.tables["b_events__1"]], + "b_events": [schema.tables["b_events"], schema.tables["b_events___1"]], "products": [schema.tables["c_products"]], "_dlt_version": [schema.tables["_dlt_version"]], "_dlt_loads": [schema.tables["_dlt_loads"]], @@ -887,10 +865,10 @@ def test_group_tables_by_resource(schema: Schema) -> None: assert result == { "a_events": [ schema.tables["a_events"], - schema.tables["a_events__1"], - schema.tables["a_events__1__2"], + schema.tables["a_events___1"], + schema.tables["a_events___1___2"], ], - "b_events": [schema.tables["b_events"], schema.tables["b_events__1"]], + "b_events": [schema.tables["b_events"], schema.tables["b_events___1"]], } # With resources that has many top level tables @@ -919,3 +897,41 @@ def test_group_tables_by_resource(schema: Schema) -> None: {"columns": {}, "name": "mc_products__sub", "parent": "mc_products"}, ] } + + +def test_remove_processing_hints() -> None: + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + # here tables contain processing hints + schema = Schema.from_dict(eth_V9) + assert "x-normalizer" in schema.tables["blocks"] + + # clone with hints removal, note that clone does not bump version + cloned = schema.clone(remove_processing_hints=True) + assert "x-normalizer" not in cloned.tables["blocks"] + # clone does not touch original schema + assert "x-normalizer" in schema.tables["blocks"] + + # to string + to_yaml = schema.to_pretty_yaml() + assert "x-normalizer" in to_yaml + to_yaml = schema.to_pretty_yaml(remove_processing_hints=True) + assert "x-normalizer" not in to_yaml + to_json = schema.to_pretty_json() + assert "x-normalizer" in to_json + to_json = schema.to_pretty_json(remove_processing_hints=True) + assert "x-normalizer" not in to_json + + # load without hints + no_hints = schema.from_dict(eth_V9, remove_processing_hints=True, bump_version=False) + assert no_hints.stored_version_hash == cloned.stored_version_hash + + # now load without hints but with version bump + cloned._bump_version() + no_hints = schema.from_dict(eth_V9, remove_processing_hints=True) + assert no_hints.stored_version_hash == cloned.stored_version_hash + + +# def test_get_new_table_columns() -> None: +# pytest.fail(reason="must implement!") +# pass +# get_new_table_columns() diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index b67b028161..788da09533 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -1,6 +1,5 @@ import pytest import yaml -from copy import deepcopy from dlt.common import json from dlt.common.schema import utils diff --git a/tests/common/storages/test_file_storage.py b/tests/common/storages/test_file_storage.py index eae765398b..7a10e29097 100644 --- a/tests/common/storages/test_file_storage.py +++ b/tests/common/storages/test_file_storage.py @@ -39,38 +39,40 @@ def test_to_relative_path(test_storage: FileStorage) -> None: def test_make_full_path(test_storage: FileStorage) -> None: # fully within storage relative_path = os.path.join("dir", "to", "file") - path = test_storage.make_full_path(relative_path) + path = test_storage.make_full_path_safe(relative_path) assert path.endswith(os.path.join(TEST_STORAGE_ROOT, relative_path)) # overlapped with storage root_path = os.path.join(TEST_STORAGE_ROOT, relative_path) - path = test_storage.make_full_path(root_path) + path = test_storage.make_full_path_safe(root_path) assert path.endswith(root_path) assert path.count(TEST_STORAGE_ROOT) == 2 # absolute path with different root than TEST_STORAGE_ROOT does not lead into storage so calculating full path impossible with pytest.raises(ValueError): - test_storage.make_full_path(os.path.join("/", root_path)) + test_storage.make_full_path_safe(os.path.join("/", root_path)) # relative path out of the root with pytest.raises(ValueError): - test_storage.make_full_path("..") + test_storage.make_full_path_safe("..") # absolute overlapping path - path = test_storage.make_full_path(os.path.abspath(root_path)) + path = test_storage.make_full_path_safe(os.path.abspath(root_path)) assert path.endswith(root_path) - assert test_storage.make_full_path("") == test_storage.storage_path - assert test_storage.make_full_path(".") == test_storage.storage_path + assert test_storage.make_full_path_safe("") == test_storage.storage_path + assert test_storage.make_full_path_safe(".") == test_storage.storage_path def test_in_storage(test_storage: FileStorage) -> None: # always relative to storage root - assert test_storage.in_storage("a/b/c") is True - assert test_storage.in_storage(f"../{TEST_STORAGE_ROOT}/b/c") is True - assert test_storage.in_storage("../a/b/c") is False - assert test_storage.in_storage("../../../a/b/c") is False - assert test_storage.in_storage("/a") is False - assert test_storage.in_storage(".") is True - assert test_storage.in_storage(os.curdir) is True - assert test_storage.in_storage(os.path.realpath(os.curdir)) is False + assert test_storage.is_path_in_storage("a/b/c") is True + assert test_storage.is_path_in_storage(f"../{TEST_STORAGE_ROOT}/b/c") is True + assert test_storage.is_path_in_storage("../a/b/c") is False + assert test_storage.is_path_in_storage("../../../a/b/c") is False + assert test_storage.is_path_in_storage("/a") is False + assert test_storage.is_path_in_storage(".") is True + assert test_storage.is_path_in_storage(os.curdir) is True + assert test_storage.is_path_in_storage(os.path.realpath(os.curdir)) is False assert ( - test_storage.in_storage(os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT)) + test_storage.is_path_in_storage( + os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT) + ) is True ) @@ -164,7 +166,7 @@ def test_rmtree_ro(test_storage: FileStorage) -> None: test_storage.create_folder("protected") path = test_storage.save("protected/barbapapa.txt", "barbapapa") os.chmod(path, stat.S_IREAD) - os.chmod(test_storage.make_full_path("protected"), stat.S_IREAD) + os.chmod(test_storage.make_full_path_safe("protected"), stat.S_IREAD) with pytest.raises(PermissionError): test_storage.delete_folder("protected", recursively=True, delete_ro=False) test_storage.delete_folder("protected", recursively=True, delete_ro=True) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index ecbc5d296d..45bc8d157e 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -8,10 +8,8 @@ from dlt.common import sleep from dlt.common.schema import Schema from dlt.common.storages import PackageStorage, LoadStorage, ParsedLoadJobFileName +from dlt.common.storages.exceptions import LoadPackageAlreadyCompleted, LoadPackageNotCompleted from dlt.common.utils import uniq_id - -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage -from tests.utils import autouse_test_storage from dlt.common.pendulum import pendulum from dlt.common.configuration.container import Container from dlt.common.storages.load_package import ( @@ -23,6 +21,9 @@ clear_destination_state, ) +from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage + def test_is_partially_loaded(load_storage: LoadStorage) -> None: load_id, file_name = start_loading_file( @@ -243,6 +244,177 @@ def test_build_parse_job_path(load_storage: LoadStorage) -> None: ParsedLoadJobFileName.parse("tab.id.wrong_retry.jsonl") +def test_load_package_listings(load_storage: LoadStorage) -> None: + # 100 csv files + load_id = create_load_package(load_storage.new_packages, 100) + new_jobs = load_storage.new_packages.list_new_jobs(load_id) + assert len(new_jobs) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_1")) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_2")) == 0 + assert len(load_storage.new_packages.list_all_jobs_with_states(load_id)) == 100 + assert len(load_storage.new_packages.list_started_jobs(load_id)) == 0 + assert len(load_storage.new_packages.list_failed_jobs(load_id)) == 0 + assert load_storage.new_packages.is_package_completed(load_id) is False + with pytest.raises(LoadPackageNotCompleted): + load_storage.new_packages.list_failed_jobs_infos(load_id) + # add a few more files + add_new_jobs(load_storage.new_packages, load_id, 7, "items_2") + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_1")) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_2")) == 7 + j_w_s = load_storage.new_packages.list_all_jobs_with_states(load_id) + assert len(j_w_s) == 107 + assert all(job[0] == "new_jobs" for job in j_w_s) + with pytest.raises(FileNotFoundError): + load_storage.new_packages.get_job_failed_message(load_id, j_w_s[0][1]) + # get package infos + package_jobs = load_storage.new_packages.get_load_package_jobs(load_id) + assert len(package_jobs["new_jobs"]) == 107 + # other folders empty + assert len(package_jobs["started_jobs"]) == 0 + package_info = load_storage.new_packages.get_load_package_info(load_id) + assert len(package_info.jobs["new_jobs"]) == 107 + assert len(package_info.jobs["completed_jobs"]) == 0 + assert package_info.load_id == load_id + # full path + assert package_info.package_path == load_storage.new_packages.storage.make_full_path(load_id) + assert package_info.state == "new" + assert package_info.completed_at is None + + # move some files + new_jobs = sorted(load_storage.new_packages.list_new_jobs(load_id)) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[0])) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[1])) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[-1])) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[-2])) + + assert len(load_storage.new_packages.list_started_jobs(load_id)) == 4 + assert len(load_storage.new_packages.list_new_jobs(load_id)) == 103 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_1")) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_2")) == 7 + package_jobs = load_storage.new_packages.get_load_package_jobs(load_id) + assert len(package_jobs["new_jobs"]) == 103 + assert len(package_jobs["started_jobs"]) == 4 + package_info = load_storage.new_packages.get_load_package_info(load_id) + assert len(package_info.jobs["new_jobs"]) == 103 + assert len(package_info.jobs["started_jobs"]) == 4 + + # complete and fail some + load_storage.new_packages.complete_job(load_id, os.path.basename(new_jobs[0])) + load_storage.new_packages.fail_job(load_id, os.path.basename(new_jobs[1]), None) + load_storage.new_packages.fail_job(load_id, os.path.basename(new_jobs[-1]), "error!") + path = load_storage.new_packages.retry_job(load_id, os.path.basename(new_jobs[-2])) + assert ParsedLoadJobFileName.parse(path).retry_count == 1 + assert ( + load_storage.new_packages.get_job_failed_message( + load_id, ParsedLoadJobFileName.parse(new_jobs[1]) + ) + is None + ) + assert ( + load_storage.new_packages.get_job_failed_message( + load_id, ParsedLoadJobFileName.parse(new_jobs[-1]) + ) + == "error!" + ) + # can't move again + with pytest.raises(FileNotFoundError): + load_storage.new_packages.complete_job(load_id, os.path.basename(new_jobs[0])) + assert len(load_storage.new_packages.list_started_jobs(load_id)) == 0 + # retry back in new + assert len(load_storage.new_packages.list_new_jobs(load_id)) == 104 + package_jobs = load_storage.new_packages.get_load_package_jobs(load_id) + assert len(package_jobs["new_jobs"]) == 104 + assert len(package_jobs["started_jobs"]) == 0 + assert len(package_jobs["completed_jobs"]) == 1 + assert len(package_jobs["failed_jobs"]) == 2 + assert len(load_storage.new_packages.list_failed_jobs(load_id)) == 2 + package_info = load_storage.new_packages.get_load_package_info(load_id) + assert len(package_info.jobs["new_jobs"]) == 104 + assert len(package_info.jobs["started_jobs"]) == 0 + assert len(package_info.jobs["completed_jobs"]) == 1 + assert len(package_info.jobs["failed_jobs"]) == 2 + + # complete package + load_storage.new_packages.complete_loading_package(load_id, "aborted") + assert load_storage.new_packages.is_package_completed(load_id) + with pytest.raises(LoadPackageAlreadyCompleted): + load_storage.new_packages.complete_loading_package(load_id, "aborted") + + for job in package_info.jobs["failed_jobs"] + load_storage.new_packages.list_failed_jobs_infos( # type: ignore[operator] + load_id + ): + if job.job_file_info.table_name == "items_1": + assert job.failed_message is None + elif job.job_file_info.table_name == "items_2": + assert job.failed_message == "error!" + else: + raise AssertionError() + assert job.created_at is not None + assert job.elapsed is not None + assert job.file_size > 0 + assert job.state == "failed_jobs" + # must be abs path! + assert os.path.isabs(job.file_path) + + +def test_get_load_package_info_perf(load_storage: LoadStorage) -> None: + import time + + st_t = time.time() + for _ in range(10000): + load_storage.loaded_packages.storage.make_full_path("198291092.121/new/ABD.CX.gx") + # os.path.basename("198291092.121/new/ABD.CX.gx") + print(time.time() - st_t) + + st_t = time.time() + load_id = create_load_package(load_storage.loaded_packages, 10000) + print(time.time() - st_t) + + st_t = time.time() + # move half of the files to failed + for file_name in load_storage.loaded_packages.list_new_jobs(load_id)[:1000]: + load_storage.loaded_packages.start_job(load_id, os.path.basename(file_name)) + load_storage.loaded_packages.fail_job( + load_id, os.path.basename(file_name), f"FAILED {file_name}" + ) + print(time.time() - st_t) + + st_t = time.time() + load_storage.loaded_packages.get_load_package_info(load_id) + print(time.time() - st_t) + + st_t = time.time() + table_stat = {} + for file in load_storage.loaded_packages.list_new_jobs(load_id): + parsed = ParsedLoadJobFileName.parse(file) + table_stat[parsed.table_name] = parsed + print(time.time() - st_t) + + +def create_load_package( + package_storage: PackageStorage, new_jobs: int, table_name="items_1" +) -> str: + schema = Schema("test") + load_id = create_load_id() + package_storage.create_package(load_id) + package_storage.save_schema(load_id, schema) + add_new_jobs(package_storage, load_id, new_jobs, table_name) + return load_id + + +def add_new_jobs( + package_storage: PackageStorage, load_id: str, new_jobs: int, table_name="items_1" +) -> None: + for _ in range(new_jobs): + file_name = PackageStorage.build_job_file_name( + table_name, ParsedLoadJobFileName.new_file_id(), 0, False, "csv" + ) + file_path = os.path.join(TEST_STORAGE_ROOT, file_name) + with open(file_path, "wt", encoding="utf-8") as f: + f.write("a|b|c") + package_storage.import_job(load_id, file_path) + + def test_migrate_to_load_package_state() -> None: """ Here we test that an existing load package without a state will not error diff --git a/tests/common/storages/test_load_storage.py b/tests/common/storages/test_load_storage.py index e8686ac2f9..49deaff23e 100644 --- a/tests/common/storages/test_load_storage.py +++ b/tests/common/storages/test_load_storage.py @@ -33,7 +33,7 @@ def test_complete_successful_package(load_storage: LoadStorage) -> None: # but completed packages are deleted load_storage.maybe_remove_completed_jobs(load_id) assert not load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) assert_package_info(load_storage, load_id, "loaded", "completed_jobs", jobs_count=0) # delete completed package @@ -56,7 +56,7 @@ def test_complete_successful_package(load_storage: LoadStorage) -> None: ) # has completed loads assert load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) load_storage.delete_loaded_package(load_id) assert not load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) @@ -82,14 +82,14 @@ def test_complete_package_failed_jobs(load_storage: LoadStorage) -> None: assert load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) # has completed loads assert load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) assert_package_info(load_storage, load_id, "loaded", "failed_jobs") # get failed jobs info failed_files = sorted(load_storage.loaded_packages.list_failed_jobs(load_id)) - # job + message - assert len(failed_files) == 2 + # only jobs + assert len(failed_files) == 1 assert load_storage.loaded_packages.storage.has_file(failed_files[0]) failed_info = load_storage.list_failed_jobs_in_loaded_package(load_id) assert failed_info[0].file_path == load_storage.loaded_packages.storage.make_full_path( @@ -117,7 +117,7 @@ def test_abort_package(load_storage: LoadStorage) -> None: assert_package_info(load_storage, load_id, "normalized", "failed_jobs") load_storage.complete_load_package(load_id, True) assert load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) assert_package_info(load_storage, load_id, "aborted", "failed_jobs") diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index e97fac8a9e..ffbd2ecf1b 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -1,12 +1,10 @@ import os -import shutil import pytest import yaml from dlt.common import json -from dlt.common.normalizers import explicit_normalizers +from dlt.common.normalizers.utils import explicit_normalizers from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TStoredSchema from dlt.common.storages.exceptions import ( InStorageSchemaModified, SchemaNotFoundError, @@ -20,9 +18,9 @@ ) from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT +from tests.common.storages.utils import prepare_eth_import_folder from tests.common.utils import ( load_yml_case, - yml_case_path, COMMON_TEST_CASES_PATH, IMPORTED_VERSION_HASH_ETH_V9, ) @@ -234,7 +232,7 @@ def test_getter(storage: SchemaStorage) -> None: def test_getter_with_import(ie_storage: SchemaStorage) -> None: with pytest.raises(KeyError): ie_storage["ethereum"] - prepare_import_folder(ie_storage) + prepare_eth_import_folder(ie_storage) # schema will be imported schema = ie_storage["ethereum"] assert schema.name == "ethereum" @@ -260,17 +258,17 @@ def test_getter_with_import(ie_storage: SchemaStorage) -> None: def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: - prepare_import_folder(ie_storage) + prepare_eth_import_folder(ie_storage) # we have ethereum schema to be imported but we create new schema and save it schema = Schema("ethereum") schema_hash = schema.version_hash ie_storage.save_schema(schema) assert schema.version_hash == schema_hash # we linked schema to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() # load schema and make sure our new schema is here schema = ie_storage.load_schema("ethereum") - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() assert schema._stored_version_hash == schema_hash assert schema.version_hash == schema_hash assert schema.previous_hashes == [] @@ -283,11 +281,11 @@ def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: def test_save_store_schema_over_import_sync(synced_storage: SchemaStorage) -> None: # as in test_save_store_schema_over_import but we export the new schema immediately to overwrite the imported schema - prepare_import_folder(synced_storage) + prepare_eth_import_folder(synced_storage) schema = Schema("ethereum") schema_hash = schema.version_hash synced_storage.save_schema(schema) - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() # import schema is overwritten fs = FileStorage(synced_storage.config.import_schema_path) exported_name = synced_storage._file_name_in_store("ethereum", "yaml") @@ -353,6 +351,28 @@ def test_schema_from_file() -> None: ) +def test_save_initial_import_schema(ie_storage: LiveSchemaStorage) -> None: + # no schema in regular storage + with pytest.raises(SchemaNotFoundError): + ie_storage.load_schema("ethereum") + + # save initial import schema where processing hints are removed + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + schema = Schema.from_dict(eth_V9) + ie_storage.save_import_schema_if_not_exists(schema) + # should be available now + eth = ie_storage.load_schema("ethereum") + assert "x-normalizer" not in eth.tables["blocks"] + + # won't overwrite initial schema + del eth_V9["tables"]["blocks__uncles"] + schema = Schema.from_dict(eth_V9) + ie_storage.save_import_schema_if_not_exists(schema) + # should be available now + eth = ie_storage.load_schema("ethereum") + assert "blocks__uncles" in eth.tables + + def test_live_schema_instances(live_storage: LiveSchemaStorage) -> None: schema = Schema("simple") live_storage.save_schema(schema) @@ -474,22 +494,14 @@ def test_new_live_schema_committed(live_storage: LiveSchemaStorage) -> None: # assert schema.settings["schema_sealed"] is True -def prepare_import_folder(storage: SchemaStorage) -> None: - shutil.copy( - yml_case_path("schemas/eth/ethereum_schema_v8"), - os.path.join(storage.storage.storage_path, "../import/ethereum.schema.yaml"), - ) - - def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: - prepare_import_folder(synced_storage) - eth_V9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") + prepare_eth_import_folder(synced_storage) schema = synced_storage.load_schema("ethereum") # is linked to imported schema - schema._imported_version_hash = eth_V9["version_hash"] + schema._imported_version_hash = IMPORTED_VERSION_HASH_ETH_V9() # also was saved in storage assert synced_storage.has_schema("ethereum") - # and has link to imported schema s well (load without import) + # and has link to imported schema as well (load without import) schema = storage.load_schema("ethereum") - assert schema._imported_version_hash == eth_V9["version_hash"] + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() return schema diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index 3bfc3374a4..1b5a68948b 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -21,9 +21,12 @@ ) from dlt.common.storages import DataItemStorage, FileStorage from dlt.common.storages.fsspec_filesystem import FileItem, FileItemDict +from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.typing import StrAny, TDataItems from dlt.common.utils import uniq_id +from tests.common.utils import load_yml_case + TEST_SAMPLE_FILES = "tests/common/storages/samples" MINIMALLY_EXPECTED_RELATIVE_PATHS = { "csv/freshman_kgs.csv", @@ -199,3 +202,12 @@ def assert_package_info( # get dict package_info.asdict() return package_info + + +def prepare_eth_import_folder(storage: SchemaStorage) -> Schema: + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + # remove processing hints before installing as import schema + # ethereum schema is a "dirty" schema with processing hints + eth = Schema.from_dict(eth_V9, remove_processing_hints=True) + storage._export_schema(eth, storage.config.import_schema_path) + return eth diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index 24b0928463..2c690d94bb 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -1,10 +1,13 @@ +from typing import Dict import pytest from dlt.common.destination.reference import DestinationClientDwhConfiguration, Destination from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.exceptions import InvalidDestinationReference, UnknownDestinationModule from dlt.common.schema import Schema +from dlt.common.typing import is_subclass +from tests.common.configuration.utils import environment from tests.utils import ACTIVE_DESTINATIONS @@ -32,6 +35,96 @@ def test_custom_destination_module() -> None: ) # a full type name +def test_arguments_propagated_to_config() -> None: + dest = Destination.from_reference( + "dlt.destinations.duckdb", create_indexes=None, unknown_param="A" + ) + # None for create_indexes is not a default and it is passed on, unknown_param is removed because it is unknown + assert dest.config_params == {"create_indexes": None} + assert dest.caps_params == {} + + # test explicit config value being passed + import dlt + + dest = Destination.from_reference( + "dlt.destinations.duckdb", create_indexes=dlt.config.value, unknown_param="A" + ) + assert dest.config_params == {"create_indexes": dlt.config.value} + assert dest.caps_params == {} + + dest = Destination.from_reference( + "dlt.destinations.weaviate", naming_convention="duck_case", create_indexes=True + ) + # create indexes are not known + assert dest.config_params == {} + + # create explicit caps + dest = Destination.from_reference( + "dlt.destinations.dummy", + naming_convention="duck_case", + recommended_file_size=4000000, + loader_file_format="parquet", + ) + from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration + + assert dest.config_params == {"loader_file_format": "parquet"} + # loader_file_format is a legacy param that is duplicated as preferred_loader_file_format + assert dest.caps_params == { + "naming_convention": "duck_case", + "recommended_file_size": 4000000, + } + # instantiate configs + caps = dest.capabilities() + assert caps.naming_convention == "duck_case" + assert caps.preferred_loader_file_format == "parquet" + assert caps.recommended_file_size == 4000000 + init_config = DummyClientConfiguration() + config = dest.configuration(init_config) + assert config.loader_file_format == "parquet" # type: ignore[attr-defined] + + +def test_factory_config_injection(environment: Dict[str, str]) -> None: + environment["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" + from dlt.destinations import dummy + + # caps will resolve from config without client + assert dummy().capabilities().preferred_loader_file_format == "parquet" + + caps = dummy().client(Schema("client")).capabilities + assert caps.preferred_loader_file_format == "parquet" + + environment.clear() + caps = dummy().client(Schema("client")).capabilities + assert caps.preferred_loader_file_format == "jsonl" + + environment["DESTINATION__DUMMY__LOADER_FILE_FORMAT"] = "parquet" + environment["DESTINATION__DUMMY__FAIL_PROB"] = "0.435" + + # config will partially resolve without client + config = dummy().configuration(None, accept_partial=True) + assert config.fail_prob == 0.435 + assert config.loader_file_format == "parquet" + + dummy_ = dummy().client(Schema("client")) + assert dummy_.capabilities.preferred_loader_file_format == "parquet" + assert dummy_.config.fail_prob == 0.435 + + # test named destination + environment.clear() + import os + from dlt.destinations import filesystem + from dlt.destinations.impl.filesystem.configuration import ( + FilesystemDestinationClientConfiguration, + ) + + filesystem_ = filesystem(destination_name="local") + abs_path = os.path.abspath("_storage") + environment["DESTINATION__LOCAL__BUCKET_URL"] = abs_path + init_config = FilesystemDestinationClientConfiguration()._bind_dataset_name(dataset_name="test") + configured_bucket_url = filesystem_.client(Schema("test"), init_config).config.bucket_url + assert configured_bucket_url.endswith("_storage") + + def test_import_module_by_path() -> None: # importing works directly from dlt destinations dest = Destination.from_reference("dlt.destinations.postgres") @@ -54,17 +147,7 @@ def test_import_module_by_path() -> None: def test_import_all_destinations() -> None: # this must pass without the client dependencies being imported for dest_type in ACTIVE_DESTINATIONS: - # generic destination needs a valid callable, otherwise instantiation will fail - additional_args = {} - if dest_type == "destination": - - def dest_callable(items, table) -> None: - pass - - additional_args["destination_callable"] = dest_callable - dest = Destination.from_reference( - dest_type, None, dest_type + "_name", "production", **additional_args - ) + dest = Destination.from_reference(dest_type, None, dest_type + "_name", "production") assert dest.destination_type == "dlt.destinations." + dest_type assert dest.destination_name == dest_type + "_name" assert dest.config_params["environment"] == "production" @@ -73,6 +156,44 @@ def dest_callable(items, table) -> None: assert isinstance(dest.capabilities(), DestinationCapabilitiesContext) +def test_instantiate_all_factories() -> None: + from dlt import destinations + + impls = dir(destinations) + for impl in impls: + var_ = getattr(destinations, impl) + if not is_subclass(var_, Destination): + continue + dest = var_() + + assert dest.destination_name + assert dest.destination_type + # custom destination is named after the callable + if dest.destination_type != "dlt.destinations.destination": + assert dest.destination_type.endswith(dest.destination_name) + else: + assert dest.destination_name == "dummy_custom_destination" + assert dest.spec + assert dest.spec() + # partial configuration may always be created + init_config = dest.spec.credentials_type()() + init_config.__is_resolved__ = True + assert dest.configuration(init_config, accept_partial=True) + assert dest.capabilities() + + mod_dest = var_( + destination_name="fake_name", environment="prod", naming_convention="duck_case" + ) + assert ( + mod_dest.config_params.items() + >= {"destination_name": "fake_name", "environment": "prod"}.items() + ) + assert mod_dest.caps_params == {"naming_convention": "duck_case"} + assert mod_dest.destination_name == "fake_name" + caps = mod_dest.capabilities() + assert caps.naming_convention == "duck_case" + + def test_import_destination_config() -> None: # importing destination by type will work dest = Destination.from_reference(ref="dlt.destinations.duckdb", environment="stage") @@ -97,6 +218,7 @@ def test_import_destination_config() -> None: ref="duckdb", destination_name="my_destination", environment="devel" ) assert dest.destination_type == "dlt.destinations.duckdb" + assert dest.destination_name == "my_destination" assert dest.config_params["environment"] == "devel" config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" diff --git a/tests/common/utils.py b/tests/common/utils.py index a234937e56..32741128b8 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -9,14 +9,24 @@ from dlt.common import json from dlt.common.typing import StrAny -from dlt.common.schema import utils +from dlt.common.schema import utils, Schema from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.configuration.providers import environ as environ_provider COMMON_TEST_CASES_PATH = "./tests/common/cases/" -# for import schema tests, change when upgrading the schema version -IMPORTED_VERSION_HASH_ETH_V9 = "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" + + +def IMPORTED_VERSION_HASH_ETH_V9() -> str: + # for import schema tests, change when upgrading the schema version + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + assert eth_V9["version_hash"] == "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" + # remove processing hints before installing as import schema + # ethereum schema is a "dirty" schema with processing hints + eth = Schema.from_dict(eth_V9, remove_processing_hints=True) + return eth.stored_version_hash + + # test sentry DSN TEST_SENTRY_DSN = ( "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" diff --git a/tests/destinations/test_custom_destination.py b/tests/destinations/test_custom_destination.py index 6834006689..6ebf7f6ef3 100644 --- a/tests/destinations/test_custom_destination.py +++ b/tests/destinations/test_custom_destination.py @@ -8,12 +8,13 @@ from copy import deepcopy from dlt.common.configuration.specs.base_configuration import configspec +from dlt.common.schema.schema import Schema from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.destination.reference import Destination from dlt.common.destination.exceptions import InvalidDestinationReference -from dlt.common.configuration.exceptions import ConfigFieldMissingException +from dlt.common.configuration.exceptions import ConfigFieldMissingException, ConfigurationValueError from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.inject import get_fun_spec from dlt.common.configuration.specs import BaseConfiguration @@ -38,7 +39,7 @@ def _run_through_sink( batch_size: int = 10, ) -> List[Tuple[TDataItems, TTableSchema]]: """ - runs a list of items through the sink destination and returns colleceted calls + runs a list of items through the sink destination and returns collected calls """ calls: List[Tuple[TDataItems, TTableSchema]] = [] @@ -55,7 +56,7 @@ def items_resource() -> TDataItems: nonlocal items yield items - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) p.run([items_resource()]) return calls @@ -126,6 +127,34 @@ def global_sink_func(items: TDataItems, table: TTableSchema) -> None: global_calls.append((items, table)) +def test_capabilities() -> None: + # test default caps + dest = dlt.destination()(global_sink_func)() + caps = dest.capabilities() + assert caps.preferred_loader_file_format == "typed-jsonl" + assert caps.supported_loader_file_formats == ["typed-jsonl", "parquet"] + assert caps.naming_convention == "direct" + assert caps.max_table_nesting == 0 + client_caps = dest.client(Schema("schema")).capabilities + assert dict(caps) == dict(client_caps) + + # test modified caps + dest = dlt.destination( + loader_file_format="parquet", + batch_size=0, + name="my_name", + naming_convention="snake_case", + max_table_nesting=10, + )(global_sink_func)() + caps = dest.capabilities() + assert caps.preferred_loader_file_format == "parquet" + assert caps.supported_loader_file_formats == ["typed-jsonl", "parquet"] + assert caps.naming_convention == "snake_case" + assert caps.max_table_nesting == 10 + client_caps = dest.client(Schema("schema")).capabilities + assert dict(caps) == dict(client_caps) + + def test_instantiation() -> None: # also tests _DESTINATIONS calls: List[Tuple[TDataItems, TTableSchema]] = [] @@ -140,23 +169,23 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va # test decorator calls = [] - p = dlt.pipeline("sink_test", destination=dlt.destination()(local_sink_func), full_refresh=True) + p = dlt.pipeline("sink_test", destination=dlt.destination()(local_sink_func), dev_mode=True) p.run([1, 2, 3], table_name="items") assert len(calls) == 1 # local func does not create entry in destinations - assert not _DESTINATIONS + assert "local_sink_func" not in _DESTINATIONS # test passing via from_reference calls = [] p = dlt.pipeline( "sink_test", destination=Destination.from_reference("destination", destination_callable=local_sink_func), - full_refresh=True, + dev_mode=True, ) p.run([1, 2, 3], table_name="items") assert len(calls) == 1 # local func does not create entry in destinations - assert not _DESTINATIONS + assert "local_sink_func" not in _DESTINATIONS # test passing string reference global global_calls @@ -167,7 +196,7 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va "destination", destination_callable="tests.destinations.test_custom_destination.global_sink_func", ), - full_refresh=True, + dev_mode=True, ) p.run([1, 2, 3], table_name="items") assert len(global_calls) == 1 @@ -182,9 +211,9 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va p = dlt.pipeline( "sink_test", destination=Destination.from_reference("destination", destination_callable=None), - full_refresh=True, + dev_mode=True, ) - with pytest.raises(PipelineStepFailed): + with pytest.raises(ConfigurationValueError): p.run([1, 2, 3], table_name="items") # pass invalid string reference will fail on instantiation @@ -194,7 +223,7 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va destination=Destination.from_reference( "destination", destination_callable="does.not.exist" ), - full_refresh=True, + dev_mode=True, ) # using decorator without args will also work @@ -206,7 +235,7 @@ def simple_decorator_sink(items, table, my_val=dlt.config.value): assert my_val == "something" calls.append((items, table)) - p = dlt.pipeline("sink_test", destination=simple_decorator_sink, full_refresh=True) # type: ignore + p = dlt.pipeline("sink_test", destination=simple_decorator_sink, dev_mode=True) # type: ignore p.run([1, 2, 3], table_name="items") assert len(calls) == 1 @@ -265,7 +294,7 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: assert str(i) in collected_items # no errors are set, all items should be processed - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) load_id = p.run([items(), items2()]).loads_ids[0] assert_items_in_range(calls["items"], 0, 100) assert_items_in_range(calls["items2"], 0, 100) @@ -278,7 +307,7 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: # provoke errors calls = {} provoke_error = {"items": 25, "items2": 45} - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) with pytest.raises(PipelineStepFailed): p.run([items(), items2()]) @@ -335,7 +364,7 @@ def snake_sink(items, table): assert table["columns"]["snake_case"]["name"] == "snake_case" assert table["columns"]["camel_case"]["name"] == "camel_case" - dlt.pipeline("sink_test", destination=snake_sink, full_refresh=True).run(resource()) + dlt.pipeline("sink_test", destination=snake_sink, dev_mode=True).run(resource()) # check default (which is direct) @dlt.destination() @@ -345,7 +374,7 @@ def direct_sink(items, table): assert table["columns"]["snake_case"]["name"] == "snake_case" assert table["columns"]["camelCase"]["name"] == "camelCase" - dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource()) + dlt.pipeline("sink_test", destination=direct_sink, dev_mode=True).run(resource()) def test_file_batch() -> None: @@ -368,7 +397,7 @@ def direct_sink(file_path, table): with pyarrow.parquet.ParquetFile(file_path) as reader: assert reader.metadata.num_rows == (100 if table["name"] == "person" else 50) - dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=direct_sink, dev_mode=True).run( [resource1(), resource2()] ) @@ -384,25 +413,23 @@ def my_sink(file_path, table, my_val=dlt.config.value): # if no value is present, it should raise with pytest.raises(ConfigFieldMissingException): - dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) # we may give the value via __callable__ function - dlt.pipeline("sink_test", destination=my_sink(my_val="something"), full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_sink(my_val="something"), dev_mode=True).run( [1, 2, 3], table_name="items" ) # right value will pass os.environ["DESTINATION__MY_SINK__MY_VAL"] = "something" - dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( - [1, 2, 3], table_name="items" - ) + dlt.pipeline("sink_test", destination=my_sink, dev_mode=True).run([1, 2, 3], table_name="items") # wrong value will raise os.environ["DESTINATION__MY_SINK__MY_VAL"] = "wrong" with pytest.raises(PipelineStepFailed): - dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -413,13 +440,13 @@ def other_sink(file_path, table, my_val=dlt.config.value): # if no value is present, it should raise with pytest.raises(ConfigFieldMissingException): - dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=other_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) # right value will pass os.environ["DESTINATION__SOME_NAME__MY_VAL"] = "something" - dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=other_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -437,7 +464,7 @@ def my_gcp_sink( # missing spec with pytest.raises(ConfigFieldMissingException): - dlt.pipeline("sink_test", destination=my_gcp_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_gcp_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -447,7 +474,7 @@ def my_gcp_sink( os.environ["CREDENTIALS__USERNAME"] = "my_user_name" # now it will run - dlt.pipeline("sink_test", destination=my_gcp_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_gcp_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -471,14 +498,14 @@ def sink_func_with_spec( # call fails because `my_predefined_val` is required part of spec, even if not injected with pytest.raises(ConfigFieldMissingException): - info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), full_refresh=True).run( + info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), dev_mode=True).run( [1, 2, 3], table_name="items" ) info.raise_on_failed_jobs() # call happens now os.environ["MY_PREDEFINED_VAL"] = "VAL" - info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), full_refresh=True).run( + info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), dev_mode=True).run( [1, 2, 3], table_name="items" ) info.raise_on_failed_jobs() @@ -550,7 +577,7 @@ def test_sink(items, table): found_dlt_column_value = True # test with and without removing - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) p.run([{"id": 1, "value": "1"}], table_name="some_table") assert found_dlt_column != remove_stuff @@ -579,7 +606,7 @@ def nesting_sink(items, table): def source(): yield dlt.resource(data, name="data") - p = dlt.pipeline("sink_test_max_nesting", destination=nesting_sink, full_refresh=True) + p = dlt.pipeline("sink_test_max_nesting", destination=nesting_sink, dev_mode=True) p.run(source()) # fall back to source setting diff --git a/tests/extract/data_writers/test_buffered_writer.py b/tests/extract/data_writers/test_buffered_writer.py index b6da132de9..5cad5a35b9 100644 --- a/tests/extract/data_writers/test_buffered_writer.py +++ b/tests/extract/data_writers/test_buffered_writer.py @@ -264,6 +264,27 @@ def test_import_file(writer_type: Type[DataWriter]) -> None: assert metrics.file_size == 231 +@pytest.mark.parametrize("writer_type", ALL_WRITERS) +def test_import_file_with_extension(writer_type: Type[DataWriter]) -> None: + now = time.time() + with get_writer(writer_type) as writer: + # won't destroy the original + metrics = writer.import_file( + "tests/extract/cases/imported.any", + DataWriterMetrics("", 1, 231, 0, 0), + with_extension="any", + ) + assert len(writer.closed_files) == 1 + assert os.path.isfile(metrics.file_path) + # extension is correctly set + assert metrics.file_path.endswith(".any") + assert writer.closed_files[0] == metrics + assert metrics.created <= metrics.last_modified + assert metrics.created >= now + assert metrics.items_count == 1 + assert metrics.file_size == 231 + + @pytest.mark.parametrize( "disable_compression", [True, False], ids=["no_compression", "compression"] ) diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index db888c95e4..f9775fd218 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -42,7 +42,7 @@ ) from dlt.extract.items import TableNameMeta -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9 +from tests.common.utils import load_yml_case def test_default_resource() -> None: @@ -107,7 +107,10 @@ def test_load_schema_for_callable() -> None: schema = s.schema assert schema.name == "ethereum" == s.name # the schema in the associated file has this hash - assert schema.stored_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + eth_v9 = load_yml_case("schemas/eth/ethereum_schema_v9") + # source removes processing hints so we do + reference_schema = Schema.from_dict(eth_v9, remove_processing_hints=True) + assert schema.stored_version_hash == reference_schema.stored_version_hash def test_unbound_parametrized_transformer() -> None: @@ -341,6 +344,41 @@ class Columns3(BaseModel): assert t["columns"]["b"]["data_type"] == "double" +def test_not_normalized_identifiers_in_hints() -> None: + @dlt.resource( + primary_key="ID", + merge_key=["Month", "Day"], + columns=[{"name": "Col1", "data_type": "bigint"}], + table_name="🐫Camels", + ) + def CamelResource(): + yield ["🐫"] * 10 + + camels = CamelResource() + # original names are kept + assert camels.name == "CamelResource" + assert camels.table_name == "🐫Camels" + assert camels.columns == {"Col1": {"data_type": "bigint", "name": "Col1"}} + table = camels.compute_table_schema() + columns = table["columns"] + assert "ID" in columns + assert "Month" in columns + assert "Day" in columns + assert "Col1" in columns + assert table["name"] == "🐫Camels" + + # define as part of a source + camel_source = DltSource(Schema("snake_case"), "camel_section", [camels]) + schema = camel_source.discover_schema() + # all normalized + table = schema.get_table("_camels") + columns = table["columns"] + assert "id" in columns + assert "month" in columns + assert "day" in columns + assert "col1" in columns + + def test_resource_name_from_generator() -> None: def some_data(): yield [1, 2, 3] @@ -565,6 +603,21 @@ def created_global(): _assert_source_schema(created_global(), "global") +def test_source_schema_removes_processing_hints() -> None: + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + assert "x-normalizer" in eth_V9["tables"]["blocks"] + + @dlt.source(schema=Schema.from_dict(eth_V9)) + def created_explicit(): + schema = dlt.current.source_schema() + assert schema.name == "ethereum" + assert "x-normalizer" not in schema.tables["blocks"] + return dlt.resource([1, 2, 3], name="res") + + source = created_explicit() + assert "x-normalizer" not in source.schema.tables["blocks"] + + def test_source_state_context() -> None: @dlt.resource(selected=False) def main(): @@ -849,6 +902,18 @@ def test_standalone_transformer(next_item_mode: str) -> None: ] +def test_transformer_required_args() -> None: + @dlt.transformer + def path_params(id_, workspace_id, load_id, base: bool = False): + yield {"id": id_, "workspace_id": workspace_id, "load_id": load_id} + + data = list([1, 2, 3] | path_params(121, 343)) + assert len(data) == 3 + assert data[0] == {"id": 1, "workspace_id": 121, "load_id": 343} + + # @dlt + + @dlt.transformer(standalone=True, name=lambda args: args["res_name"]) def standalone_tx_with_name(item: TDataItem, res_name: str, init: int = dlt.config.value): return res_name * item * init diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py index dc978b997a..dbec417f97 100644 --- a/tests/extract/test_extract.py +++ b/tests/extract/test_extract.py @@ -125,6 +125,7 @@ def with_table_hints(): {"id": 1, "pk2": "B"}, make_hints( write_disposition="merge", + file_format="preferred", columns=[{"name": "id", "precision": 16}, {"name": "text", "data_type": "decimal"}], primary_key="pk2", ), @@ -143,6 +144,7 @@ def with_table_hints(): assert "pk" in table["columns"] assert "text" in table["columns"] assert table["write_disposition"] == "merge" + assert table["file_format"] == "preferred" # make table name dynamic yield dlt.mark.with_hints( diff --git a/tests/extract/test_extract_pipe.py b/tests/extract/test_extract_pipe.py index 9bf580b76a..d285181c55 100644 --- a/tests/extract/test_extract_pipe.py +++ b/tests/extract/test_extract_pipe.py @@ -510,6 +510,19 @@ def test_pipe_copy_on_fork() -> None: assert elems[0].item is not elems[1].item +def test_pipe_pass_empty_list() -> None: + def _gen(): + yield [] + + pipe = Pipe.from_data("data", _gen()) + elems = list(PipeIterator.from_pipe(pipe)) + assert elems[0].item == [] + + pipe = Pipe.from_data("data", [[]]) + elems = list(PipeIterator.from_pipe(pipe)) + assert elems[0].item == [] + + def test_clone_single_pipe() -> None: doc = {"e": 1, "l": 2} parent = Pipe.from_data("data", [doc]) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index bb6fb70983..49437d7b74 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -17,6 +17,7 @@ from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration from dlt.common.configuration import ConfigurationValueError from dlt.common.pendulum import pendulum, timedelta +from dlt.common import Decimal from dlt.common.pipeline import NormalizeInfo, StateInjectableContext, resource_state from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id, digest128, chunks @@ -786,6 +787,43 @@ def some_data(first: bool, last_timestamp=dlt.sources.incremental("ts")): p.run(some_data(False)) +@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"pandas"}) +@pytest.mark.parametrize( + "id_value", + ("1231231231231271872", b"1231231231231271872", pendulum.now(), 1271.78, Decimal("1231.87")), +) +def test_primary_key_types(item_type: TestDataItemFormat, id_value: Any) -> None: + """Case when deduplication filter is empty for an Arrow table.""" + p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") + now = pendulum.now() + + data = [ + { + "delta": str(i), + "ts": now.add(days=i), + "_id": id_value, + } + for i in range(-10, 10) + ] + source_items = data_to_item_format(item_type, data) + start = now.add(days=-10) + + @dlt.resource + def some_data( + last_timestamp=dlt.sources.incremental("ts", initial_value=start, primary_key="_id"), + ): + yield from source_items + + info = p.run(some_data()) + info.raise_on_failed_jobs() + norm_info = p.last_trace.last_normalize_info + assert norm_info.row_counts["some_data"] == 20 + # load incrementally + info = p.run(some_data()) + norm_info = p.last_trace.last_normalize_info + assert "some_data" not in norm_info.row_counts + + @pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) def test_replace_resets_state(item_type: TestDataItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index 7b2613776d..8287da69d4 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -1274,6 +1274,8 @@ def empty_gen(): primary_key=["a", "b"], merge_key=["c", "a"], schema_contract="freeze", + table_format="delta", + file_format="jsonl", ) table = empty_r.compute_table_schema() assert table["columns"]["a"] == { @@ -1288,11 +1290,15 @@ def empty_gen(): assert table["parent"] == "parent" assert empty_r.table_name == "table" assert table["schema_contract"] == "freeze" + assert table["table_format"] == "delta" + assert table["file_format"] == "jsonl" # reset empty_r.apply_hints( table_name="", parent_table_name="", + table_format="", + file_format="", primary_key=[], merge_key="", columns={}, diff --git a/tests/libs/pyarrow/test_pyarrow_normalizer.py b/tests/libs/pyarrow/test_pyarrow_normalizer.py index 63abcbc92a..d975702ad8 100644 --- a/tests/libs/pyarrow/test_pyarrow_normalizer.py +++ b/tests/libs/pyarrow/test_pyarrow_normalizer.py @@ -3,8 +3,8 @@ import pyarrow as pa import pytest -from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationClash -from dlt.common.normalizers import explicit_normalizers, import_normalizers +from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationCollision +from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers from dlt.common.schema.utils import new_column, TColumnSchema from dlt.common.destination import DestinationCapabilitiesContext @@ -65,7 +65,7 @@ def test_field_normalization_clash() -> None: {"col^New": "hello", "col_new": 1}, ] ) - with pytest.raises(NameNormalizationClash): + with pytest.raises(NameNormalizationCollision): _normalize(table, []) diff --git a/tests/load/athena_iceberg/test_athena_adapter.py b/tests/load/athena_iceberg/test_athena_adapter.py index 3144eb9cc9..19c176a374 100644 --- a/tests/load/athena_iceberg/test_athena_adapter.py +++ b/tests/load/athena_iceberg/test_athena_adapter.py @@ -2,7 +2,7 @@ import dlt from dlt.destinations import filesystem -from dlt.destinations.impl.athena.athena_adapter import athena_adapter, athena_partition +from dlt.destinations.adapters import athena_adapter, athena_partition # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -40,7 +40,7 @@ def not_partitioned_table(): "athena_test", destination="athena", staging=filesystem("s3://not-a-real-bucket"), - full_refresh=True, + dev_mode=True, ) pipeline.extract([partitioned_table, not_partitioned_table]) diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index 4fe01752ee..0ef935a8bc 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -1,15 +1,9 @@ import pytest import os -import datetime # noqa: I251 from typing import Iterator, Any import dlt -from dlt.common import pendulum -from dlt.common.utils import uniq_id -from tests.cases import table_update_and_row, assert_all_data_types_row -from tests.pipeline.utils import assert_load_info, load_table_counts - -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.pipeline.utils import load_table_counts from dlt.destinations.exceptions import DatabaseTerminalException diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index b16790b07d..e8b5dab8fd 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -22,7 +22,7 @@ from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException -from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, preserve_environ +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import json_case_path as common_json_case_path from tests.common.configuration.utils import environment from tests.load.utils import ( diff --git a/tests/load/bigquery/test_bigquery_streaming_insert.py b/tests/load/bigquery/test_bigquery_streaming_insert.py index c80f6ed65a..c950a46f91 100644 --- a/tests/load/bigquery/test_bigquery_streaming_insert.py +++ b/tests/load/bigquery/test_bigquery_streaming_insert.py @@ -1,7 +1,7 @@ import pytest import dlt -from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.adapters import bigquery_adapter from tests.pipeline.utils import assert_load_info @@ -12,7 +12,7 @@ def test_resource(): bigquery_adapter(test_resource, insert_api="streaming") - pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", full_refresh=True) + pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", dev_mode=True) pack = pipe.run(test_resource, table_name="test_streaming_items44") assert_load_info(pack) @@ -41,10 +41,12 @@ def test_resource(): pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery") info = pipe.run(test_resource) + # pick the failed job + failed_job = info.load_packages[0].jobs["failed_jobs"][0] assert ( """BigQuery streaming insert can only be used with `append`""" """ write_disposition, while the given resource has `merge`.""" - ) in info.asdict()["load_packages"][0]["jobs"][0]["failed_message"] + ) in failed_job.failed_message def test_bigquery_streaming_nested_data(): @@ -54,7 +56,7 @@ def test_resource(): bigquery_adapter(test_resource, insert_api="streaming") - pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", full_refresh=True) + pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", dev_mode=True) pack = pipe.run(test_resource, table_name="test_streaming_items") assert_load_info(pack) diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index df564192dc..66ea4a319f 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -21,17 +21,23 @@ from dlt.common.schema import Schema from dlt.common.utils import custom_environ from dlt.common.utils import uniq_id + from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate +from dlt.destinations import bigquery from dlt.destinations.impl.bigquery.bigquery import BigQueryClient -from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.adapters import bigquery_adapter from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration + from dlt.extract import DltResource -from tests.load.pipeline.utils import ( + +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, drop_active_pipeline_data, + TABLE_UPDATE, + sequence_generator, + empty_schema, ) -from tests.load.utils import TABLE_UPDATE, sequence_generator, empty_schema # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -58,7 +64,7 @@ def gcp_client(empty_schema: Schema) -> BigQueryClient: creds = GcpServiceAccountCredentials() creds.project_id = "test_project_id" # noinspection PydanticTypeChecker - return BigQueryClient( + return bigquery().client( empty_schema, BigQueryClientConfiguration(credentials=creds)._bind_dataset_name( dataset_name=f"test_{uniq_id()}" @@ -89,9 +95,9 @@ def test_create_table(gcp_client: BigQueryClient) -> None: sqlfluff.parse(sql, dialect="bigquery") assert sql.startswith("CREATE TABLE") assert "event_test_table" in sql - assert "`col1` INTEGER NOT NULL" in sql + assert "`col1` INT64 NOT NULL" in sql assert "`col2` FLOAT64 NOT NULL" in sql - assert "`col3` BOOLEAN NOT NULL" in sql + assert "`col3` BOOL NOT NULL" in sql assert "`col4` TIMESTAMP NOT NULL" in sql assert "`col5` STRING " in sql assert "`col6` NUMERIC(38,9) NOT NULL" in sql @@ -100,7 +106,7 @@ def test_create_table(gcp_client: BigQueryClient) -> None: assert "`col9` JSON NOT NULL" in sql assert "`col10` DATE" in sql assert "`col11` TIME" in sql - assert "`col1_precision` INTEGER NOT NULL" in sql + assert "`col1_precision` INT64 NOT NULL" in sql assert "`col4_precision` TIMESTAMP NOT NULL" in sql assert "`col5_precision` STRING(25) " in sql assert "`col6_precision` NUMERIC(6,2) NOT NULL" in sql @@ -119,9 +125,9 @@ def test_alter_table(gcp_client: BigQueryClient) -> None: assert sql.startswith("ALTER TABLE") assert sql.count("ALTER TABLE") == 1 assert "event_test_table" in sql - assert "ADD COLUMN `col1` INTEGER NOT NULL" in sql + assert "ADD COLUMN `col1` INT64 NOT NULL" in sql assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql - assert "ADD COLUMN `col3` BOOLEAN NOT NULL" in sql + assert "ADD COLUMN `col3` BOOL NOT NULL" in sql assert "ADD COLUMN `col4` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5` STRING" in sql assert "ADD COLUMN `col6` NUMERIC(38,9) NOT NULL" in sql @@ -130,7 +136,7 @@ def test_alter_table(gcp_client: BigQueryClient) -> None: assert "ADD COLUMN `col9` JSON NOT NULL" in sql assert "ADD COLUMN `col10` DATE" in sql assert "ADD COLUMN `col11` TIME" in sql - assert "ADD COLUMN `col1_precision` INTEGER NOT NULL" in sql + assert "ADD COLUMN `col1_precision` INT64 NOT NULL" in sql assert "ADD COLUMN `col4_precision` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5_precision` STRING(25)" in sql assert "ADD COLUMN `col6_precision` NUMERIC(6,2) NOT NULL" in sql @@ -946,7 +952,7 @@ def sources() -> List[DltResource]: pipeline = destination_config.setup_pipeline( f"bigquery_{uniq_id()}", - full_refresh=True, + dev_mode=True, ) pipeline.run(sources()) diff --git a/tests/load/cases/loading/csv_header.csv b/tests/load/cases/loading/csv_header.csv new file mode 100644 index 0000000000..14c7514e51 --- /dev/null +++ b/tests/load/cases/loading/csv_header.csv @@ -0,0 +1,3 @@ +id|name|description|ordered_at|price +1|item|value|2024-04-12|128.4 +1|"item"|value with space|2024-04-12|128.4 \ No newline at end of file diff --git a/tests/load/cases/loading/csv_no_header.csv b/tests/load/cases/loading/csv_no_header.csv new file mode 100644 index 0000000000..1e3a63494e --- /dev/null +++ b/tests/load/cases/loading/csv_no_header.csv @@ -0,0 +1,2 @@ +1|item|value|2024-04-12|128.4 +1|"item"|value with space|2024-04-12|128.4 \ No newline at end of file diff --git a/tests/load/cases/loading/csv_no_header.csv.gz b/tests/load/cases/loading/csv_no_header.csv.gz new file mode 100644 index 0000000000..310950f484 Binary files /dev/null and b/tests/load/cases/loading/csv_no_header.csv.gz differ diff --git a/tests/load/cases/loading/header.jsonl b/tests/load/cases/loading/header.jsonl new file mode 100644 index 0000000000..c2f9fee551 --- /dev/null +++ b/tests/load/cases/loading/header.jsonl @@ -0,0 +1,2 @@ +{"id": 1, "name": "item", "description": "value", "ordered_at": "2024-04-12", "price": 128.4} +{"id": 1, "name": "item", "description": "value with space", "ordered_at": "2024-04-12", "price": 128.4} \ No newline at end of file diff --git a/tests/load/clickhouse/test_clickhouse_adapter.py b/tests/load/clickhouse/test_clickhouse_adapter.py index 36d3ac07f7..ea3116c25b 100644 --- a/tests/load/clickhouse/test_clickhouse_adapter.py +++ b/tests/load/clickhouse/test_clickhouse_adapter.py @@ -19,7 +19,7 @@ def not_annotated_resource(): clickhouse_adapter(merge_tree_resource, table_engine_type="merge_tree") clickhouse_adapter(replicated_merge_tree_resource, table_engine_type="replicated_merge_tree") - pipe = dlt.pipeline(pipeline_name="adapter_test", destination="clickhouse", full_refresh=True) + pipe = dlt.pipeline(pipeline_name="adapter_test", destination="clickhouse", dev_mode=True) pack = pipe.run([merge_tree_resource, replicated_merge_tree_resource, not_annotated_resource]) assert_load_info(pack) diff --git a/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py b/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py index 481cd420c6..b2edb12d49 100644 --- a/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py +++ b/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py @@ -22,7 +22,7 @@ def dummy_data() -> Generator[Dict[str, int], None, None]: pipeline_name="gcs_s3_compatibility", destination="clickhouse", staging=gcp_bucket, - full_refresh=True, + dev_mode=True, ) pack = pipe.run([dummy_data]) assert_load_info(pack) diff --git a/tests/load/clickhouse/test_clickhouse_table_builder.py b/tests/load/clickhouse/test_clickhouse_table_builder.py index fd3bf50907..867102dde9 100644 --- a/tests/load/clickhouse/test_clickhouse_table_builder.py +++ b/tests/load/clickhouse/test_clickhouse_table_builder.py @@ -6,6 +6,8 @@ from dlt.common.schema import Schema from dlt.common.utils import custom_environ, digest128 from dlt.common.utils import uniq_id + +from dlt.destinations import clickhouse from dlt.destinations.impl.clickhouse.clickhouse import ClickHouseClient from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseCredentials, @@ -18,7 +20,7 @@ def clickhouse_client(empty_schema: Schema) -> ClickHouseClient: # Return a client without opening connection. creds = ClickHouseCredentials() - return ClickHouseClient( + return clickhouse().client( empty_schema, ClickHouseClientConfiguration(credentials=creds)._bind_dataset_name(f"test_{uniq_id()}"), ) diff --git a/tests/load/conftest.py b/tests/load/conftest.py index fefaeee077..a110b1198f 100644 --- a/tests/load/conftest.py +++ b/tests/load/conftest.py @@ -2,8 +2,8 @@ import pytest from typing import Iterator -from tests.load.utils import ALL_BUCKETS, DEFAULT_BUCKETS, WITH_GDRIVE_BUCKETS -from tests.utils import preserve_environ +from tests.load.utils import ALL_BUCKETS, DEFAULT_BUCKETS, WITH_GDRIVE_BUCKETS, drop_pipeline +from tests.utils import preserve_environ, patch_home_dir @pytest.fixture(scope="function", params=DEFAULT_BUCKETS) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index cc353f5894..f6a06180c9 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -6,7 +6,6 @@ from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.common.configuration import resolve_configuration -from tests.utils import preserve_environ # mark all tests as essential, do not remove pytestmark = pytest.mark.essential diff --git a/tests/load/dremio/test_dremio_client.py b/tests/load/dremio/test_dremio_client.py index d0002dc343..efc72c0652 100644 --- a/tests/load/dremio/test_dremio_client.py +++ b/tests/load/dremio/test_dremio_client.py @@ -1,6 +1,8 @@ import pytest from dlt.common.schema import TColumnSchema, Schema + +from dlt.destinations import dremio from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration, DremioCredentials from dlt.destinations.impl.dremio.dremio import DremioClient from tests.load.utils import empty_schema @@ -10,11 +12,11 @@ def dremio_client(empty_schema: Schema) -> DremioClient: creds = DremioCredentials() creds.database = "test_database" - return DremioClient( + # ignore any configured values + creds.resolve() + return dremio(credentials=creds).client( empty_schema, - DremioClientConfiguration(credentials=creds)._bind_dataset_name( - dataset_name="test_dataset" - ), + DremioClientConfiguration()._bind_dataset_name(dataset_name="test_dataset"), ) diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index 8f6bf195e2..ebbe959874 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -15,9 +15,8 @@ from dlt.destinations.impl.duckdb.exceptions import InvalidInMemoryDuckdbCredentials from dlt.pipeline.exceptions import PipelineStepFailed -from tests.load.pipeline.utils import drop_pipeline from tests.pipeline.utils import assert_table -from tests.utils import patch_home_dir, autouse_test_storage, preserve_environ, TEST_STORAGE_ROOT +from tests.utils import patch_home_dir, autouse_test_storage, TEST_STORAGE_ROOT # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -57,7 +56,7 @@ def test_duckdb_open_conn_default() -> None: delete_quack_db() -def test_duckdb_in_memory_mode_via_factory(preserve_environ): +def test_duckdb_in_memory_mode_via_factory(): delete_quack_db() try: import duckdb diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 545f182ece..85f86ce84d 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -5,6 +5,7 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema +from dlt.destinations import duckdb from dlt.destinations.impl.duckdb.duck import DuckDbClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration @@ -22,7 +23,7 @@ @pytest.fixture def client(empty_schema: Schema) -> DuckDbClient: # return client without opening connection - return DuckDbClient( + return duckdb().client( empty_schema, DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_" + uniq_id()), ) @@ -117,7 +118,7 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: assert '"col4" TIMESTAMP WITH TIME ZONE NOT NULL' in sql # same thing with indexes - client = DuckDbClient( + client = duckdb().client( client.schema, DuckDbClientConfiguration(create_indexes=True)._bind_dataset_name( dataset_name="test_" + uniq_id() diff --git a/tests/load/duckdb/test_motherduck_client.py b/tests/load/duckdb/test_motherduck_client.py index 2a1d703c87..764e1654c6 100644 --- a/tests/load/duckdb/test_motherduck_client.py +++ b/tests/load/duckdb/test_motherduck_client.py @@ -14,7 +14,7 @@ MotherDuckClientConfiguration, ) -from tests.utils import patch_home_dir, preserve_environ, skip_if_not_active +from tests.utils import patch_home_dir, skip_if_not_active # mark all tests as essential, do not remove pytestmark = pytest.mark.essential diff --git a/tests/load/filesystem/test_aws_credentials.py b/tests/load/filesystem/test_aws_credentials.py index 1a41144744..5e0a3c3fd0 100644 --- a/tests/load/filesystem/test_aws_credentials.py +++ b/tests/load/filesystem/test_aws_credentials.py @@ -1,6 +1,7 @@ import pytest from typing import Dict +from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration from dlt.common.utils import digest128 from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs.aws_credentials import AwsCredentials @@ -8,7 +9,7 @@ from tests.common.configuration.utils import environment from tests.load.utils import ALL_FILESYSTEM_DRIVERS -from tests.utils import preserve_environ, autouse_test_storage +from tests.utils import autouse_test_storage # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -101,6 +102,11 @@ def test_aws_credentials_from_boto3(environment: Dict[str, str]) -> None: assert c.aws_access_key_id == "fake_access_key" +def test_aws_credentials_from_unknown_object() -> None: + with pytest.raises(InvalidBoto3Session): + AwsCredentials().parse_native_representation(CredentialsConfiguration()) + + def test_aws_credentials_for_profile(environment: Dict[str, str]) -> None: import botocore.exceptions diff --git a/tests/load/filesystem/test_azure_credentials.py b/tests/load/filesystem/test_azure_credentials.py index 4ee2ec46db..2353491737 100644 --- a/tests/load/filesystem/test_azure_credentials.py +++ b/tests/load/filesystem/test_azure_credentials.py @@ -17,7 +17,7 @@ from dlt.common.storages.configuration import FilesystemConfiguration from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AZ_BUCKET from tests.common.configuration.utils import environment -from tests.utils import preserve_environ, autouse_test_storage +from tests.utils import autouse_test_storage from dlt.common.storages.fsspec_filesystem import fsspec_from_config # mark all tests as essential, do not remove diff --git a/tests/load/filesystem/test_filesystem_client.py b/tests/load/filesystem/test_filesystem_client.py index 53e54c2f06..f16e75c7e6 100644 --- a/tests/load/filesystem/test_filesystem_client.py +++ b/tests/load/filesystem/test_filesystem_client.py @@ -2,13 +2,22 @@ import os from unittest import mock from pathlib import Path +from urllib.parse import urlparse import pytest +from dlt.common.configuration.specs.azure_credentials import AzureCredentials +from dlt.common.configuration.specs.base_configuration import ( + CredentialsConfiguration, + extract_inner_hint, +) +from dlt.common.schema.schema import Schema +from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.time import ensure_pendulum_datetime from dlt.common.utils import digest128, uniq_id from dlt.common.storages import FileStorage, ParsedLoadJobFileName +from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import ( FilesystemDestinationClientConfiguration, INIT_FILE_NAME, @@ -52,6 +61,32 @@ def test_filesystem_destination_configuration(url, exp) -> None: assert FilesystemDestinationClientConfiguration(bucket_url=url).fingerprint() == exp +def test_filesystem_factory_buckets(with_gdrive_buckets_env: str) -> None: + proto = urlparse(with_gdrive_buckets_env).scheme + credentials_type = extract_inner_hint( + FilesystemConfiguration.PROTOCOL_CREDENTIALS.get(proto, CredentialsConfiguration) + ) + + # test factory figuring out the right credentials + filesystem_ = filesystem(with_gdrive_buckets_env) + client = filesystem_.client( + Schema("test"), + initial_config=FilesystemDestinationClientConfiguration()._bind_dataset_name("test"), + ) + assert client.config.protocol == proto or "file" + assert isinstance(client.config.credentials, credentials_type) + assert issubclass(client.config.credentials_type(client.config), credentials_type) + assert filesystem_.capabilities() + + # factory gets initial credentials + filesystem_ = filesystem(with_gdrive_buckets_env, credentials=credentials_type()) + client = filesystem_.client( + Schema("test"), + initial_config=FilesystemDestinationClientConfiguration()._bind_dataset_name("test"), + ) + assert isinstance(client.config.credentials, credentials_type) + + @pytest.mark.parametrize("write_disposition", ("replace", "append", "merge")) @pytest.mark.parametrize("layout", TEST_FILE_LAYOUTS) def test_successful_load(write_disposition: str, layout: str, with_gdrive_buckets_env: str) -> None: diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index 270e1ff70c..a7b1371f9f 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -20,9 +20,10 @@ from dlt.destinations.impl.filesystem.configuration import ( FilesystemDestinationClientConfiguration, ) +from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders from tests.common.storages.utils import assert_sample_files from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AWS_BUCKET -from tests.utils import preserve_environ, autouse_test_storage +from tests.utils import autouse_test_storage from .utils import self_signed_cert from tests.common.configuration.utils import environment @@ -199,7 +200,7 @@ def test_s3_wrong_client_certificate(default_buckets_env: str, self_signed_cert: def test_filesystem_destination_config_reports_unused_placeholders(mocker) -> None: with custom_environ({"DATASET_NAME": "BOBO"}): - extra_placeholders = { + extra_placeholders: TExtraPlaceholders = { "value": 1, "otters": "lab", "dlt": "labs", @@ -211,7 +212,7 @@ def test_filesystem_destination_config_reports_unused_placeholders(mocker) -> No FilesystemDestinationClientConfiguration( bucket_url="file:///tmp/dirbobo", layout="{schema_name}/{table_name}/{otters}-x-{x}/{load_id}.{file_id}.{timestamp}.{ext}", - extra_placeholders=extra_placeholders, # type: ignore + extra_placeholders=extra_placeholders, ) ) logger_spy.assert_called_once_with("Found unused layout placeholders: value, dlt, dlthub") @@ -227,7 +228,7 @@ def test_filesystem_destination_passed_parameters_override_config_values() -> No "DESTINATION__FILESYSTEM__EXTRA_PLACEHOLDERS": json.dumps(config_extra_placeholders), } ): - extra_placeholders = { + extra_placeholders: TExtraPlaceholders = { "new_value": 1, "dlt": "labs", "dlthub": "platform", diff --git a/tests/load/filesystem/test_object_store_rs_credentials.py b/tests/load/filesystem/test_object_store_rs_credentials.py index 4e43b7c5d8..524cd4425d 100644 --- a/tests/load/filesystem/test_object_store_rs_credentials.py +++ b/tests/load/filesystem/test_object_store_rs_credentials.py @@ -29,9 +29,11 @@ FS_CREDS: Dict[str, Any] = dlt.secrets.get("destination.filesystem.credentials") -assert ( - FS_CREDS is not None -), "`destination.filesystem.credentials` must be configured for these tests." +if FS_CREDS is None: + pytest.skip( + msg="`destination.filesystem.credentials` must be configured for these tests.", + allow_module_level=True, + ) def can_connect(bucket_url: str, object_store_rs_credentials: Dict[str, str]) -> bool: @@ -86,6 +88,7 @@ def test_aws_object_store_rs_credentials() -> None: creds = AwsCredentials( aws_access_key_id=FS_CREDS["aws_access_key_id"], aws_secret_access_key=FS_CREDS["aws_secret_access_key"], + # region_name must be configured in order for data lake to work region_name=FS_CREDS["region_name"], ) assert creds.aws_session_token is None @@ -138,6 +141,7 @@ def test_gcp_object_store_rs_credentials() -> None: creds = GcpServiceAccountCredentialsWithoutDefaults( project_id=FS_CREDS["project_id"], private_key=FS_CREDS["private_key"], + # private_key_id must be configured in order for data lake to work private_key_id=FS_CREDS["private_key_id"], client_email=FS_CREDS["client_email"], ) diff --git a/tests/load/lancedb/__init__.py b/tests/load/lancedb/__init__.py new file mode 100644 index 0000000000..fb4bf0b35d --- /dev/null +++ b/tests/load/lancedb/__init__.py @@ -0,0 +1,3 @@ +from tests.utils import skip_if_not_active + +skip_if_not_active("lancedb") diff --git a/tests/load/lancedb/test_config.py b/tests/load/lancedb/test_config.py new file mode 100644 index 0000000000..c1d658d4fe --- /dev/null +++ b/tests/load/lancedb/test_config.py @@ -0,0 +1,35 @@ +import os +from typing import Iterator + +import pytest + +from dlt.common.configuration import resolve_configuration +from dlt.common.utils import digest128 +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBClientConfiguration, +) +from tests.load.utils import ( + drop_active_pipeline_data, +) + + +# Mark all tests as essential, do not remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +def test_lancedb_configuration() -> None: + os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL_PROVIDER"] = "colbert" + os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL"] = "text-embedding-3-small" + + config = resolve_configuration( + LanceDBClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "lancedb"), + ) + assert config.embedding_model_provider == "colbert" + assert config.embedding_model == "text-embedding-3-small" diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py new file mode 100644 index 0000000000..a89153f629 --- /dev/null +++ b/tests/load/lancedb/test_pipeline.py @@ -0,0 +1,435 @@ +from typing import Iterator, Generator, Any, List + +import pytest + +import dlt +from dlt.common import json +from dlt.common.typing import DictStrStr, DictStrAny +from dlt.common.utils import uniq_id +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + lancedb_adapter, + VECTORIZE_HINT, +) +from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient +from tests.load.lancedb.utils import assert_table +from tests.load.utils import sequence_generator, drop_active_pipeline_data +from tests.pipeline.utils import assert_load_info + + +# Mark all tests as essential, do not remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +def test_adapter_and_hints() -> None: + generator_instance1 = sequence_generator() + + @dlt.resource(columns=[{"name": "content", "data_type": "text"}]) + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + assert some_data.columns["content"] == {"name": "content", "data_type": "text"} # type: ignore[index] + + lancedb_adapter( + some_data, + embed=["content"], + ) + + assert some_data.columns["content"] == { # type: ignore + "name": "content", + "data_type": "text", + "x-lancedb-embed": True, + } + + +def test_basic_state_and_schema() -> None: + generator_instance1 = sequence_generator() + + @dlt.resource + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + lancedb_adapter( + some_data, + embed=["content"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"test_pipeline_append_dataset{uniq_id()}", + ) + info = pipeline.run( + some_data(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore + # Check if we can get a stored schema and state. + schema = client.get_stored_schema() + print("Print dataset name", client.dataset_name) + assert schema + state = client.get_stored_state("test_pipeline_append") + assert state + + +def test_pipeline_append() -> None: + generator_instance1 = sequence_generator() + generator_instance2 = sequence_generator() + + @dlt.resource + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + lancedb_adapter( + some_data, + embed=["content"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + info = pipeline.run( + some_data(), + ) + assert_load_info(info) + + data = next(generator_instance2) + assert_table(pipeline, "some_data", items=data) + + info = pipeline.run( + some_data(), + ) + assert_load_info(info) + + data.extend(next(generator_instance2)) + assert_table(pipeline, "some_data", items=data) + + +def test_explicit_append() -> None: + """Append should work even when the primary key is specified.""" + data = [ + {"doc_id": 1, "content": "1"}, + {"doc_id": 2, "content": "2"}, + {"doc_id": 3, "content": "3"}, + ] + + @dlt.resource(primary_key="doc_id") + def some_data() -> Generator[List[DictStrAny], Any, None]: + yield data + + lancedb_adapter( + some_data, + embed=["content"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + info = pipeline.run( + some_data(), + ) + + assert_table(pipeline, "some_data", items=data) + + info = pipeline.run( + some_data(), + write_disposition="append", + ) + assert_load_info(info) + + data.extend(data) + assert_table(pipeline, "some_data", items=data) + + +def test_pipeline_replace() -> None: + generator_instance1 = sequence_generator() + generator_instance2 = sequence_generator() + + @dlt.resource + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + lancedb_adapter( + some_data, + embed=["content"], + ) + + uid = uniq_id() + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_replace", + destination="lancedb", + dataset_name="test_pipeline_replace_dataset" + + uid, # lancedb doesn't mandate any name normalization + ) + + info = pipeline.run( + some_data(), + write_disposition="replace", + ) + assert_load_info(info) + assert info.dataset_name == f"test_pipeline_replace_dataset{uid}" + + data = next(generator_instance2) + assert_table(pipeline, "some_data", items=data) + + info = pipeline.run( + some_data(), + write_disposition="replace", + ) + assert_load_info(info) + + data = next(generator_instance2) + assert_table(pipeline, "some_data", items=data) + + +def test_pipeline_merge() -> None: + data = [ + { + "doc_id": 1, + "merge_id": "shawshank-redemption-1994", + "title": "The Shawshank Redemption", + "description": ( + "Two imprisoned men find redemption through acts of decency over the years." + ), + }, + { + "doc_id": 2, + "merge_id": "the-godfather-1972", + "title": "The Godfather", + "description": ( + "A crime dynasty's aging patriarch transfers control to his reluctant son." + ), + }, + { + "doc_id": 3, + "merge_id": "the-dark-knight-2008", + "title": "The Dark Knight", + "description": ( + "The Joker wreaks havoc on Gotham, challenging The Dark Knight's ability to fight" + " injustice." + ), + }, + { + "doc_id": 4, + "merge_id": "pulp-fiction-1994", + "title": "Pulp Fiction", + "description": ( + "The lives of two mob hitmen, a boxer, a gangster and his wife, and a pair of diner" + " bandits intertwine in four tales of violence and redemption." + ), + }, + { + "doc_id": 5, + "merge_id": "schindlers-list-1993", + "title": "Schindler's List", + "description": ( + "In German-occupied Poland during World War II, industrialist Oskar Schindler" + " gradually becomes concerned for his Jewish workforce after witnessing their" + " persecution by the Nazis." + ), + }, + { + "doc_id": 6, + "merge_id": "the-lord-of-the-rings-the-return-of-the-king-2003", + "title": "The Lord of the Rings: The Return of the King", + "description": ( + "Gandalf and Aragorn lead the World of Men against Sauron's army to draw his gaze" + " from Frodo and Sam as they approach Mount Doom with the One Ring." + ), + }, + { + "doc_id": 7, + "merge_id": "the-matrix-1999", + "title": "The Matrix", + "description": ( + "A computer hacker learns from mysterious rebels about the true nature of his" + " reality and his role in the war against its controllers." + ), + }, + ] + + @dlt.resource(primary_key="doc_id") + def movies_data() -> Any: + yield data + + @dlt.resource(primary_key="doc_id", merge_key=["merge_id", "title"]) + def movies_data_explicit_merge_keys() -> Any: + yield data + + lancedb_adapter( + movies_data, + embed=["description"], + ) + + lancedb_adapter( + movies_data_explicit_merge_keys, + embed=["description"], + ) + + pipeline = dlt.pipeline( + pipeline_name="movies", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + info = pipeline.run( + movies_data(), + write_disposition="merge", + dataset_name=f"MoviesDataset{uniq_id()}", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data", items=data) + + # Change some data. + data[0]["title"] = "The Shawshank Redemption 2" + + info = pipeline.run( + movies_data(), + write_disposition="merge", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data", items=data) + + info = pipeline.run( + movies_data(), + write_disposition="merge", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data", items=data) + + # Test with explicit merge keys. + info = pipeline.run( + movies_data_explicit_merge_keys(), + write_disposition="merge", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data_explicit_merge_keys", items=data) + + +def test_pipeline_with_schema_evolution() -> None: + data = [ + { + "doc_id": 1, + "content": "1", + }, + { + "doc_id": 2, + "content": "2", + }, + ] + + @dlt.resource() + def some_data() -> Generator[List[DictStrAny], Any, None]: + yield data + + lancedb_adapter(some_data, embed=["content"]) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestSchemaEvolutionDataset{uniq_id()}", + ) + pipeline.run( + some_data(), + ) + + assert_table(pipeline, "some_data", items=data) + + aggregated_data = data.copy() + + data = [ + { + "doc_id": 3, + "content": "3", + "new_column": "new", + }, + { + "doc_id": 4, + "content": "4", + "new_column": "new", + }, + ] + + pipeline.run( + some_data(), + ) + + table_schema = pipeline.default_schema.tables["some_data"] + assert "new_column" in table_schema["columns"] + + aggregated_data.extend(data) + + assert_table(pipeline, "some_data", items=aggregated_data) + + +def test_merge_github_nested() -> None: + pipe = dlt.pipeline(destination="lancedb", dataset_name="github1", full_refresh=True) + assert pipe.dataset_name.startswith("github1_202") + + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", + ) as f: + data = json.load(f) + + info = pipe.run( + lancedb_adapter(data[:17], embed=["title", "body"]), + table_name="issues", + write_disposition="merge", + primary_key="id", + ) + assert_load_info(info) + # assert if schema contains tables with right names + print(pipe.default_schema.tables.keys()) + assert set(pipe.default_schema.tables.keys()) == { + "_dlt_version", + "_dlt_loads", + "issues", + "_dlt_pipeline_state", + "issues__labels", + "issues__assignees", + } + assert {t["name"] for t in pipe.default_schema.data_tables()} == { + "issues", + "issues__labels", + "issues__assignees", + } + assert {t["name"] for t in pipe.default_schema.dlt_tables()} == { + "_dlt_version", + "_dlt_loads", + "_dlt_pipeline_state", + } + issues = pipe.default_schema.tables["issues"] + assert issues["columns"]["id"]["primary_key"] is True + # Make sure vectorization is enabled for. + assert issues["columns"]["title"][VECTORIZE_HINT] # type: ignore[literal-required] + assert issues["columns"]["body"][VECTORIZE_HINT] # type: ignore[literal-required] + assert VECTORIZE_HINT not in issues["columns"]["url"] + assert_table(pipe, "issues", expected_items_count=17) + + +def test_empty_dataset_allowed() -> None: + # dataset_name is optional so dataset name won't be autogenerated when not explicitly passed. + pipe = dlt.pipeline(destination="lancedb", full_refresh=True) + client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] + + assert pipe.dataset_name is None + info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) + # Dataset in load info is empty. + assert info.dataset_name is None + client = pipe.destination_client() # type: ignore[assignment] + assert client.dataset_name is None + assert client.sentinel_table == "dltSentinelTable" + assert_table(pipe, "content", expected_items_count=3) diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py new file mode 100644 index 0000000000..dc3ea5304b --- /dev/null +++ b/tests/load/lancedb/utils.py @@ -0,0 +1,74 @@ +from typing import Union, List, Any, Dict + +import numpy as np +from lancedb.embeddings import TextEmbeddingFunction # type: ignore + +import dlt +from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient + + +def assert_unordered_dicts_equal( + dict_list1: List[Dict[str, Any]], dict_list2: List[Dict[str, Any]] +) -> None: + """ + Assert that two lists of dictionaries contain the same dictionaries, ignoring None values. + + Args: + dict_list1 (List[Dict[str, Any]]): The first list of dictionaries to compare. + dict_list2 (List[Dict[str, Any]]): The second list of dictionaries to compare. + + Raises: + AssertionError: If the lists have different lengths or contain different dictionaries. + """ + assert len(dict_list1) == len(dict_list2), "Lists have different length" + + dict_set1 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1} + dict_set2 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2} + + assert dict_set1 == dict_set2, "Lists contain different dictionaries" + + +def assert_table( + pipeline: dlt.Pipeline, + table_name: str, + expected_items_count: int = None, + items: List[Any] = None, +) -> None: + client: LanceDBClient = pipeline.destination_client() # type: ignore[assignment] + qualified_table_name = client.make_qualified_table_name(table_name) + + exists = client.table_exists(qualified_table_name) + assert exists + + records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + + if expected_items_count is not None: + assert expected_items_count == len(records) + + if items is None: + return + + drop_keys = [ + "_dlt_id", + "_dlt_load_id", + dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + ] + objects_without_dlt_or_special_keys = [ + {k: v for k, v in record.items() if k not in drop_keys} for record in records + ] + + assert_unordered_dicts_equal(objects_without_dlt_or_special_keys, items) + + +class MockEmbeddingFunc(TextEmbeddingFunction): + def generate_embeddings( + self, + texts: Union[List[str], np.ndarray], # type: ignore[type-arg] + *args, + **kwargs, + ) -> List[np.ndarray]: # type: ignore[type-arg] + return [np.array(None)] + + def ndims(self) -> int: + return 2 diff --git a/tests/load/mssql/test_mssql_credentials.py b/tests/load/mssql/test_mssql_configuration.py similarity index 77% rename from tests/load/mssql/test_mssql_credentials.py rename to tests/load/mssql/test_mssql_configuration.py index 7d49196531..75af101e23 100644 --- a/tests/load/mssql/test_mssql_credentials.py +++ b/tests/load/mssql/test_mssql_configuration.py @@ -1,15 +1,46 @@ +import os import pyodbc import pytest from dlt.common.configuration import resolve_configuration, ConfigFieldMissingException from dlt.common.exceptions import SystemConfigurationException +from dlt.common.schema import Schema -from dlt.destinations.impl.mssql.configuration import MsSqlCredentials +from dlt.destinations import mssql +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration # mark all tests as essential, do not remove pytestmark = pytest.mark.essential +def test_mssql_factory() -> None: + schema = Schema("schema") + dest = mssql() + client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is False + assert client.config.has_case_sensitive_identifiers is False + assert client.capabilities.has_case_sensitive_identifiers is False + assert client.capabilities.casefold_identifier is str + + # set args explicitly + dest = mssql(has_case_sensitive_identifiers=True, create_indexes=True) + client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + # set args via config + os.environ["DESTINATION__CREATE_INDEXES"] = "True" + os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True" + dest = mssql() + client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + def test_mssql_credentials_defaults() -> None: creds = MsSqlCredentials() assert creds.port == 1433 diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index f7a87c14ee..d6cf3ec3e8 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -6,7 +6,8 @@ pytest.importorskip("dlt.destinations.impl.mssql.mssql", reason="MSSQL ODBC driver not installed") -from dlt.destinations.impl.mssql.mssql import MsSqlClient +from dlt.destinations import mssql +from dlt.destinations.impl.mssql.mssql import MsSqlJobClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials from tests.load.utils import TABLE_UPDATE, empty_schema @@ -16,9 +17,9 @@ @pytest.fixture -def client(empty_schema: Schema) -> MsSqlClient: +def client(empty_schema: Schema) -> MsSqlJobClient: # return client without opening connection - return MsSqlClient( + return mssql().client( empty_schema, MsSqlClientConfiguration(credentials=MsSqlCredentials())._bind_dataset_name( dataset_name="test_" + uniq_id() @@ -26,7 +27,7 @@ def client(empty_schema: Schema) -> MsSqlClient: ) -def test_create_table(client: MsSqlClient) -> None: +def test_create_table(client: MsSqlJobClient) -> None: # non existing table sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] sqlfluff.parse(sql, dialect="tsql") @@ -50,7 +51,7 @@ def test_create_table(client: MsSqlClient) -> None: assert '"col11_precision" time(3) NOT NULL' in sql -def test_alter_table(client: MsSqlClient) -> None: +def test_alter_table(client: MsSqlJobClient) -> None: # existing table has no columns sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0] sqlfluff.parse(sql, dialect="tsql") diff --git a/tests/load/pipeline/conftest.py b/tests/load/pipeline/conftest.py index 34227a8041..a2ba65494b 100644 --- a/tests/load/pipeline/conftest.py +++ b/tests/load/pipeline/conftest.py @@ -1,8 +1,2 @@ -from tests.utils import ( - patch_home_dir, - preserve_environ, - autouse_test_storage, - duckdb_pipeline_location, -) +from tests.utils import autouse_test_storage, duckdb_pipeline_location from tests.pipeline.utils import drop_dataset_from_env -from tests.load.pipeline.utils import drop_pipeline diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 0bddfaabee..630d84a28c 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -9,14 +9,14 @@ import dlt from dlt.common import pendulum -from dlt.common.time import reduce_pendulum_datetime_precision, ensure_pendulum_datetime +from dlt.common.time import reduce_pendulum_datetime_precision from dlt.common.utils import uniq_id + from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.pipeline.utils import assert_load_info, select_data from tests.utils import ( TestDataItemFormat, arrow_item_from_pandas, - preserve_environ, TPythonTableFormat, ) from tests.cases import arrow_table_all_data_types diff --git a/tests/load/pipeline/test_athena.py b/tests/load/pipeline/test_athena.py index 272cc701d5..3197a19d14 100644 --- a/tests/load/pipeline/test_athena.py +++ b/tests/load/pipeline/test_athena.py @@ -9,15 +9,15 @@ from tests.pipeline.utils import assert_load_info, load_table_counts from tests.pipeline.utils import load_table_counts from dlt.destinations.exceptions import CantExtractTablePrefix -from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter -from dlt.destinations.fs_client import FSClientBase +from dlt.destinations.adapters import athena_partition, athena_adapter -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration from tests.load.utils import ( TEST_FILE_LAYOUTS, FILE_LAYOUT_MANY_TABLES_ONE_FOLDER, FILE_LAYOUT_CLASSIC, FILE_LAYOUT_TABLE_NOT_FIRST, + destinations_configs, + DestinationTestConfiguration, ) # mark all tests as essential, do not remove @@ -208,7 +208,7 @@ def my_source() -> Any: @pytest.mark.parametrize("layout", TEST_FILE_LAYOUTS) def test_athena_file_layouts(destination_config: DestinationTestConfiguration, layout) -> None: # test wether strange file layouts still work in all staging configs - pipeline = destination_config.setup_pipeline("athena_file_layout", full_refresh=True) + pipeline = destination_config.setup_pipeline("athena_file_layout", dev_mode=True) os.environ["DESTINATION__FILESYSTEM__LAYOUT"] = layout resources = [ @@ -242,7 +242,7 @@ def test_athena_file_layouts(destination_config: DestinationTestConfiguration, l ) def test_athena_partitioned_iceberg_table(destination_config: DestinationTestConfiguration): """Load an iceberg table with partition hints and verifiy partitions are created correctly.""" - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), dev_mode=True) data_items = [ (1, "A", datetime.date.fromisoformat("2021-01-01")), diff --git a/tests/load/pipeline/test_bigquery.py b/tests/load/pipeline/test_bigquery.py index 68533a5d43..0618ff9d3d 100644 --- a/tests/load/pipeline/test_bigquery.py +++ b/tests/load/pipeline/test_bigquery.py @@ -3,8 +3,7 @@ from dlt.common import Decimal from tests.pipeline.utils import assert_load_info -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration -from tests.load.utils import delete_dataset +from tests.load.utils import destinations_configs, DestinationTestConfiguration # mark all tests as essential, do not remove pytestmark = pytest.mark.essential diff --git a/tests/load/pipeline/test_clickhouse.py b/tests/load/pipeline/test_clickhouse.py index 2ba5cfdcb8..8ad3a7f1a7 100644 --- a/tests/load/pipeline/test_clickhouse.py +++ b/tests/load/pipeline/test_clickhouse.py @@ -5,10 +5,7 @@ import dlt from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id -from tests.load.pipeline.utils import ( - destinations_configs, - DestinationTestConfiguration, -) +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.pipeline.utils import load_table_counts @@ -18,7 +15,7 @@ ids=lambda x: x.name, ) def test_clickhouse_destination_append(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"clickhouse_{uniq_id()}", full_refresh=True) + pipeline = destination_config.setup_pipeline(f"clickhouse_{uniq_id()}", dev_mode=True) try: diff --git a/tests/load/pipeline/test_csv_loading.py b/tests/load/pipeline/test_csv_loading.py new file mode 100644 index 0000000000..6a2be2eb40 --- /dev/null +++ b/tests/load/pipeline/test_csv_loading.py @@ -0,0 +1,172 @@ +import os +from typing import List +import pytest + +import dlt +from dlt.common.data_writers.configuration import CsvFormatConfiguration +from dlt.common.schema.typing import TColumnSchema +from dlt.common.typing import TLoaderFileFormat +from dlt.common.utils import uniq_id + +from tests.cases import arrow_table_all_data_types, prepare_shuffled_tables +from tests.pipeline.utils import ( + assert_data_table_counts, + assert_load_info, + assert_only_table_columns, + load_tables_to_dicts, +) +from tests.load.utils import destinations_configs, DestinationTestConfiguration +from tests.utils import TestDataItemFormat + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("item_type", ["object", "table"]) +def test_load_csv( + destination_config: DestinationTestConfiguration, item_type: TestDataItemFormat +) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + # do not save state so the state job is not created + pipeline.config.restore_from_destination = False + + table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() + # convert to pylist when loading from objects, this will kick the csv-reader in + if item_type == "object": + table, shuffled_table, shuffled_removed_column = ( + table.to_pylist(), + shuffled_table.to_pylist(), + shuffled_removed_column.to_pylist(), + ) + + load_info = pipeline.run( + [shuffled_removed_column, shuffled_table, table], + table_name="table", + loader_file_format="csv", + ) + assert_load_info(load_info) + job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + assert_data_table_counts(pipeline, {"table": 5432 * 3}) + load_tables_to_dicts(pipeline, "table") + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("file_format", (None, "csv")) +@pytest.mark.parametrize("compression", (True, False)) +def test_custom_csv_no_header( + destination_config: DestinationTestConfiguration, + file_format: TLoaderFileFormat, + compression: bool, +) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = str(not compression) + csv_format = CsvFormatConfiguration(delimiter="|", include_header=False) + # apply to collected config + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + # this will apply this to config when client instance is created + pipeline.destination.config_params["csv_format"] = csv_format + # verify + assert pipeline.destination_client().config.csv_format == csv_format # type: ignore[attr-defined] + # create a resource that imports file + + columns: List[TColumnSchema] = [ + {"name": "id", "data_type": "bigint"}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, + ] + hints = dlt.mark.make_hints(columns=columns) + import_file = "tests/load/cases/loading/csv_no_header.csv" + if compression: + import_file += ".gz" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "csv", 2, hints=hints)], + table_name="no_header", + loader_file_format=file_format, + ) + info.raise_on_failed_jobs() + print(info) + assert_only_table_columns(pipeline, "no_header", [col["name"] for col in columns]) + rows = load_tables_to_dicts(pipeline, "no_header") + assert len(rows["no_header"]) == 2 + # we should have twp files loaded + jobs = info.load_packages[0].jobs["completed_jobs"] + assert len(jobs) == 2 + job_extensions = [os.path.splitext(job.job_file_info.file_name())[1] for job in jobs] + assert ".csv" in job_extensions + # we allow state to be saved to make sure it is not in csv format (which would broke) + # the loading. state is always saved in destination preferred format + preferred_ext = "." + pipeline.destination.capabilities().preferred_loader_file_format + assert preferred_ext in job_extensions + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +def test_custom_wrong_header(destination_config: DestinationTestConfiguration) -> None: + csv_format = CsvFormatConfiguration(delimiter="|", include_header=True) + # apply to collected config + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + # this will apply this to config when client instance is created + pipeline.destination.config_params["csv_format"] = csv_format + # verify + assert pipeline.destination_client().config.csv_format == csv_format # type: ignore[attr-defined] + # create a resource that imports file + + columns: List[TColumnSchema] = [ + {"name": "object_id", "data_type": "bigint", "nullable": False}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, + ] + hints = dlt.mark.make_hints(columns=columns) + import_file = "tests/load/cases/loading/csv_header.csv" + # snowflake will pass here because we do not match + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "csv", 2, hints=hints)], + table_name="no_header", + ) + assert info.has_failed_jobs + assert len(info.load_packages[0].jobs["failed_jobs"]) == 1 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +def test_empty_csv_from_arrow(destination_config: DestinationTestConfiguration) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + os.environ["RESTORE_FROM_DESTINATION"] = "False" + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + table, _, _ = arrow_table_all_data_types("arrow-table", include_json=False) + + load_info = pipeline.run( + table.schema.empty_table(), table_name="arrow_table", loader_file_format="csv" + ) + assert_load_info(load_info) + assert len(load_info.load_packages[0].jobs["completed_jobs"]) == 1 + job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + assert_data_table_counts(pipeline, {"arrow_table": 0}) + with pipeline.sql_client() as client: + with client.execute_query("SELECT * FROM arrow_table") as cur: + columns = [col.name for col in cur.description] + assert len(cur.fetchall()) == 0 + + # all columns in order, also casefold to the destination casing (we use cursor.description) + casefold = pipeline.destination.capabilities().casefold_identifier + assert columns == list( + map(casefold, pipeline.default_schema.get_table_columns("arrow_table").keys()) + ) diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 1dc225594f..86ee1a646e 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -11,8 +11,8 @@ from dlt.helpers.dbt.exceptions import DBTProcessingError, PrerequisitesException from tests.pipeline.utils import select_data +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.utils import ACTIVE_SQL_DESTINATIONS -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts diff --git a/tests/load/pipeline/test_dremio.py b/tests/load/pipeline/test_dremio.py index 9a4c96c922..66d1b0be4f 100644 --- a/tests/load/pipeline/test_dremio.py +++ b/tests/load/pipeline/test_dremio.py @@ -12,9 +12,7 @@ ids=lambda x: x.name, ) def test_dremio(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline( - "dremio-test", dataset_name="bar", full_refresh=True - ) + pipeline = destination_config.setup_pipeline("dremio-test", dataset_name="bar", dev_mode=True) @dlt.resource(name="items", write_disposition="replace") def items() -> Iterator[Any]: diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index 313ba63a2c..e1c6ec9d79 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -17,11 +17,11 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.utils import destinations_configs, DestinationTestConfiguration def _attach(pipeline: Pipeline) -> Pipeline: - return dlt.attach(pipeline.pipeline_name, pipeline.pipelines_dir) + return dlt.attach(pipeline.pipeline_name, pipelines_dir=pipeline.pipelines_dir) @dlt.source(section="droppable", name="droppable") @@ -91,13 +91,14 @@ def assert_dropped_resource_tables(pipeline: Pipeline, resources: List[str]) -> client: SqlJobClientBase with pipeline.destination_client(pipeline.default_schema_name) as client: # type: ignore[assignment] # Check all tables supposed to be dropped are not in dataset - for table in dropped_tables: - exists, _ = client.get_storage_table(table) - assert not exists + storage_tables = list(client.get_storage_tables(dropped_tables)) + # no columns in all tables + assert all(len(table[1]) == 0 for table in storage_tables) + # Check tables not from dropped resources still exist - for table in expected_tables: - exists, _ = client.get_storage_table(table) - assert exists + storage_tables = list(client.get_storage_tables(expected_tables)) + # all tables have columns + assert all(len(table[1]) > 0 for table in storage_tables) def assert_dropped_resource_states(pipeline: Pipeline, resources: List[str]) -> None: @@ -178,7 +179,7 @@ def test_drop_command_only_state(destination_config: DestinationTestConfiguratio def test_drop_command_only_tables(destination_config: DestinationTestConfiguration) -> None: """Test drop only tables and makes sure that schema and state are synced""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) pipeline.run(source) sources_state = pipeline.state["sources"] @@ -334,9 +335,8 @@ def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None # Verify original _dlt tables were not deleted with attached._sql_job_client(attached.default_schema) as client: - for tbl in dlt_tables: - exists, _ = client.get_storage_table(tbl) - assert exists + storage_tables = list(client.get_storage_tables(dlt_tables)) + assert all(len(table[1]) > 0 for table in storage_tables) @pytest.mark.parametrize( diff --git a/tests/load/pipeline/test_duckdb.py b/tests/load/pipeline/test_duckdb.py index 3f9821cee0..3dcfffe348 100644 --- a/tests/load/pipeline/test_duckdb.py +++ b/tests/load/pipeline/test_duckdb.py @@ -1,16 +1,14 @@ import pytest import os +from dlt.common.schema.exceptions import SchemaIdentifierNormalizationCollision from dlt.common.time import ensure_pendulum_datetime from dlt.destinations.exceptions import DatabaseTerminalException from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import TABLE_UPDATE_ALL_INT_PRECISIONS, TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.pipeline.utils import airtable_emojis, load_table_counts -from tests.load.pipeline.utils import ( - destinations_configs, - DestinationTestConfiguration, -) @pytest.mark.parametrize( @@ -44,7 +42,7 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No "🦚Peacock__peacock": 3, "🦚Peacocks🦚": 1, "🦚WidePeacock": 1, - "🦚WidePeacock__peacock": 3, + "🦚WidePeacock__Peacock": 3, } # this will fail - duckdb preserves case but is case insensitive when comparing identifiers @@ -54,7 +52,10 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No table_name="🦚peacocks🦚", loader_file_format=destination_config.file_format, ) - assert isinstance(pip_ex.value.__context__, DatabaseTerminalException) + assert isinstance(pip_ex.value.__context__, SchemaIdentifierNormalizationCollision) + assert pip_ex.value.__context__.conflict_identifier_name == "🦚Peacocks🦚" + assert pip_ex.value.__context__.identifier_name == "🦚peacocks🦚" + assert pip_ex.value.__context__.identifier_type == "table" # show tables and columns with pipeline.sql_client() as client: diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index efbdc082f1..210ad76b8a 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -14,13 +14,14 @@ from dlt.common.utils import uniq_id from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient +from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import arrow_table_all_data_types, table_update_and_row, assert_all_data_types_row from tests.common.utils import load_json_case from tests.utils import ALL_TEST_DATA_ITEM_FORMATS, TestDataItemFormat, skip_if_not_active from dlt.destinations.path_utils import create_path -from tests.load.pipeline.utils import ( +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, ) @@ -34,7 +35,7 @@ @pytest.fixture def local_filesystem_pipeline() -> dlt.Pipeline: os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "_storage" - return dlt.pipeline(pipeline_name="fs_pipe", destination="filesystem", full_refresh=True) + return dlt.pipeline(pipeline_name="fs_pipe", destination="filesystem", dev_mode=True) def test_pipeline_merge_write_disposition(default_buckets_env: str) -> None: @@ -499,7 +500,7 @@ def count(*args, **kwargs) -> Any: return count - extra_placeholders = { + extra_placeholders: TExtraPlaceholders = { "who": "marcin", "action": "says", "what": "no potato", @@ -653,8 +654,8 @@ def some_data(): # test accessors for state s1 = c1.get_stored_state("p1") s2 = c1.get_stored_state("p2") - assert s1.dlt_load_id == load_id_1_2 # second load - assert s2.dlt_load_id == load_id_2_1 # first load + assert s1._dlt_load_id == load_id_1_2 # second load + assert s2._dlt_load_id == load_id_2_1 # first load assert s1_old.version != s1.version assert s2_old.version == s2.version @@ -797,13 +798,15 @@ def table_3(): # check opening of file values = [] - for line in fs_client.read_text(t1_files[0]).split("\n"): + for line in fs_client.read_text(t1_files[0], encoding="utf-8").split("\n"): if line: values.append(json.loads(line)["value"]) assert values == [1, 2, 3, 4, 5] # check binary read - assert fs_client.read_bytes(t1_files[0]) == str.encode(fs_client.read_text(t1_files[0])) + assert fs_client.read_bytes(t1_files[0]) == str.encode( + fs_client.read_text(t1_files[0], encoding="utf-8") + ) # check truncate fs_client.truncate_tables(["table_1"]) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index a3f5083ae6..2c1d1346f1 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -19,7 +19,11 @@ from dlt.pipeline.exceptions import PipelineStepFailed from tests.pipeline.utils import assert_load_info, load_table_counts, select_data -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.utils import ( + normalize_storage_table_cols, + destinations_configs, + DestinationTestConfiguration, +) # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts @@ -38,7 +42,7 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio # make block uncles unseen to trigger filtering loader in loader for child tables if has_table_seen_data(schema.tables["blocks__uncles"]): - del schema.tables["blocks__uncles"]["x-normalizer"] # type: ignore[typeddict-item] + del schema.tables["blocks__uncles"]["x-normalizer"] assert not has_table_seen_data(schema.tables["blocks__uncles"]) with open( @@ -307,9 +311,10 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) assert github_2_counts["issues"] == 100 - 45 + 1 with p._sql_job_client(p.default_schema) as job_c: - _, table_schema = job_c.get_storage_table("issues") - assert "url" in table_schema - assert "m_a1" not in table_schema # unbound columns were not created + _, storage_cols = job_c.get_storage_table("issues") + storage_cols = normalize_storage_table_cols("issues", storage_cols, p.default_schema) + assert "url" in storage_cols + assert "m_a1" not in storage_cols # unbound columns were not created @pytest.mark.parametrize( @@ -319,6 +324,8 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf ) def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", dev_mode=True) + # do not save state to destination so jobs counting is easier + p.config.restore_from_destination = False github_data = github() # generate some complex types github_data.max_table_nesting = 2 @@ -985,7 +992,7 @@ def test_invalid_merge_strategy(destination_config: DestinationTestConfiguration def r(): yield {"foo": "bar"} - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) with pytest.raises(PipelineStepFailed) as pip_ex: p.run(r()) assert isinstance(pip_ex.value.__context__, SchemaException) diff --git a/tests/load/test_parallelism.py b/tests/load/pipeline/test_parallelism.py similarity index 98% rename from tests/load/test_parallelism.py rename to tests/load/pipeline/test_parallelism.py index a1a09a4d6b..656357fb00 100644 --- a/tests/load/test_parallelism.py +++ b/tests/load/pipeline/test_parallelism.py @@ -55,7 +55,7 @@ def t() -> TDataItems: yield {"num": i} # we load n items for 3 tables in one run - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) p.run( [ dlt.resource(table_name="t1")(t), diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index ad44cd6f5c..a12c29168f 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -13,7 +13,8 @@ from dlt.common.destination.reference import WithStagingDataset from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import VERSION_TABLE_NAME +from dlt.common.schema.typing import PIPELINE_STATE_TABLE_NAME, VERSION_TABLE_NAME +from dlt.common.schema.utils import pipeline_state_table from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id @@ -26,7 +27,7 @@ PipelineStepFailed, ) -from tests.utils import TEST_STORAGE_ROOT, data_to_item_format, preserve_environ +from tests.utils import TEST_STORAGE_ROOT, data_to_item_format from tests.pipeline.utils import ( assert_data_table_counts, assert_load_info, @@ -40,12 +41,11 @@ TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row, delete_dataset, -) -from tests.load.pipeline.utils import ( drop_active_pipeline_data, - REPLACE_STRATEGIES, + destinations_configs, + DestinationTestConfiguration, ) -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.pipeline.utils import REPLACE_STRATEGIES # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -137,10 +137,27 @@ def data_fun() -> Iterator[Any]: destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), ids=lambda x: x.name, ) -def test_default_schema_name(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("use_single_dataset", [True, False]) +@pytest.mark.parametrize( + "naming_convention", + [ + "duck_case", + "snake_case", + "sql_cs_v1", + ], +) +def test_default_schema_name( + destination_config: DestinationTestConfiguration, + use_single_dataset: bool, + naming_convention: str, +) -> None: + os.environ["SCHEMA__NAMING"] = naming_convention destination_config.setup() dataset_name = "dataset_" + uniq_id() - data = ["a", "b", "c"] + data = [ + {"id": idx, "CamelInfo": uniq_id(), "GEN_ERIC": alpha} + for idx, alpha in [(0, "A"), (0, "B"), (0, "C")] + ] p = dlt.pipeline( "test_default_schema_name", @@ -149,16 +166,25 @@ def test_default_schema_name(destination_config: DestinationTestConfiguration) - staging=destination_config.staging, dataset_name=dataset_name, ) + p.config.use_single_dataset = use_single_dataset p.extract(data, table_name="test", schema=Schema("default")) p.normalize() info = p.load() + print(info) # try to restore pipeline r_p = dlt.attach("test_default_schema_name", TEST_STORAGE_ROOT) schema = r_p.default_schema assert schema.name == "default" - assert_table(p, "test", data, info=info) + # check if dlt ables have exactly the required schemas + # TODO: uncomment to check dlt tables schemas + # assert ( + # r_p.default_schema.tables[PIPELINE_STATE_TABLE_NAME]["columns"] + # == pipeline_state_table()["columns"] + # ) + + # assert_table(p, "test", data, info=info) @pytest.mark.parametrize( @@ -947,8 +973,7 @@ def table_3(make_data=False): load_table_counts(pipeline, "table_3") assert "x-normalizer" not in pipeline.default_schema.tables["table_3"] assert ( - pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True + pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] is True ) # load with one empty job, table 3 not created @@ -990,18 +1015,9 @@ def table_3(make_data=False): # print(v5) # check if seen data is market correctly - assert ( - pipeline.default_schema.tables["table_3"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True - ) - assert ( - pipeline.default_schema.tables["table_2"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True - ) - assert ( - pipeline.default_schema.tables["table_1"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True - ) + assert pipeline.default_schema.tables["table_3"]["x-normalizer"]["seen-data"] is True + assert pipeline.default_schema.tables["table_2"]["x-normalizer"]["seen-data"] is True + assert pipeline.default_schema.tables["table_1"]["x-normalizer"]["seen-data"] is True job_client, _ = pipeline._get_destination_clients(schema) diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index a64ee300cd..a4001b7faa 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -6,45 +6,11 @@ from dlt.common.utils import uniq_id -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration -from tests.cases import arrow_table_all_data_types, prepare_shuffled_tables -from tests.pipeline.utils import assert_data_table_counts, assert_load_info, load_tables_to_dicts +from tests.load.utils import destinations_configs, DestinationTestConfiguration +from tests.pipeline.utils import assert_load_info, load_tables_to_dicts from tests.utils import TestDataItemFormat -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -@pytest.mark.parametrize("item_type", ["object", "table"]) -def test_postgres_load_csv( - destination_config: DestinationTestConfiguration, item_type: TestDataItemFormat -) -> None: - os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" - pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) - table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() - - # convert to pylist when loading from objects, this will kick the csv-reader in - if item_type == "object": - table, shuffled_table, shuffled_removed_column = ( - table.to_pylist(), - shuffled_table.to_pylist(), - shuffled_removed_column.to_pylist(), - ) - - load_info = pipeline.run( - [shuffled_removed_column, shuffled_table, table], - table_name="table", - loader_file_format="csv", - ) - assert_load_info(load_info) - job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path - assert job.endswith("csv") - assert_data_table_counts(pipeline, {"table": 5432 * 3}) - load_tables_to_dicts(pipeline, "table") - - @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, subset=["postgres"]), @@ -64,7 +30,7 @@ def test_postgres_encoded_binary( blob_table = blob_table.to_pylist() print(blob_table) - pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) load_info = pipeline.run(blob_table, table_name="table", loader_file_format="csv") assert_load_info(load_info) job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path @@ -76,31 +42,3 @@ def test_postgres_encoded_binary( # print(bytes(data["table"][0]["hash"])) # data in postgres equals unencoded blob assert data["table"][0]["hash"].tobytes() == blob - - -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_postgres_empty_csv_from_arrow(destination_config: DestinationTestConfiguration) -> None: - os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" - os.environ["RESTORE_FROM_DESTINATION"] = "False" - pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) - table, _, _ = arrow_table_all_data_types("arrow-table", include_json=False) - - load_info = pipeline.run( - table.schema.empty_table(), table_name="table", loader_file_format="csv" - ) - assert_load_info(load_info) - assert len(load_info.load_packages[0].jobs["completed_jobs"]) == 1 - job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path - assert job.endswith("csv") - assert_data_table_counts(pipeline, {"table": 0}) - with pipeline.sql_client() as client: - with client.execute_query('SELECT * FROM "table"') as cur: - columns = [col.name for col in cur.description] - assert len(cur.fetchall()) == 0 - - # all columns in order - assert columns == list(pipeline.default_schema.get_table_columns("table").keys()) diff --git a/tests/load/pipeline/test_redshift.py b/tests/load/pipeline/test_redshift.py index 29293693f5..bfdc15459c 100644 --- a/tests/load/pipeline/test_redshift.py +++ b/tests/load/pipeline/test_redshift.py @@ -4,7 +4,7 @@ import dlt from dlt.common.utils import uniq_id -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.cases import table_update_and_row, assert_all_data_types_row from tests.pipeline.utils import assert_load_info diff --git a/tests/load/pipeline/test_refresh_modes.py b/tests/load/pipeline/test_refresh_modes.py index 02ed560068..de557ba118 100644 --- a/tests/load/pipeline/test_refresh_modes.py +++ b/tests/load/pipeline/test_refresh_modes.py @@ -8,7 +8,7 @@ from dlt.common.typing import DictStrAny from dlt.common.pipeline import pipeline_state as current_pipeline_state -from tests.utils import clean_test_storage, preserve_environ +from tests.utils import clean_test_storage from tests.pipeline.utils import ( assert_load_info, load_tables_to_dicts, diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 464b5aea1f..12bc69abe0 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -4,12 +4,12 @@ from dlt.common.utils import uniq_id from tests.pipeline.utils import assert_load_info, load_table_counts, load_tables_to_dicts -from tests.load.pipeline.utils import ( +from tests.load.utils import ( drop_active_pipeline_data, destinations_configs, DestinationTestConfiguration, - REPLACE_STRATEGIES, ) +from tests.load.pipeline.utils import REPLACE_STRATEGIES @pytest.mark.essential diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index b287619e8c..37f999ff86 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -6,7 +6,9 @@ import dlt from dlt.common import pendulum -from dlt.common.schema.schema import Schema +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.schema.schema import Schema, utils +from dlt.common.schema.utils import normalize_table_identifiers from dlt.common.utils import uniq_id from dlt.common.destination.exceptions import DestinationUndefinedEntity @@ -14,7 +16,6 @@ from dlt.pipeline.exceptions import SqlClientNotAvailable from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import ( - STATE_TABLE_COLUMNS, load_pipeline_state_from_destination, state_resource, ) @@ -24,12 +25,12 @@ from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_DECODED from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9, yml_case_path as common_yml_case_path from tests.common.configuration.utils import environment -from tests.load.pipeline.utils import drop_active_pipeline_data from tests.pipeline.utils import assert_query_data from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, get_normalized_dataset_name, + drop_active_pipeline_data, ) @@ -77,15 +78,17 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - initial_state["_local"]["_last_extracted_at"] = pendulum.now() initial_state["_local"]["_last_extracted_hash"] = initial_state["_version_hash"] # add _dlt_id and _dlt_load_id - resource, _ = state_resource(initial_state) + resource, _ = state_resource(initial_state, "not_used_load_id") resource.apply_hints( columns={ "_dlt_id": {"name": "_dlt_id", "data_type": "text", "nullable": False}, "_dlt_load_id": {"name": "_dlt_load_id", "data_type": "text", "nullable": False}, - **STATE_TABLE_COLUMNS, + **utils.pipeline_state_table()["columns"], } ) - schema.update_table(schema.normalize_table_identifiers(resource.compute_table_schema())) + schema.update_table( + normalize_table_identifiers(resource.compute_table_schema(), schema.naming) + ) # do not bump version here or in sync_schema, dlt won't recognize that schema changed and it won't update it in storage # so dlt in normalize stage infers _state_version table again but with different column order and the column order in schema is different # then in database. parquet is created in schema order and in Redshift it must exactly match the order. @@ -183,6 +186,7 @@ def test_silently_skip_on_invalid_credentials( destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) +@pytest.mark.essential @pytest.mark.parametrize( "destination_config", destinations_configs( @@ -191,13 +195,25 @@ def test_silently_skip_on_invalid_credentials( ids=lambda x: x.name, ) @pytest.mark.parametrize("use_single_dataset", [True, False]) +@pytest.mark.parametrize( + "naming_convention", + [ + "tests.common.cases.normalizers.title_case", + "snake_case", + ], +) def test_get_schemas_from_destination( - destination_config: DestinationTestConfiguration, use_single_dataset: bool + destination_config: DestinationTestConfiguration, + use_single_dataset: bool, + naming_convention: str, ) -> None: + set_naming_env(destination_config.destination, naming_convention) + pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) p.config.use_single_dataset = use_single_dataset def _make_dn_name(schema_name: str) -> str: @@ -268,18 +284,34 @@ def _make_dn_name(schema_name: str) -> str: assert len(restored_schemas) == 3 +@pytest.mark.essential @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, default_vector_configs=True, all_buckets_filesystem_configs=True + default_sql_configs=True, + all_staging_configs=True, + default_vector_configs=True, + all_buckets_filesystem_configs=True, ), ids=lambda x: x.name, ) -def test_restore_state_pipeline(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize( + "naming_convention", + [ + "tests.common.cases.normalizers.title_case", + "snake_case", + ], +) +def test_restore_state_pipeline( + destination_config: DestinationTestConfiguration, naming_convention: str +) -> None: + set_naming_env(destination_config.destination, naming_convention) + # enable restoring from destination os.environ["RESTORE_FROM_DESTINATION"] = "True" pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) def some_data_gen(param: str) -> Any: dlt.current.source_state()[param] = param @@ -366,7 +398,7 @@ def some_data(): p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) # now attach locally os.environ["RESTORE_FROM_DESTINATION"] = "True" - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) assert p.dataset_name == dataset_name assert p.default_schema_name is None # restore @@ -451,6 +483,9 @@ def test_restore_schemas_while_import_schemas_exist( # make sure schema got imported schema = p.schemas["ethereum"] assert "blocks" in schema.tables + # allow to modify tables even if naming convention is changed. some of the tables in ethereum schema + # have processing hints that lock the table schema. so when weaviate changes naming convention we have an exception + os.environ["SCHEMA__ALLOW_IDENTIFIER_CHANGE_ON_TABLE_WITH_DATA"] = "true" # extract some additional data to upgrade schema in the pipeline p.run( @@ -467,7 +502,7 @@ def test_restore_schemas_while_import_schemas_exist( assert normalized_labels in schema.tables # re-attach the pipeline - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) p.run( ["C", "D", "E"], table_name="annotations", loader_file_format=destination_config.file_format ) @@ -496,7 +531,7 @@ def test_restore_schemas_while_import_schemas_exist( assert normalized_annotations in schema.tables # check if attached to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() # extract some data with restored pipeline p.run( ["C", "D", "E"], table_name="blacklist", loader_file_format=destination_config.file_format @@ -604,7 +639,9 @@ def some_data(param: str) -> Any: prod_state = production_p.state assert p.state["_state_version"] == prod_state["_state_version"] - 1 # re-attach production and sync - ra_production_p = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT) + ra_production_p = destination_config.attach_pipeline( + pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT + ) ra_production_p.sync_destination() # state didn't change because production is ahead of local with its version # nevertheless this is potentially dangerous situation 🤷 @@ -613,10 +650,18 @@ def some_data(param: str) -> Any: # get all the states, notice version 4 twice (one from production, the other from local) try: with p.sql_client() as client: + # use sql_client to escape identifiers properly state_table = client.make_qualified_table_name(p.default_schema.state_table_name) - + c_version = client.escape_column_name( + p.default_schema.naming.normalize_identifier("version") + ) + c_created_at = client.escape_column_name( + p.default_schema.naming.normalize_identifier("created_at") + ) assert_query_data( - p, f"SELECT version FROM {state_table} ORDER BY created_at DESC", [5, 4, 4, 3, 2] + p, + f"SELECT {c_version} FROM {state_table} ORDER BY {c_created_at} DESC", + [5, 4, 4, 3, 2], ) except SqlClientNotAvailable: pytest.skip(f"destination {destination_config.destination} does not support sql client") @@ -669,7 +714,7 @@ def some_data(param: str) -> Any: assert p.dataset_name == dataset_name print("---> no state sync last attach") - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) # this will prevent from creating of _dlt_pipeline_state p.config.restore_from_destination = False data4 = some_data("state4") @@ -686,7 +731,7 @@ def some_data(param: str) -> Any: assert p.state["_local"]["first_run"] is False # attach again to make the `run` method check the destination print("---> last attach") - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) p.config.restore_from_destination = True data5 = some_data("state4") data5.apply_hints(table_name="state1_data5") @@ -696,8 +741,31 @@ def some_data(param: str) -> Any: def prepare_import_folder(p: Pipeline) -> None: - os.makedirs(p._schema_storage.config.import_schema_path, exist_ok=True) - shutil.copy( - common_yml_case_path("schemas/eth/ethereum_schema_v5"), - os.path.join(p._schema_storage.config.import_schema_path, "ethereum.schema.yaml"), - ) + from tests.common.storages.utils import prepare_eth_import_folder + + prepare_eth_import_folder(p._schema_storage) + + +def set_naming_env(destination: str, naming_convention: str) -> None: + # snake case is for default convention so do not set it + if naming_convention != "snake_case": + # path convention to test weaviate ci_naming + if destination == "weaviate": + if naming_convention.endswith("sql_upper"): + pytest.skip(f"{naming_convention} not supported on weaviate") + else: + naming_convention = "dlt.destinations.impl.weaviate.ci_naming" + os.environ["SCHEMA__NAMING"] = naming_convention + + +def assert_naming_to_caps(destination: str, caps: DestinationCapabilitiesContext) -> None: + naming = Schema("test").naming + if ( + not caps.has_case_sensitive_identifiers + and caps.casefold_identifier is not str + and naming.is_case_sensitive + ): + pytest.skip( + f"Skipping for case insensitive destination {destination} with case folding because" + f" naming {naming.name()} is case sensitive" + ) diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index e8baa33ff3..b33c5a2590 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -17,12 +17,11 @@ from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import arrow_table_all_data_types -from tests.pipeline.utils import assert_load_info, load_table_counts -from tests.load.pipeline.utils import ( +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, ) -from tests.pipeline.utils import load_tables_to_dicts +from tests.pipeline.utils import load_tables_to_dicts, assert_load_info, load_table_counts from tests.utils import TPythonTableFormat @@ -104,7 +103,7 @@ def test_core_functionality( validity_column_names: List[str], active_record_timestamp: Optional[pendulum.DateTime], ) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", @@ -243,7 +242,7 @@ def r(data): ) @pytest.mark.parametrize("simple", [True, False]) def test_child_table(destination_config: DestinationTestConfiguration, simple: bool) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", write_disposition={"disposition": "merge", "strategy": "scd2"} @@ -386,7 +385,7 @@ def r(data): ids=lambda x: x.name, ) def test_grandchild_table(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", write_disposition={"disposition": "merge", "strategy": "scd2"} @@ -479,7 +478,7 @@ def r(data): ids=lambda x: x.name, ) def test_validity_column_name_conflict(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", @@ -525,7 +524,7 @@ def test_active_record_timestamp( destination_config: DestinationTestConfiguration, active_record_timestamp: Optional[TAnyDateTime], ) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", @@ -572,7 +571,7 @@ def _make_scd2_r(table_: Any) -> DltResource: }, ).add_map(add_row_hash_to_table("row_hash")) - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) info = p.run(_make_scd2_r(table), loader_file_format=destination_config.file_format) assert_load_info(info) # make sure we have scd2 columns in schema @@ -608,7 +607,7 @@ def _make_scd2_r(table_: Any) -> DltResource: ids=lambda x: x.name, ) def test_user_provided_row_hash(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", diff --git a/tests/load/pipeline/test_snowflake_pipeline.py b/tests/load/pipeline/test_snowflake_pipeline.py new file mode 100644 index 0000000000..3cfa9e8b21 --- /dev/null +++ b/tests/load/pipeline/test_snowflake_pipeline.py @@ -0,0 +1,55 @@ +import pytest + +import dlt +from dlt.common import Decimal + +from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import DatabaseUndefinedRelation +from tests.pipeline.utils import assert_load_info +from tests.load.utils import destinations_configs, DestinationTestConfiguration + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) +def test_snowflake_case_sensitive_identifiers( + destination_config: DestinationTestConfiguration, +) -> None: + snow_ = dlt.destinations.snowflake(naming_convention="sql_cs_v1") + + dataset_name = "CaseSensitive_Dataset_" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_snowflake_case_sensitive_identifiers", dataset_name=dataset_name, destination=snow_ + ) + caps = pipeline.destination.capabilities() + assert caps.naming_convention == "sql_cs_v1" + + destination_client = pipeline.destination_client() + # assert snowflake caps to be in case sensitive mode + assert destination_client.capabilities.casefold_identifier is str + + # load some case sensitive data + info = pipeline.run([{"Id": 1, "Capital": 0.0}], table_name="Expenses") + assert_load_info(info) + with pipeline.sql_client() as client: + assert client.has_dataset() + # use the same case sensitive dataset + with client.with_alternative_dataset_name(dataset_name): + assert client.has_dataset() + # make it case insensitive (upper) + with client.with_alternative_dataset_name(dataset_name.upper()): + assert not client.has_dataset() + # keep case sensitive but make lowercase + with client.with_alternative_dataset_name(dataset_name.lower()): + assert not client.has_dataset() + + # must use quoted identifiers + rows = client.execute_sql('SELECT "Id", "Capital" FROM "Expenses"') + print(rows) + with pytest.raises(DatabaseUndefinedRelation): + client.execute_sql('SELECT "Id", "Capital" FROM Expenses') diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index e0e2154b57..7f1427f20f 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -8,15 +8,13 @@ from dlt.common.schema.typing import TDataType from tests.load.pipeline.test_merge_disposition import github -from tests.pipeline.utils import load_table_counts -from tests.pipeline.utils import assert_load_info +from tests.pipeline.utils import load_table_counts, assert_load_info from tests.load.utils import ( - TABLE_ROW_ALL_DATA_TYPES, - TABLE_UPDATE_COLUMNS_SCHEMA, + destinations_configs, + DestinationTestConfiguration, assert_all_data_types_row, ) from tests.cases import table_update_and_row -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration @dlt.resource( @@ -65,12 +63,17 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: ) == 4 ) + # pipeline state is loaded with preferred format, so allows (possibly) for two job formats + caps = pipeline.destination.capabilities() + # NOTE: preferred_staging_file_format goes first because here we test staged loading and + # default caps will be modified so preferred_staging_file_format is used as main + preferred_format = caps.preferred_staging_file_format or caps.preferred_loader_file_format assert ( len( [ x for x in package_info.jobs["completed_jobs"] - if x.job_file_info.file_format == destination_config.file_format + if x.job_file_info.file_format in (destination_config.file_format, preferred_format) ] ) == 4 diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index 16c589352e..ba2f6bf172 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -1,7 +1,7 @@ import pytest import dlt from typing import Any -from tests.load.pipeline.utils import ( +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, ) @@ -124,9 +124,13 @@ def source(): ) # schemaless destinations allow adding of root key without the pipeline failing - # for now this is only the case for dremio + # they do not mind adding NOT NULL columns to tables with existing data (id NOT NULL is supported at all) # doing this will result in somewhat useless behavior - destination_allows_adding_root_key = destination_config.destination in ["dremio", "clickhouse"] + destination_allows_adding_root_key = destination_config.destination in [ + "dremio", + "clickhouse", + "athena", + ] if destination_allows_adding_root_key and not with_root_key: pipeline.run( diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index d762029ddd..679c2d6da9 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -1,67 +1 @@ -from typing import Any, Iterator, List, Sequence, TYPE_CHECKING, Callable -import pytest - -import dlt -from dlt.common.destination.reference import WithStagingDataset - -from dlt.common.configuration.container import Container -from dlt.common.pipeline import LoadInfo, PipelineContext - -from tests.load.utils import DestinationTestConfiguration, destinations_configs -from dlt.destinations.exceptions import CantExtractTablePrefix - -if TYPE_CHECKING: - from dlt.destinations.impl.filesystem.filesystem import FilesystemClient - REPLACE_STRATEGIES = ["truncate-and-insert", "insert-from-staging", "staging-optimized"] - - -@pytest.fixture(autouse=True) -def drop_pipeline(request) -> Iterator[None]: - yield - if "no_load" in request.keywords: - return - try: - drop_active_pipeline_data() - except CantExtractTablePrefix: - # for some tests we test that this exception is raised, - # so we suppress it here - pass - - -def drop_active_pipeline_data() -> None: - """Drops all the datasets for currently active pipeline, wipes the working folder and then deactivated it.""" - if Container()[PipelineContext].is_active(): - # take existing pipeline - p = dlt.pipeline() - - def _drop_dataset(schema_name: str) -> None: - with p.destination_client(schema_name) as client: - try: - client.drop_storage() - print("dropped") - except Exception as exc: - print(exc) - if isinstance(client, WithStagingDataset): - with client.with_staging_dataset(): - try: - client.drop_storage() - print("staging dropped") - except Exception as exc: - print(exc) - - # drop_func = _drop_dataset_fs if _is_filesystem(p) else _drop_dataset_sql - # take all schemas and if destination was set - if p.destination: - if p.config.use_single_dataset: - # drop just the dataset for default schema - if p.default_schema_name: - _drop_dataset(p.default_schema_name) - else: - # for each schema, drop the dataset - for schema_name in p.schema_names: - _drop_dataset(schema_name) - - # p._wipe_working_folder() - # deactivate context - Container()[PipelineContext].deactivate() diff --git a/tests/load/postgres/test_postgres_client.py b/tests/load/postgres/test_postgres_client.py index a0fbd85b5b..d8cd996dcf 100644 --- a/tests/load/postgres/test_postgres_client.py +++ b/tests/load/postgres/test_postgres_client.py @@ -11,7 +11,7 @@ from dlt.destinations.impl.postgres.postgres import PostgresClient from dlt.destinations.impl.postgres.sql_client import psycopg2 -from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy, preserve_environ +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage from tests.common.configuration.utils import environment diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 7566b8afce..5ba68be67c 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -4,8 +4,9 @@ from dlt.common.exceptions import TerminalValueError from dlt.common.utils import uniq_id -from dlt.common.schema import Schema +from dlt.common.schema import Schema, utils +from dlt.destinations import postgres from dlt.destinations.impl.postgres.postgres import PostgresClient from dlt.destinations.impl.postgres.configuration import ( PostgresClientConfiguration, @@ -25,13 +26,23 @@ @pytest.fixture def client(empty_schema: Schema) -> PostgresClient: + return create_client(empty_schema) + + +@pytest.fixture +def cs_client(empty_schema: Schema) -> PostgresClient: + # change normalizer to case sensitive + empty_schema._normalizers_config["names"] = "tests.common.cases.normalizers.title_case" + empty_schema.update_normalizers() + return create_client(empty_schema) + + +def create_client(empty_schema: Schema) -> PostgresClient: # return client without opening connection - return PostgresClient( - empty_schema, - PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name( - dataset_name="test_" + uniq_id() - ), + config = PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ) + return postgres().client(empty_schema, config) def test_create_table(client: PostgresClient) -> None: @@ -102,7 +113,7 @@ def test_alter_table(client: PostgresClient) -> None: assert '"col11_precision" time (3) without time zone NOT NULL' in sql -def test_create_table_with_hints(client: PostgresClient) -> None: +def test_create_table_with_hints(client: PostgresClient, empty_schema: Schema) -> None: mod_update = deepcopy(TABLE_UPDATE) # timestamp mod_update[0]["primary_key"] = True @@ -119,8 +130,8 @@ def test_create_table_with_hints(client: PostgresClient) -> None: assert '"col4" timestamp with time zone NOT NULL' in sql # same thing without indexes - client = PostgresClient( - client.schema, + client = postgres().client( + empty_schema, PostgresClientConfiguration( create_indexes=False, credentials=PostgresCredentials(), @@ -129,3 +140,20 @@ def test_create_table_with_hints(client: PostgresClient) -> None: sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql, dialect="postgres") assert '"col2" double precision NOT NULL' in sql + + +def test_create_table_case_sensitive(cs_client: PostgresClient) -> None: + cs_client.schema.update_table( + utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE)) + ) + sql = cs_client._get_table_update_sql( + "Event_test_tablE", + list(cs_client.schema.get_table_columns("Event_test_tablE").values()), + False, + )[0] + sqlfluff.parse(sql, dialect="postgres") + # everything capitalized + assert cs_client.sql_client.fully_qualified_dataset_name(escape=False)[0] == "T" # Test + # every line starts with "Col" + for line in sql.split("\n")[1:]: + assert line.startswith('"Col') diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index d50b50282a..e0cb9dab84 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -5,6 +5,7 @@ from dlt.common import json from dlt.common.utils import uniq_id +from dlt.destinations.adapters import qdrant_adapter from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter, VECTORIZE_HINT from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient from tests.pipeline.utils import assert_load_info @@ -68,6 +69,8 @@ def some_data(): assert schema state = client.get_stored_state("test_pipeline_append") assert state + state = client.get_stored_state("unknown_pipeline") + assert state is None def test_pipeline_append() -> None: @@ -316,8 +319,8 @@ def test_merge_github_nested() -> None: primary_key="id", ) assert_load_info(info) + # assert if schema contains tables with right names - print(p.default_schema.tables.keys()) assert set(p.default_schema.tables.keys()) == { "_dlt_version", "_dlt_loads", diff --git a/tests/load/qdrant/utils.py b/tests/load/qdrant/utils.py index 74d5db9715..3b12d15f86 100644 --- a/tests/load/qdrant/utils.py +++ b/tests/load/qdrant/utils.py @@ -20,16 +20,16 @@ def assert_collection( expected_items_count: int = None, items: List[Any] = None, ) -> None: - client: QdrantClient = pipeline.destination_client() # type: ignore[assignment] + client: QdrantClient + with pipeline.destination_client() as client: # type: ignore[assignment] + # Check if collection exists + exists = client._collection_exists(collection_name) + assert exists - # Check if collection exists - exists = client._collection_exists(collection_name) - assert exists - - qualified_collection_name = client._make_qualified_collection_name(collection_name) - point_records, offset = client.db_client.scroll( - qualified_collection_name, with_payload=True, limit=50 - ) + qualified_collection_name = client._make_qualified_collection_name(collection_name) + point_records, offset = client.db_client.scroll( + qualified_collection_name, with_payload=True, limit=50 + ) if expected_items_count is not None: assert expected_items_count == len(point_records) @@ -55,10 +55,11 @@ def has_collections(client): if Container()[PipelineContext].is_active(): # take existing pipeline p = dlt.pipeline() - client: QdrantClient = p.destination_client() # type: ignore[assignment] + client: QdrantClient - if has_collections(client): - client.drop_storage() + with p.destination_client() as client: # type: ignore[assignment] + if has_collections(client): + client.drop_storage() p._wipe_working_folder() # deactivate context diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index 03bb57c3b4..bb923df673 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -6,13 +6,18 @@ from dlt.common import json, pendulum from dlt.common.configuration.resolve import resolve_configuration +from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.storages import FileStorage from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DatabaseTerminalException -from dlt.destinations.impl.redshift.configuration import RedshiftCredentials +from dlt.destinations import redshift +from dlt.destinations.impl.redshift.configuration import ( + RedshiftCredentials, + RedshiftClientConfiguration, +) from dlt.destinations.impl.redshift.redshift import RedshiftClient, psycopg2 from tests.common.utils import COMMON_TEST_CASES_PATH @@ -42,6 +47,34 @@ def test_postgres_and_redshift_credentials_defaults() -> None: assert red_cred.port == 5439 +def test_redshift_factory() -> None: + schema = Schema("schema") + dest = redshift() + client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.staging_iam_role is None + assert client.config.has_case_sensitive_identifiers is False + assert client.capabilities.has_case_sensitive_identifiers is False + assert client.capabilities.casefold_identifier is str.lower + + # set args explicitly + dest = redshift(has_case_sensitive_identifiers=True, staging_iam_role="LOADER") + client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.staging_iam_role == "LOADER" + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + # set args via config + os.environ["DESTINATION__STAGING_IAM_ROLE"] = "LOADER" + os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True" + dest = redshift() + client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.staging_iam_role == "LOADER" + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + @skipifpypy def test_text_too_long(client: RedshiftClient, file_storage: FileStorage) -> None: caps = client.capabilities diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index 2427bc7cfe..de6f450134 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -6,6 +6,7 @@ from dlt.common.schema import Schema from dlt.common.configuration import resolve_configuration +from dlt.destinations import redshift from dlt.destinations.impl.redshift.redshift import RedshiftClient from dlt.destinations.impl.redshift.configuration import ( RedshiftClientConfiguration, @@ -21,7 +22,7 @@ @pytest.fixture def client(empty_schema: Schema) -> RedshiftClient: # return client without opening connection - return RedshiftClient( + return redshift().client( empty_schema, RedshiftClientConfiguration(credentials=RedshiftCredentials())._bind_dataset_name( dataset_name="test_" + uniq_id() diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index 691f0b5a64..10d93d104c 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -121,10 +121,10 @@ def test_only_authenticator() -> None: } -def test_no_query(environment) -> None: - c = SnowflakeCredentials("snowflake://user1:pass1@host1/db1") - assert str(c.to_url()) == "snowflake://user1:pass1@host1/db1" - print(c.to_url()) +# def test_no_query(environment) -> None: +# c = SnowflakeCredentials("snowflake://user1:pass1@host1/db1") +# assert str(c.to_url()) == "snowflake://user1:pass1@host1/db1" +# print(c.to_url()) def test_query_additional_params() -> None: diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index bdbe888fb5..4bb69085da 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -5,12 +5,12 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema +from dlt.destinations import snowflake from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient from dlt.destinations.impl.snowflake.configuration import ( SnowflakeClientConfiguration, SnowflakeCredentials, ) -from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate from tests.load.utils import TABLE_UPDATE, empty_schema @@ -22,7 +22,7 @@ def snowflake_client(empty_schema: Schema) -> SnowflakeClient: # return client without opening connection creds = SnowflakeCredentials() - return SnowflakeClient( + return snowflake().client( empty_schema, SnowflakeClientConfiguration(credentials=creds)._bind_dataset_name( dataset_name="test_" + uniq_id() diff --git a/tests/load/synapse/test_synapse_configuration.py b/tests/load/synapse/test_synapse_configuration.py index f366d87d09..8aaea03b0f 100644 --- a/tests/load/synapse/test_synapse_configuration.py +++ b/tests/load/synapse/test_synapse_configuration.py @@ -1,8 +1,11 @@ +import os import pytest from dlt.common.configuration import resolve_configuration from dlt.common.exceptions import SystemConfigurationException +from dlt.common.schema import Schema +from dlt.destinations import synapse from dlt.destinations.impl.synapse.configuration import ( SynapseClientConfiguration, SynapseCredentials, @@ -14,7 +17,42 @@ def test_synapse_configuration() -> None: # By default, unique indexes should not be created. - assert SynapseClientConfiguration().create_indexes is False + c = SynapseClientConfiguration() + assert c.create_indexes is False + assert c.has_case_sensitive_identifiers is False + assert c.staging_use_msi is False + + +def test_synapse_factory() -> None: + schema = Schema("schema") + dest = synapse() + client = dest.client(schema, SynapseClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is False + assert client.config.staging_use_msi is False + assert client.config.has_case_sensitive_identifiers is False + assert client.capabilities.has_case_sensitive_identifiers is False + assert client.capabilities.casefold_identifier is str + + # set args explicitly + dest = synapse(has_case_sensitive_identifiers=True, create_indexes=True, staging_use_msi=True) + client = dest.client(schema, SynapseClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.staging_use_msi is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + # set args via config + os.environ["DESTINATION__CREATE_INDEXES"] = "True" + os.environ["DESTINATION__STAGING_USE_MSI"] = "True" + os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True" + dest = synapse() + client = dest.client(schema, SynapseClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.staging_use_msi is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str def test_parse_native_representation() -> None: diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 9ee2ebe202..1a92a20f1e 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -7,17 +7,18 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema, TColumnHint -from dlt.destinations.impl.synapse.synapse import SynapseClient +from dlt.destinations import synapse +from dlt.destinations.impl.synapse.synapse import ( + SynapseClient, + HINT_TO_SYNAPSE_ATTR, + TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR, +) from dlt.destinations.impl.synapse.configuration import ( SynapseClientConfiguration, SynapseCredentials, ) from tests.load.utils import TABLE_UPDATE, empty_schema -from dlt.destinations.impl.synapse.synapse import ( - HINT_TO_SYNAPSE_ATTR, - TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR, -) # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -26,7 +27,7 @@ @pytest.fixture def client(empty_schema: Schema) -> SynapseClient: # return client without opening connection - client = SynapseClient( + client = synapse().client( empty_schema, SynapseClientConfiguration(credentials=SynapseCredentials())._bind_dataset_name( dataset_name="test_" + uniq_id() @@ -39,7 +40,7 @@ def client(empty_schema: Schema) -> SynapseClient: @pytest.fixture def client_with_indexes_enabled(empty_schema: Schema) -> SynapseClient: # return client without opening connection - client = SynapseClient( + client = synapse().client( empty_schema, SynapseClientConfiguration( credentials=SynapseCredentials(), create_indexes=True diff --git a/tests/load/synapse/test_synapse_table_indexing.py b/tests/load/synapse/test_synapse_table_indexing.py index a9d426ad4a..d877b769cc 100644 --- a/tests/load/synapse/test_synapse_table_indexing.py +++ b/tests/load/synapse/test_synapse_table_indexing.py @@ -1,20 +1,14 @@ import os import pytest from typing import Iterator, List, Any, Union -from textwrap import dedent import dlt from dlt.common.schema import TColumnSchema -from dlt.destinations.sql_client import SqlClientBase - -from dlt.destinations.impl.synapse import synapse_adapter +from dlt.destinations.adapters import synapse_adapter from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType from tests.load.utils import TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES -from tests.load.pipeline.utils import ( - drop_pipeline, -) # this import ensures all test data gets removed from tests.load.synapse.utils import get_storage_table_index_type # mark all tests as essential, do not remove diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 30de51f069..be917672f1 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -4,11 +4,11 @@ from unittest import mock import pytest from unittest.mock import patch -from typing import List +from typing import List, Tuple from dlt.common.exceptions import TerminalException, TerminalValueError from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName -from dlt.common.storages.load_package import LoadJobInfo +from dlt.common.storages.load_package import LoadJobInfo, TJobState from dlt.common.storages.load_storage import JobFileFormatUnsupported from dlt.common.destination.reference import LoadJob, TDestination from dlt.common.schema.utils import ( @@ -31,7 +31,6 @@ clean_test_storage, init_test_logging, TEST_DICT_CONFIG_PROVIDER, - preserve_environ, ) from tests.load.utils import prepare_load_package from tests.utils import skip_if_not_active, TEST_STORAGE_ROOT @@ -97,15 +96,11 @@ def test_unsupported_write_disposition() -> None: with ThreadPoolExecutor() as pool: load.run(pool) # job with unsupported write disp. is failed - exception_file = [ - f - for f in load.load_storage.normalized_packages.list_failed_jobs(load_id) - if f.endswith(".exception") - ][0] - assert ( - "LoadClientUnsupportedWriteDisposition" - in load.load_storage.normalized_packages.storage.load(exception_file) + failed_job = load.load_storage.normalized_packages.list_failed_jobs(load_id)[0] + failed_message = load.load_storage.normalized_packages.get_job_failed_message( + load_id, ParsedLoadJobFileName.parse(failed_job) ) + assert "LoadClientUnsupportedWriteDisposition" in failed_message def test_get_new_jobs_info() -> None: @@ -125,7 +120,7 @@ def test_get_completed_table_chain_single_job_per_table() -> None: schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) top_job_table = get_top_level_table(schema.tables, "event_user") - all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs = load.load_storage.normalized_packages.list_all_jobs_with_states(load_id) assert get_completed_table_chain(schema, all_jobs, top_job_table) is None # fake being completed assert ( @@ -144,12 +139,12 @@ def test_get_completed_table_chain_single_job_per_table() -> None: load.load_storage.normalized_packages.start_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs = load.load_storage.normalized_packages.list_all_jobs_with_states(load_id) assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) is None load.load_storage.normalized_packages.complete_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs = load.load_storage.normalized_packages.list_all_jobs_with_states(load_id) assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) == [ schema.get_table("event_loop_interrupted") ] @@ -485,9 +480,7 @@ def test_extend_table_chain() -> None: # no jobs for bot assert _extend_tables_with_table_chain(schema, ["event_bot"], ["event_user"]) == set() # skip unseen tables - del schema.tables["event_user__parse_data__entities"][ # type:ignore[typeddict-item] - "x-normalizer" - ] + del schema.tables["event_user__parse_data__entities"]["x-normalizer"] entities_chain = { name for name in schema.data_table_names() @@ -533,25 +526,15 @@ def test_get_completed_table_chain_cases() -> None: # child completed, parent not event_user = schema.get_table("event_user") event_user_entities = schema.get_table("event_user__parse_data__entities") - event_user_job = LoadJobInfo( + event_user_job: Tuple[TJobState, ParsedLoadJobFileName] = ( "started_jobs", - "path", - 0, - None, - 0, ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl"), - None, ) - event_user_entities_job = LoadJobInfo( + event_user_entities_job: Tuple[TJobState, ParsedLoadJobFileName] = ( "completed_jobs", - "path", - 0, - None, - 0, ParsedLoadJobFileName( "event_user__parse_data__entities", "event_user__parse_data__entities_id", 0, "jsonl" ), - None, ) chain = get_completed_table_chain(schema, [event_user_job, event_user_entities_job], event_user) assert chain is None @@ -561,24 +544,21 @@ def test_get_completed_table_chain_cases() -> None: schema, [event_user_job, event_user_entities_job], event_user, - event_user_job.job_file_info.job_id(), + event_user_job[1].job_id(), ) # full chain assert chain == [event_user, event_user_entities] # parent failed, child completed chain = get_completed_table_chain( - schema, [event_user_job._replace(state="failed_jobs"), event_user_entities_job], event_user + schema, [("failed_jobs", event_user_job[1]), event_user_entities_job], event_user ) assert chain == [event_user, event_user_entities] # both failed chain = get_completed_table_chain( schema, - [ - event_user_job._replace(state="failed_jobs"), - event_user_entities_job._replace(state="failed_jobs"), - ], + [("failed_jobs", event_user_job[1]), ("failed_jobs", event_user_entities_job[1])], event_user, ) assert chain == [event_user, event_user_entities] @@ -589,16 +569,16 @@ def test_get_completed_table_chain_cases() -> None: event_user["write_disposition"] = w_d # type:ignore[typeddict-item] chain = get_completed_table_chain( - schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + schema, [event_user_job], event_user, event_user_job[1].job_id() ) assert chain == user_chain # but if child is present and incomplete... chain = get_completed_table_chain( schema, - [event_user_job, event_user_entities_job._replace(state="new_jobs")], + [event_user_job, ("new_jobs", event_user_entities_job[1])], event_user, - event_user_job.job_file_info.job_id(), + event_user_job[1].job_id(), ) # noting is returned assert chain is None @@ -607,9 +587,9 @@ def test_get_completed_table_chain_cases() -> None: deep_child = schema.tables[ "event_user__parse_data__response_selector__default__response__response_templates" ] - del deep_child["x-normalizer"] # type:ignore[typeddict-item] + del deep_child["x-normalizer"] chain = get_completed_table_chain( - schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + schema, [event_user_job], event_user, event_user_job[1].job_id() ) user_chain.remove(deep_child) assert chain == user_chain @@ -784,7 +764,7 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No assert not load.load_storage.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) ) - completed_path = load.load_storage.loaded_packages.get_job_folder_path( + completed_path = load.load_storage.loaded_packages.get_job_state_folder_path( load_id, "completed_jobs" ) if should_delete_completed: diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 1c035f7f68..38155a8b09 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -11,10 +11,14 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient from tests.utils import TEST_STORAGE_ROOT, skipifpypy -from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage -from tests.load.pipeline.utils import destinations_configs +from tests.load.utils import ( + expect_load_file, + prepare_table, + yield_client_with_storage, + destinations_configs, +) -DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse"] +DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse", "motherduck"] @pytest.fixture @@ -176,7 +180,6 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ids=lambda x: x.name, ) def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -> None: - mocked_caps = client.sql_client.__class__.capabilities writer_type = client.capabilities.insert_values_writer_type insert_sql = prepare_insert_statement(10, writer_type) @@ -185,10 +188,10 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - elif writer_type == "select_union": pre, post, sep = ("SELECT ", "", " UNION ALL\n") + # caps are instance and are attr of sql client instance so it is safe to mock them + client.sql_client.capabilities.max_query_length = 2 # this guarantees that we execute inserts line by line - with patch.object(mocked_caps, "max_query_length", 2), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # print(mocked_fragments.mock_calls) @@ -211,9 +214,8 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - # set query length so it reads data until separator ("," or " UNION ALL") (followed by \n) query_length = (idx - start_idx - 1) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + client.sql_client.capabilities.max_query_length = query_length + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on ',' @@ -221,9 +223,8 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - # so it reads until "\n" query_length = (idx - start_idx) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + client.sql_client.capabilities.max_query_length = query_length + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on separator ("," or " UNION ALL") @@ -235,9 +236,8 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - elif writer_type == "select_union": offset = 1 query_length = (len(insert_sql) - start_idx - offset) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + client.sql_client.capabilities.max_query_length = query_length + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on ',' @@ -251,22 +251,21 @@ def assert_load_with_max_query( max_query_length: int, ) -> None: # load and check for real - mocked_caps = client.sql_client.__class__.capabilities - with patch.object(mocked_caps, "max_query_length", max_query_length): - user_table_name = prepare_table(client) - insert_sql = prepare_insert_statement( - insert_lines, client.capabilities.insert_values_writer_type - ) - expect_load_file(client, file_storage, insert_sql, user_table_name) - canonical_name = client.sql_client.make_qualified_table_name(user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] - assert rows_count == insert_lines - # get all uniq ids in order - rows = client.sql_client.execute_sql( - f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;" - ) - v_ids = list(map(lambda i: i[0], rows)) - assert list(map(str, range(0, insert_lines))) == v_ids + client.sql_client.capabilities.max_query_length = max_query_length + user_table_name = prepare_table(client) + insert_sql = prepare_insert_statement( + insert_lines, client.capabilities.insert_values_writer_type + ) + expect_load_file(client, file_storage, insert_sql, user_table_name) + canonical_name = client.sql_client.make_qualified_table_name(user_table_name) + rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] + assert rows_count == insert_lines + # get all uniq ids in order + rows = client.sql_client.execute_sql( + f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;" + ) + v_ids = list(map(lambda i: i[0], rows)) + assert list(map(str, range(0, insert_lines))) == v_ids client.sql_client.execute_sql(f"DELETE FROM {canonical_name}") diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 7e360a6664..35b988d46e 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -5,7 +5,7 @@ from unittest.mock import patch import pytest import datetime # noqa: I251 -from typing import Iterator, Tuple, List, Dict, Any, Mapping, MutableMapping +from typing import Iterator, Tuple, List, Dict, Any from dlt.common import json, pendulum from dlt.common.schema import Schema @@ -15,7 +15,7 @@ TWriteDisposition, TTableSchema, ) -from dlt.common.schema.utils import new_table, new_column +from dlt.common.schema.utils import new_table, new_column, pipeline_state_table from dlt.common.storages import FileStorage from dlt.common.schema import TTableSchemaColumns from dlt.common.utils import uniq_id @@ -26,7 +26,7 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.destination.reference import WithStagingDataset +from dlt.common.destination.reference import StateInfo, WithStagingDataset from tests.cases import table_update_and_row, assert_all_data_types_row from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage @@ -41,8 +41,13 @@ cm_yield_client_with_storage, write_dataset, prepare_table, + normalize_storage_table_cols, + destinations_configs, + DestinationTestConfiguration, ) -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential @pytest.fixture @@ -69,13 +74,18 @@ def test_initialize_storage(client: SqlJobClientBase) -> None: ) def test_get_schema_on_empty_storage(client: SqlJobClientBase) -> None: # test getting schema on empty dataset without any tables - exists, _ = client.get_storage_table(VERSION_TABLE_NAME) - assert exists is False + table_name, table_columns = list(client.get_storage_tables([VERSION_TABLE_NAME]))[0] + assert table_name == VERSION_TABLE_NAME + assert len(table_columns) == 0 schema_info = client.get_stored_schema() assert schema_info is None schema_info = client.get_stored_schema_by_hash("8a0298298823928939") assert schema_info is None + # now try to get several non existing tables + storage_tables = list(client.get_storage_tables(["no_table_1", "no_table_2"])) + assert [("no_table_1", {}), ("no_table_2", {})] == storage_tables + @pytest.mark.order(3) @pytest.mark.parametrize( @@ -90,17 +100,17 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: # check is event slot has variant assert schema_update["event_slot"]["columns"]["value"]["variant"] is True # now we have dlt tables - exists, _ = client.get_storage_table(VERSION_TABLE_NAME) - assert exists is True - exists, _ = client.get_storage_table(LOADS_TABLE_NAME) - assert exists is True + storage_tables = list(client.get_storage_tables([VERSION_TABLE_NAME, LOADS_TABLE_NAME])) + assert set([table[0] for table in storage_tables]) == {VERSION_TABLE_NAME, LOADS_TABLE_NAME} + assert [len(table[1]) > 0 for table in storage_tables] == [True, True] # verify if schemas stored this_schema = client.get_stored_schema_by_hash(schema.version_hash) newest_schema = client.get_stored_schema() # should point to the same schema assert this_schema == newest_schema # check fields - assert this_schema.version == 1 == schema.version + # NOTE: schema version == 2 because we updated default hints after loading the schema + assert this_schema.version == 2 == schema.version assert this_schema.version_hash == schema.stored_version_hash assert this_schema.engine_version == schema.ENGINE_VERSION assert this_schema.schema_name == schema.name @@ -120,7 +130,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: this_schema = client.get_stored_schema_by_hash(schema.version_hash) newest_schema = client.get_stored_schema() assert this_schema == newest_schema - assert this_schema.version == schema.version == 2 + assert this_schema.version == schema.version == 3 assert this_schema.version_hash == schema.stored_version_hash # simulate parallel write: initial schema is modified differently and written alongside the first one @@ -128,14 +138,14 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: first_schema = Schema.from_dict(json.loads(first_version_schema)) first_schema.tables["event_bot"]["write_disposition"] = "replace" first_schema._bump_version() - assert first_schema.version == this_schema.version == 2 + assert first_schema.version == this_schema.version == 3 # wait to make load_newest_schema deterministic sleep(1) client._update_schema_in_storage(first_schema) this_schema = client.get_stored_schema_by_hash(first_schema.version_hash) newest_schema = client.get_stored_schema() assert this_schema == newest_schema # error - assert this_schema.version == first_schema.version == 2 + assert this_schema.version == first_schema.version == 3 assert this_schema.version_hash == first_schema.stored_version_hash # get schema with non existing hash @@ -157,7 +167,6 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: assert this_schema == newest_schema -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -190,11 +199,11 @@ def test_complete_load(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, subset=["redshift", "postgres", "duckdb"]), + destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name, ) -def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: +def test_schema_update_create_table(client: SqlJobClientBase) -> None: # infer typical rasa event schema schema = client.schema table_name = "event_test_table" + uniq_id() @@ -215,8 +224,8 @@ def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: assert table_update["timestamp"]["sort"] is True assert table_update["sender_id"]["cluster"] is True assert table_update["_dlt_id"]["unique"] is True - exists, _ = client.get_storage_table(table_name) - assert exists is True + _, storage_columns = list(client.get_storage_tables([table_name]))[0] + assert len(storage_columns) > 0 @pytest.mark.parametrize( @@ -225,7 +234,15 @@ def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: indirect=True, ids=lambda x: x.name, ) -def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: +@pytest.mark.parametrize("dataset_name", (None, "_hidden_ds")) +def test_schema_update_create_table_bigquery(client: SqlJobClientBase, dataset_name: str) -> None: + # patch dataset name + if dataset_name: + # drop existing dataset + client.drop_storage() + client.sql_client.dataset_name = dataset_name + "_" + uniq_id() + client.initialize_storage() + # infer typical rasa event schema schema = client.schema # this will be partition @@ -241,14 +258,11 @@ def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: table_update = schema_update["event_test_table"]["columns"] assert table_update["timestamp"]["partition"] is True assert table_update["_dlt_id"]["nullable"] is False - exists, storage_table = client.get_storage_table("event_test_table") - assert exists is True - assert storage_table["timestamp"]["partition"] is True - assert storage_table["sender_id"]["cluster"] is True - exists, storage_table = client.get_storage_table("_dlt_version") - assert exists is True - assert storage_table["version"]["partition"] is False - assert storage_table["version"]["cluster"] is False + _, storage_columns = client.get_storage_table("event_test_table") + # check if all columns present + assert storage_columns.keys() == client.schema.tables["event_test_table"]["columns"].keys() + _, storage_columns = client.get_storage_table("_dlt_version") + assert storage_columns.keys() == client.schema.tables["_dlt_version"]["columns"].keys() @pytest.mark.parametrize( @@ -285,10 +299,11 @@ def test_schema_update_alter_table(client: SqlJobClientBase) -> None: assert len(schema_update[table_name]["columns"]) == 2 assert schema_update[table_name]["columns"]["col3"]["data_type"] == "double" assert schema_update[table_name]["columns"]["col4"]["data_type"] == "timestamp" - _, storage_table = client.get_storage_table(table_name) + _, storage_table_cols = client.get_storage_table(table_name) # 4 columns - assert len(storage_table) == 4 - assert storage_table["col4"]["data_type"] == "timestamp" + assert len(storage_table_cols) == 4 + storage_table_cols = normalize_storage_table_cols(table_name, storage_table_cols, schema) + assert storage_table_cols["col4"]["data_type"] == "timestamp" @pytest.mark.parametrize( @@ -341,9 +356,7 @@ def test_drop_tables(client: SqlJobClientBase) -> None: client.drop_tables(*tables_to_drop, delete_schema=False) # Verify requested tables are dropped - for tbl in tables_to_drop: - exists, _ = client.get_storage_table(tbl) - assert not exists + assert all(len(table[1]) == 0 for table in client.get_storage_tables(tables_to_drop)) # Verify _dlt_version schema is updated and old versions deleted table_name = client.sql_client.make_qualified_table_name(VERSION_TABLE_NAME) @@ -376,14 +389,13 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: for name, column in table_update.items(): assert column.items() >= TABLE_UPDATE_COLUMNS_SCHEMA[name].items() # now get the actual schema from the db - exists, storage_table = client.get_storage_table(table_name) - assert exists is True + _, storage_table = list(client.get_storage_tables([table_name]))[0] + assert len(storage_table) > 0 # column order must match TABLE_UPDATE storage_columns = list(storage_table.values()) for c, expected_c in zip(TABLE_UPDATE, storage_columns): - # print(c["name"]) - # print(c["data_type"]) - assert c["name"] == expected_c["name"] + # storage columns are returned with column names as in information schema + assert client.capabilities.casefold_identifier(c["name"]) == expected_c["name"] # athena does not know wei data type and has no JSON type, time is not supported with parquet tables if client.config.destination_type == "athena" and c["data_type"] in ( "wei", @@ -429,8 +441,7 @@ def _assert_columns_order(sql_: str) -> None: if hasattr(client.sql_client, "escape_ddl_identifier"): col_name = client.sql_client.escape_ddl_identifier(c["name"]) else: - col_name = client.capabilities.escape_identifier(c["name"]) - print(col_name) + col_name = client.sql_client.escape_column_name(c["name"]) # find column names idx = sql_.find(col_name, idx) assert idx > 0, f"column {col_name} not found in script" @@ -716,6 +727,53 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon assert client.sql_client.has_dataset() +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize( + "naming_convention", + [ + "tests.common.cases.normalizers.title_case", + "snake_case", + ], +) +def test_get_stored_state( + destination_config: DestinationTestConfiguration, + naming_convention: str, + file_storage: FileStorage, +) -> None: + os.environ["SCHEMA__NAMING"] = naming_convention + + with cm_yield_client_with_storage( + destination_config.destination, default_config_values={"default_schema_name": None} + ) as client: + # event schema with event table + if not client.capabilities.preferred_loader_file_format: + pytest.skip( + "preferred loader file format not set, destination will only work with staging" + ) + # load pipeline state + state_table = pipeline_state_table() + partial = client.schema.update_table(state_table) + print(partial) + client.schema._bump_version() + client.update_stored_schema() + + state_info = StateInfo(1, 4, "pipeline", "compressed", pendulum.now(), None, "_load_id") + doc = state_info.as_doc() + norm_doc = {client.schema.naming.normalize_identifier(k): v for k, v in doc.items()} + with io.BytesIO() as f: + # use normalized columns + write_dataset(client, f, [norm_doc], partial["columns"]) + query = f.getvalue().decode() + expect_load_file(client, file_storage, query, partial["name"]) + client.complete_load("_load_id") + + # get state + stored_state = client.get_stored_state("pipeline") + assert doc == stored_state.as_doc() + + @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name ) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 26d7884179..fa31f1db65 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -22,8 +22,15 @@ from dlt.common.time import ensure_pendulum_datetime from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage -from tests.load.utils import yield_client_with_storage, prepare_table, AWS_BUCKET -from tests.load.pipeline.utils import destinations_configs +from tests.load.utils import ( + yield_client_with_storage, + prepare_table, + AWS_BUCKET, + destinations_configs, +) + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential @pytest.fixture @@ -141,7 +148,6 @@ def test_malformed_execute_parameters(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -189,7 +195,6 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert len(rows) == 0 -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -212,7 +217,6 @@ def test_execute_ddl(client: SqlJobClientBase) -> None: assert rows[0][0] == Decimal("1.0") -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -255,7 +259,6 @@ def test_execute_query(client: SqlJobClientBase) -> None: assert len(rows) == 0 -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -307,7 +310,6 @@ def test_execute_df(client: SqlJobClientBase) -> None: assert df_3 is None -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -490,7 +492,7 @@ def test_transaction_isolation(client: SqlJobClientBase) -> None: def test_thread(thread_id: Decimal) -> None: # make a copy of the sql_client thread_client = client.sql_client.__class__( - client.sql_client.dataset_name, client.sql_client.credentials + client.sql_client.dataset_name, client.sql_client.credentials, client.capabilities ) with thread_client: with thread_client.begin_transaction(): diff --git a/tests/load/utils.py b/tests/load/utils.py index 8048d9fe51..00ed4e3bf3 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -24,13 +24,15 @@ from dlt.common.destination import TLoaderFileFormat, Destination from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT from dlt.common.data_writers import DataWriter +from dlt.common.pipeline import PipelineContext from dlt.common.schema import TTableSchemaColumns, Schema from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration -from dlt.common.schema.utils import new_table +from dlt.common.schema.utils import new_table, normalize_table_identifiers from dlt.common.storages import ParsedLoadJobFileName, LoadStorage, PackageStorage from dlt.common.typing import StrAny from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import CantExtractTablePrefix from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase @@ -126,6 +128,7 @@ class DestinationTestConfiguration: force_iceberg: bool = False supports_dbt: bool = True disable_compression: bool = False + dev_mode: bool = False @property def name(self) -> str: @@ -140,15 +143,26 @@ def name(self) -> str: name += f"-{self.extra_info}" return name + @property + def factory_kwargs(self) -> Dict[str, Any]: + return { + k: getattr(self, k) + for k in [ + "bucket_url", + "stage_name", + "staging_iam_role", + "staging_use_msi", + "force_iceberg", + ] + if getattr(self, k, None) is not None + } + def setup(self) -> None: """Sets up environment variables for this destination configuration""" - os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = self.bucket_url or "" - os.environ["DESTINATION__STAGE_NAME"] = self.stage_name or "" - os.environ["DESTINATION__STAGING_IAM_ROLE"] = self.staging_iam_role or "" - os.environ["DESTINATION__STAGING_USE_MSI"] = str(self.staging_use_msi) or "" - os.environ["DESTINATION__FORCE_ICEBERG"] = str(self.force_iceberg) or "" + for k, v in self.factory_kwargs.items(): + os.environ[f"DESTINATION__{k.upper()}"] = str(v) - """For the filesystem destinations we disable compression to make analyzing the result easier""" + # For the filesystem destinations we disable compression to make analyzing the result easier if self.destination == "filesystem" or self.disable_compression: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" @@ -156,17 +170,24 @@ def setup_pipeline( self, pipeline_name: str, dataset_name: str = None, dev_mode: bool = False, **kwargs ) -> dlt.Pipeline: """Convenience method to setup pipeline with this configuration""" + self.dev_mode = dev_mode self.setup() pipeline = dlt.pipeline( pipeline_name=pipeline_name, - destination=self.destination, - staging=self.staging, + destination=kwargs.pop("destination", self.destination), + staging=kwargs.pop("staging", self.staging), dataset_name=dataset_name or pipeline_name, dev_mode=dev_mode, **kwargs, ) return pipeline + def attach_pipeline(self, pipeline_name: str, **kwargs) -> dlt.Pipeline: + """Attach to existing pipeline keeping the dev_mode""" + # remember dev_mode from setup_pipeline + pipeline = dlt.attach(pipeline_name, **kwargs) + return pipeline + def destinations_configs( default_sql_configs: bool = False, @@ -255,8 +276,10 @@ def destinations_configs( assert set(SQL_DESTINATIONS) == {d.destination for d in destination_configs} if default_vector_configs: - # for now only weaviate - destination_configs += [DestinationTestConfiguration(destination="weaviate")] + destination_configs += [ + DestinationTestConfiguration(destination="weaviate"), + DestinationTestConfiguration(destination="lancedb"), + ] if default_staging_configs or all_staging_configs: destination_configs += [ @@ -489,6 +512,60 @@ def destinations_configs( return destination_configs +@pytest.fixture(autouse=True) +def drop_pipeline(request, preserve_environ) -> Iterator[None]: + # NOTE: keep `preserve_environ` to make sure fixtures are executed in order`` + yield + if "no_load" in request.keywords: + return + try: + drop_active_pipeline_data() + except CantExtractTablePrefix: + # for some tests we test that this exception is raised, + # so we suppress it here + pass + + +def drop_active_pipeline_data() -> None: + """Drops all the datasets for currently active pipeline, wipes the working folder and then deactivated it.""" + if Container()[PipelineContext].is_active(): + try: + # take existing pipeline + p = dlt.pipeline() + + def _drop_dataset(schema_name: str) -> None: + with p.destination_client(schema_name) as client: + try: + client.drop_storage() + print("dropped") + except Exception as exc: + print(exc) + if isinstance(client, WithStagingDataset): + with client.with_staging_dataset(): + try: + client.drop_storage() + print("staging dropped") + except Exception as exc: + print(exc) + + # drop_func = _drop_dataset_fs if _is_filesystem(p) else _drop_dataset_sql + # take all schemas and if destination was set + if p.destination: + if p.config.use_single_dataset: + # drop just the dataset for default schema + if p.default_schema_name: + _drop_dataset(p.default_schema_name) + else: + # for each schema, drop the dataset + for schema_name in p.schema_names: + _drop_dataset(schema_name) + + # p._wipe_working_folder() + finally: + # always deactivate context, working directory will be wiped when the next test starts + Container()[PipelineContext].deactivate() + + @pytest.fixture def empty_schema() -> Schema: schema = Schema("event") @@ -580,6 +657,9 @@ def yield_client( ) schema_storage = SchemaStorage(storage_config) schema = schema_storage.load_schema(schema_name) + schema.update_normalizers() + # NOTE: schema version is bumped because new default hints are added + schema._bump_version() # create client and dataset client: SqlJobClientBase = None @@ -680,7 +760,7 @@ def prepare_load_package( shutil.copy( path, load_storage.new_packages.storage.make_full_path( - load_storage.new_packages.get_job_folder_path(load_id, "new_jobs") + load_storage.new_packages.get_job_state_folder_path(load_id, "new_jobs") ), ) schema_path = Path("./tests/load/cases/loading/schema.json") @@ -708,3 +788,15 @@ def sequence_generator() -> Generator[List[Dict[str, str]], None, None]: while True: yield [{"content": str(count + i)} for i in range(3)] count += 3 + + +def normalize_storage_table_cols( + table_name: str, cols: TTableSchemaColumns, schema: Schema +) -> TTableSchemaColumns: + """Normalize storage table columns back into schema naming""" + # go back to schema naming convention. this is a hack - will work here to + # reverse snowflake UPPER case folding + storage_table = normalize_table_identifiers( + new_table(table_name, columns=cols.values()), schema.naming # type: ignore[arg-type] + ) + return storage_table["columns"] diff --git a/tests/load/weaviate/test_pipeline.py b/tests/load/weaviate/test_pipeline.py index ee42ab59d8..fc46d00d05 100644 --- a/tests/load/weaviate/test_pipeline.py +++ b/tests/load/weaviate/test_pipeline.py @@ -4,9 +4,13 @@ import dlt from dlt.common import json +from dlt.common.schema.exceptions import ( + SchemaCorruptedException, + SchemaIdentifierNormalizationCollision, +) from dlt.common.utils import uniq_id -from dlt.destinations.impl.weaviate import weaviate_adapter +from dlt.destinations.adapters import weaviate_adapter from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient @@ -244,7 +248,8 @@ def movies_data(): assert_class(pipeline, "MoviesData", items=data) -def test_pipeline_with_schema_evolution(): +@pytest.mark.parametrize("vectorized", (True, False), ids=("vectorized", "not-vectorized")) +def test_pipeline_with_schema_evolution(vectorized: bool): data = [ { "doc_id": 1, @@ -260,7 +265,8 @@ def test_pipeline_with_schema_evolution(): def some_data(): yield data - weaviate_adapter(some_data, vectorize=["content"]) + if vectorized: + weaviate_adapter(some_data, vectorize=["content"]) pipeline = dlt.pipeline( pipeline_name="test_pipeline_append", @@ -280,17 +286,22 @@ def some_data(): "doc_id": 3, "content": "3", "new_column": "new", + "new_vec_column": "lorem lorem", }, { "doc_id": 4, "content": "4", "new_column": "new", + "new_vec_column": "lorem lorem", }, ] - pipeline.run( - some_data(), - ) + some_data_2 = some_data() + + if vectorized: + weaviate_adapter(some_data_2, vectorize=["new_vec_column"]) + + pipeline.run(some_data_2) table_schema = pipeline.default_schema.tables["SomeData"] assert "new_column" in table_schema["columns"] @@ -298,6 +309,8 @@ def some_data(): aggregated_data.extend(data) aggregated_data[0]["new_column"] = None aggregated_data[1]["new_column"] = None + aggregated_data[0]["new_vec_column"] = None + aggregated_data[1]["new_vec_column"] = None assert_class(pipeline, "SomeData", items=aggregated_data) @@ -391,7 +404,7 @@ def test_vectorize_property_without_data() -> None: primary_key="vAlue", columns={"vAlue": {"data_type": "text"}}, ) - assert isinstance(pipe_ex.value.__context__, PropertyNameConflict) + assert isinstance(pipe_ex.value.__context__, SchemaIdentifierNormalizationCollision) # set the naming convention to case insensitive os.environ["SCHEMA__NAMING"] = "dlt.destinations.impl.weaviate.ci_naming" diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 8c3344f152..dc2110d2f6 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -5,6 +5,7 @@ from dlt.common.schema import Schema from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.schema.exceptions import SchemaIdentifierNormalizationCollision from dlt.common.utils import uniq_id from dlt.common.schema.typing import TWriteDisposition, TColumnSchema, TTableSchemaColumns @@ -13,7 +14,7 @@ from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema.utils import new_table +from dlt.common.schema.utils import new_table, normalize_table_identifiers from tests.load.utils import ( TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE, @@ -58,11 +59,11 @@ def make_client(naming_convention: str) -> Iterator[WeaviateClient]: "test_schema", {"names": f"dlt.destinations.impl.weaviate.{naming_convention}", "json": None}, ) - _client = get_client_instance(schema) - try: - yield _client - finally: - _client.drop_storage() + with get_client_instance(schema) as _client: + try: + yield _client + finally: + _client.drop_storage() @pytest.fixture @@ -114,11 +115,18 @@ def test_case_sensitive_properties_create(client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + normalize_table_identifiers( + new_table(class_name, columns=table_create), client.schema.naming + ) ) client.schema._bump_version() - with pytest.raises(PropertyNameConflict): + with pytest.raises(SchemaIdentifierNormalizationCollision) as clash_ex: client.update_stored_schema() + assert clash_ex.value.identifier_type == "column" + assert clash_ex.value.identifier_name == "coL1" + assert clash_ex.value.conflict_identifier_name == "col1" + assert clash_ex.value.table_name == "ColClass" + assert clash_ex.value.naming_name == "dlt.destinations.impl.weaviate.naming" def test_case_insensitive_properties_create(ci_client: WeaviateClient) -> None: @@ -129,7 +137,9 @@ def test_case_insensitive_properties_create(ci_client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] ci_client.schema.update_table( - ci_client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + normalize_table_identifiers( + new_table(class_name, columns=table_create), ci_client.schema.naming + ) ) ci_client.schema._bump_version() ci_client.update_stored_schema() @@ -146,16 +156,20 @@ def test_case_sensitive_properties_add(client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + normalize_table_identifiers( + new_table(class_name, columns=table_create), client.schema.naming + ) ) client.schema._bump_version() client.update_stored_schema() client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_update)) + normalize_table_identifiers( + new_table(class_name, columns=table_update), client.schema.naming + ) ) client.schema._bump_version() - with pytest.raises(PropertyNameConflict): + with pytest.raises(SchemaIdentifierNormalizationCollision): client.update_stored_schema() # _, table_columns = client.get_storage_table("ColClass") @@ -171,12 +185,13 @@ def test_load_case_sensitive_data(client: WeaviateClient, file_storage: FileStor client.schema.update_table(new_table(class_name, columns=[table_create["col1"]])) client.schema._bump_version() client.update_stored_schema() - # prepare a data item where is name clash due to Weaviate being CI + # prepare a data item where is name clash due to Weaviate being CS data_clash = {"col1": 72187328, "coL1": 726171} # write row with io.BytesIO() as f: write_dataset(client, f, [data_clash], table_create) query = f.getvalue().decode() + class_name = client.schema.naming.normalize_table_identifier(class_name) with pytest.raises(PropertyNameConflict): expect_load_file(client, file_storage, query, class_name) @@ -202,6 +217,7 @@ def test_load_case_sensitive_data_ci(ci_client: WeaviateClient, file_storage: Fi with io.BytesIO() as f: write_dataset(ci_client, f, [data_clash], table_create) query = f.getvalue().decode() + class_name = ci_client.schema.naming.normalize_table_identifier(class_name) expect_load_file(ci_client, file_storage, query, class_name) response = ci_client.query_class(class_name, ["col1"]).do() objects = response["data"]["Get"][ci_client.make_qualified_class_name(class_name)] diff --git a/tests/load/weaviate/utils.py b/tests/load/weaviate/utils.py index 1b2a74fcb8..b391c2fa38 100644 --- a/tests/load/weaviate/utils.py +++ b/tests/load/weaviate/utils.py @@ -22,53 +22,57 @@ def assert_class( expected_items_count: int = None, items: List[Any] = None, ) -> None: - client: WeaviateClient = pipeline.destination_client() # type: ignore[assignment] - vectorizer_name: str = client._vectorizer_config["vectorizer"] # type: ignore[assignment] - - # Check if class exists - schema = client.get_class_schema(class_name) - assert schema is not None - - columns = pipeline.default_schema.get_table_columns(class_name) - - properties = {prop["name"]: prop for prop in schema["properties"]} - assert set(properties.keys()) == set(columns.keys()) - - # make sure expected columns are vectorized - for column_name, column in columns.items(): - prop = properties[column_name] - assert prop["moduleConfig"][vectorizer_name]["skip"] == ( - not column.get(VECTORIZE_HINT, False) - ) - # tokenization - if TOKENIZATION_HINT in column: - assert prop["tokenization"] == column[TOKENIZATION_HINT] # type: ignore[literal-required] - - # if there's a single vectorize hint, class must have vectorizer enabled - if get_columns_names_with_prop(pipeline.default_schema.get_table(class_name), VECTORIZE_HINT): - assert schema["vectorizer"] == vectorizer_name - else: - assert schema["vectorizer"] == "none" - - # response = db_client.query.get(class_name, list(properties.keys())).do() - response = client.query_class(class_name, list(properties.keys())).do() - objects = response["data"]["Get"][client.make_qualified_class_name(class_name)] - - if expected_items_count is not None: - assert expected_items_count == len(objects) - - if items is None: - return - - # TODO: Remove this once we have a better way comparing the data - drop_keys = ["_dlt_id", "_dlt_load_id"] - objects_without_dlt_keys = [ - {k: v for k, v in obj.items() if k not in drop_keys} for obj in objects - ] - - # pytest compares content wise but ignores order of elements of dict - # assert sorted(objects_without_dlt_keys, key=lambda d: d['doc_id']) == sorted(data, key=lambda d: d['doc_id']) - assert_unordered_list_equal(objects_without_dlt_keys, items) + client: WeaviateClient + with pipeline.destination_client() as client: # type: ignore[assignment] + vectorizer_name: str = client._vectorizer_config["vectorizer"] # type: ignore[assignment] + + # Check if class exists + schema = client.get_class_schema(class_name) + assert schema is not None + + columns = pipeline.default_schema.get_table_columns(class_name) + + properties = {prop["name"]: prop for prop in schema["properties"]} + assert set(properties.keys()) == set(columns.keys()) + + # make sure expected columns are vectorized + for column_name, column in columns.items(): + prop = properties[column_name] + if client._is_collection_vectorized(class_name): + assert prop["moduleConfig"][vectorizer_name]["skip"] == ( + not column.get(VECTORIZE_HINT, False) + ) + # tokenization + if TOKENIZATION_HINT in column: + assert prop["tokenization"] == column[TOKENIZATION_HINT] # type: ignore[literal-required] + + # if there's a single vectorize hint, class must have vectorizer enabled + if get_columns_names_with_prop( + pipeline.default_schema.get_table(class_name), VECTORIZE_HINT + ): + assert schema["vectorizer"] == vectorizer_name + else: + assert schema["vectorizer"] == "none" + + # response = db_client.query.get(class_name, list(properties.keys())).do() + response = client.query_class(class_name, list(properties.keys())).do() + objects = response["data"]["Get"][client.make_qualified_class_name(class_name)] + + if expected_items_count is not None: + assert expected_items_count == len(objects) + + if items is None: + return + + # TODO: Remove this once we have a better way comparing the data + drop_keys = ["_dlt_id", "_dlt_load_id"] + objects_without_dlt_keys = [ + {k: v for k, v in obj.items() if k not in drop_keys} for obj in objects + ] + + # pytest compares content wise but ignores order of elements of dict + # assert sorted(objects_without_dlt_keys, key=lambda d: d['doc_id']) == sorted(data, key=lambda d: d['doc_id']) + assert_unordered_list_equal(objects_without_dlt_keys, items) def delete_classes(p, class_list): @@ -87,10 +91,9 @@ def schema_has_classes(client): if Container()[PipelineContext].is_active(): # take existing pipeline p = dlt.pipeline() - client = p.destination_client() - - if schema_has_classes(client): - client.drop_storage() + with p.destination_client() as client: + if schema_has_classes(client): + client.drop_storage() p._wipe_working_folder() # deactivate context diff --git a/tests/normalize/test_max_nesting.py b/tests/normalize/test_max_nesting.py index 4015836232..5def1617dc 100644 --- a/tests/normalize/test_max_nesting.py +++ b/tests/normalize/test_max_nesting.py @@ -62,7 +62,7 @@ def bot_events(): pipeline = dlt.pipeline( pipeline_name=pipeline_name, destination=dummy(timeout=0.1), - full_refresh=True, + dev_mode=True, ) pipeline.run(bot_events) @@ -169,7 +169,7 @@ def some_data(): pipeline = dlt.pipeline( pipeline_name=pipeline_name, destination=dummy(timeout=0.1), - full_refresh=True, + dev_mode=True, ) pipeline.run(some_data(), write_disposition="append") diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 3891c667c3..7463184be7 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -16,6 +16,7 @@ from dlt.extract.extract import ExtractStorage from dlt.normalize import Normalize +from dlt.normalize.worker import group_worker_files from dlt.normalize.exceptions import NormalizeJobFailed from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES @@ -510,28 +511,28 @@ def test_collect_metrics_on_exception(raw_normalize: Normalize) -> None: def test_group_worker_files() -> None: files = ["f%03d" % idx for idx in range(0, 100)] - assert Normalize.group_worker_files([], 4) == [] - assert Normalize.group_worker_files(["f001"], 1) == [["f001"]] - assert Normalize.group_worker_files(["f001"], 100) == [["f001"]] - assert Normalize.group_worker_files(files[:4], 4) == [["f000"], ["f001"], ["f002"], ["f003"]] - assert Normalize.group_worker_files(files[:5], 4) == [ + assert group_worker_files([], 4) == [] + assert group_worker_files(["f001"], 1) == [["f001"]] + assert group_worker_files(["f001"], 100) == [["f001"]] + assert group_worker_files(files[:4], 4) == [["f000"], ["f001"], ["f002"], ["f003"]] + assert group_worker_files(files[:5], 4) == [ ["f000"], ["f001"], ["f002"], ["f003", "f004"], ] - assert Normalize.group_worker_files(files[:8], 4) == [ + assert group_worker_files(files[:8], 4) == [ ["f000", "f001"], ["f002", "f003"], ["f004", "f005"], ["f006", "f007"], ] - assert Normalize.group_worker_files(files[:8], 3) == [ + assert group_worker_files(files[:8], 3) == [ ["f000", "f001"], ["f002", "f003", "f006"], ["f004", "f005", "f007"], ] - assert Normalize.group_worker_files(files[:5], 3) == [ + assert group_worker_files(files[:5], 3) == [ ["f000"], ["f001", "f003"], ["f002", "f004"], @@ -539,7 +540,7 @@ def test_group_worker_files() -> None: # check if sorted files = ["tab1.1", "chd.3", "tab1.2", "chd.4", "tab1.3"] - assert Normalize.group_worker_files(files, 3) == [ + assert group_worker_files(files, 3) == [ ["chd.3"], ["chd.4", "tab1.2"], ["tab1.1", "tab1.3"], @@ -730,19 +731,22 @@ def test_removal_of_normalizer_schema_section_and_add_seen_data(raw_normalize: N extracted_schema.tables["event__random_table"] = new_table("event__random_table") # add x-normalizer info (and other block to control) - extracted_schema.tables["event"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore + extracted_schema.tables["event"]["x-normalizer"] = {"evolve-columns-once": True} extracted_schema.tables["event"]["x-other-info"] = "blah" # type: ignore - extracted_schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] = {"seen-data": True, "random-entry": 1234} # type: ignore - extracted_schema.tables["event__random_table"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore + extracted_schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] = { + "seen-data": True, + "random-entry": 1234, + } + extracted_schema.tables["event__random_table"]["x-normalizer"] = {"evolve-columns-once": True} normalize_pending(raw_normalize, extracted_schema) schema = raw_normalize.schema_storage.load_schema("event") # seen data gets added, schema settings get removed - assert schema.tables["event"]["x-normalizer"] == {"seen-data": True} # type: ignore - assert schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] == { # type: ignore + assert schema.tables["event"]["x-normalizer"] == {"seen-data": True} + assert schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] == { "seen-data": True, "random-entry": 1234, } # no data seen here, so seen-data is not set and evolve settings stays until first data is seen - assert schema.tables["event__random_table"]["x-normalizer"] == {"evolve-columns-once": True} # type: ignore + assert schema.tables["event__random_table"]["x-normalizer"] == {"evolve-columns-once": True} assert "x-other-info" in schema.tables["event"] diff --git a/tests/normalize/utils.py b/tests/normalize/utils.py index 0ce099d4b6..dffb3f1bb6 100644 --- a/tests/normalize/utils.py +++ b/tests/normalize/utils.py @@ -1,15 +1,10 @@ -from typing import Mapping, cast +from dlt.destinations import duckdb, redshift, postgres, bigquery, filesystem -from dlt.destinations.impl.duckdb import capabilities as duck_insert_caps -from dlt.destinations.impl.redshift import capabilities as rd_insert_caps -from dlt.destinations.impl.postgres import capabilities as pg_insert_caps -from dlt.destinations.impl.bigquery import capabilities as jsonl_caps -from dlt.destinations.impl.filesystem import capabilities as filesystem_caps - -DEFAULT_CAPS = pg_insert_caps -INSERT_CAPS = [duck_insert_caps, rd_insert_caps, pg_insert_caps] -JSONL_CAPS = [jsonl_caps, filesystem_caps] +# callables to capabilities +DEFAULT_CAPS = postgres().capabilities +INSERT_CAPS = [duckdb().capabilities, redshift().capabilities, DEFAULT_CAPS] +JSONL_CAPS = [bigquery().capabilities, filesystem().capabilities] ALL_CAPABILITIES = INSERT_CAPS + JSONL_CAPS diff --git a/tests/pipeline/cases/github_pipeline/github_pipeline.py b/tests/pipeline/cases/github_pipeline/github_pipeline.py index aa0f6d0e0e..f4cdc2bcf2 100644 --- a/tests/pipeline/cases/github_pipeline/github_pipeline.py +++ b/tests/pipeline/cases/github_pipeline/github_pipeline.py @@ -33,11 +33,21 @@ def load_issues( if __name__ == "__main__": - p = dlt.pipeline("dlt_github_pipeline", destination="duckdb", dataset_name="github_3") + # pick the destination name + if len(sys.argv) < 1: + raise RuntimeError(f"Please provide destination name in args ({sys.argv})") + dest_ = sys.argv[1] + if dest_ == "filesystem": + import os + from dlt.destinations import filesystem + + dest_ = filesystem(os.path.abspath(os.path.join("_storage", "data"))) # type: ignore + + p = dlt.pipeline("dlt_github_pipeline", destination=dest_, dataset_name="github_3") github_source = github() - if len(sys.argv) > 1: + if len(sys.argv) > 2: # load only N issues - limit = int(sys.argv[1]) + limit = int(sys.argv[2]) github_source.add_limit(limit) info = p.run(github_source) print(info) diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 0c03a8209d..4cdccb1e34 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -9,7 +9,11 @@ import dlt from dlt.common import json, Decimal from dlt.common.utils import uniq_id -from dlt.common.libs.pyarrow import NameNormalizationClash, remove_columns, normalize_py_arrow_item +from dlt.common.libs.pyarrow import ( + NameNormalizationCollision, + remove_columns, + normalize_py_arrow_item, +) from dlt.pipeline.exceptions import PipelineStepFailed @@ -17,8 +21,8 @@ arrow_table_all_data_types, prepare_shuffled_tables, ) +from tests.pipeline.utils import assert_only_table_columns, load_tables_to_dicts from tests.utils import ( - preserve_environ, TPythonTableFormat, arrow_item_from_pandas, arrow_item_from_table, @@ -223,7 +227,7 @@ def data_frames(): with pytest.raises(PipelineStepFailed) as py_ex: pipeline.extract(data_frames()) - assert isinstance(py_ex.value.__context__, NameNormalizationClash) + assert isinstance(py_ex.value.__context__, NameNormalizationCollision) @pytest.mark.parametrize("item_type", ["arrow-table", "arrow-batch"]) @@ -507,6 +511,48 @@ def test_empty_arrow(item_type: TPythonTableFormat) -> None: assert norm_info.row_counts["items"] == 0 +def test_import_file_with_arrow_schema() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + + # Define the schema based on the CSV input + schema = pa.schema( + [ + ("id", pa.int64()), + ("name", pa.string()), + ("description", pa.string()), + ("ordered_at", pa.date32()), + ("price", pa.float64()), + ] + ) + + # Create empty arrays for each field + empty_arrays = [ + pa.array([], type=pa.int64()), + pa.array([], type=pa.string()), + pa.array([], type=pa.string()), + pa.array([], type=pa.date32()), + pa.array([], type=pa.float64()), + ] + + # Create an empty table with the defined schema + empty_table = pa.Table.from_arrays(empty_arrays, schema=schema) + + # columns should be created from empty table + import_file = "tests/load/cases/loading/header.jsonl" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2, hints=empty_table)], + table_name="no_header", + ) + info.raise_on_failed_jobs() + assert_only_table_columns(pipeline, "no_header", schema.names) + rows = load_tables_to_dicts(pipeline, "no_header") + assert len(rows["no_header"]) == 2 + + @pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) def test_extract_adds_dlt_load_id(item_type: TPythonTableFormat) -> None: os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index ccf926cc62..ba7c0b9db8 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -1,4 +1,5 @@ import sys +from subprocess import CalledProcessError import pytest import tempfile import shutil @@ -14,17 +15,19 @@ from dlt.common.storages import FileStorage from dlt.common.schema.typing import ( LOADS_TABLE_NAME, - STATE_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, VERSION_TABLE_NAME, TStoredSchema, ) from dlt.common.configuration.resolve import resolve_configuration +from dlt.destinations import duckdb, filesystem from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient +from tests.pipeline.utils import load_table_counts from tests.utils import TEST_STORAGE_ROOT, test_storage -if sys.version_info > (3, 11): +if sys.version_info >= (3, 12): pytest.skip("Does not run on Python 3.12 and later", allow_module_level=True) @@ -50,7 +53,9 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # load 20 issues print( venv.run_script( - "../tests/pipeline/cases/github_pipeline/github_pipeline.py", "20" + "../tests/pipeline/cases/github_pipeline/github_pipeline.py", + "duckdb", + "20", ) ) # load schema and check _dlt_loads definition @@ -66,20 +71,23 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: ) # check the dlt state table assert { - "version_hash" not in github_schema["tables"][STATE_TABLE_NAME]["columns"] + "version_hash" + not in github_schema["tables"][PIPELINE_STATE_TABLE_NAME]["columns"] } # check loads table without attaching to pipeline duckdb_cfg = resolve_configuration( DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) - with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: + with DuckDbSqlClient( + GITHUB_DATASET, duckdb_cfg.credentials, duckdb().capabilities() + ) as client: rows = client.execute_sql(f"SELECT * FROM {LOADS_TABLE_NAME}") # make sure we have just 4 columns assert len(rows[0]) == 4 rows = client.execute_sql("SELECT * FROM issues") assert len(rows) == 20 - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME}") + rows = client.execute_sql(f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME}") # only 5 columns + 2 dlt columns assert len(rows[0]) == 5 + 2 # inspect old state @@ -99,7 +107,16 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in current version venv = Venv.restore_current() # load all issues - print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_pipeline.py")) + try: + print( + venv.run_script( + "../tests/pipeline/cases/github_pipeline/github_pipeline.py", "duckdb" + ) + ) + except CalledProcessError as cpe: + print(f"script stdout: {cpe.stdout}") + print(f"script stderr: {cpe.stderr}") + raise # hash hash in schema github_schema = json.loads( test_storage.load( @@ -108,13 +125,16 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: ) assert github_schema["engine_version"] == 9 assert "schema_version_hash" in github_schema["tables"][LOADS_TABLE_NAME]["columns"] + # print(github_schema["tables"][PIPELINE_STATE_TABLE_NAME]) # load state state_dict = json.loads( test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/state.json") ) assert "_version_hash" in state_dict - with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: + with DuckDbSqlClient( + GITHUB_DATASET, duckdb_cfg.credentials, duckdb().capabilities() + ) as client: rows = client.execute_sql( f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY inserted_at" ) @@ -131,7 +151,9 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # two schema versions rows = client.execute_sql(f"SELECT * FROM {VERSION_TABLE_NAME}") assert len(rows) == 2 - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME} ORDER BY version") + rows = client.execute_sql( + f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME} ORDER BY version" + ) # we have hash columns assert len(rows[0]) == 6 + 2 assert len(rows) == 2 @@ -141,23 +163,82 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: assert rows[1][7] == state_dict["_version_hash"] # attach to existing pipeline - pipeline = dlt.attach(GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials) - created_at_value = pipeline.state["sources"]["github"]["resources"]["load_issues"][ - "incremental" - ]["created_at"]["last_value"] - assert isinstance(created_at_value, pendulum.DateTime) - assert created_at_value == pendulum.parse("2023-02-17T09:52:12Z") - pipeline = pipeline.drop() - # print(pipeline.working_dir) - assert pipeline.dataset_name == GITHUB_DATASET - assert pipeline.default_schema_name is None - # sync from destination - pipeline.sync_destination() - # print(pipeline.working_dir) - # we have updated schema - assert pipeline.default_schema.ENGINE_VERSION == 9 - # make sure that schema hash retrieved from the destination is exactly the same as the schema hash that was in storage before the schema was wiped - assert pipeline.default_schema.stored_version_hash == github_schema["version_hash"] + pipeline = dlt.attach( + GITHUB_PIPELINE_NAME, destination=duckdb(credentials=duckdb_cfg.credentials) + ) + assert_github_pipeline_end_state(pipeline, github_schema, 2) + + +def test_filesystem_pipeline_with_dlt_update(test_storage: FileStorage) -> None: + shutil.copytree("tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) + + # execute in test storage + with set_working_dir(TEST_STORAGE_ROOT): + # store dlt data in test storage (like patch_home_dir) + with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + # create virtual env with (0.4.9) where filesystem started to store state + with Venv.create(tempfile.mkdtemp(), ["dlt==0.4.9"]) as venv: + try: + print(venv.run_script("github_pipeline.py", "filesystem", "20")) + except CalledProcessError as cpe: + print(f"script stdout: {cpe.stdout}") + print(f"script stderr: {cpe.stderr}") + raise + # load all issues + venv = Venv.restore_current() + try: + print(venv.run_script("github_pipeline.py", "filesystem")) + except CalledProcessError as cpe: + print(f"script stdout: {cpe.stdout}") + print(f"script stderr: {cpe.stderr}") + raise + # hash hash in schema + github_schema = json.loads( + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" + ) + ) + # attach to existing pipeline + pipeline = dlt.attach(GITHUB_PIPELINE_NAME, destination=filesystem("_storage/data")) + # assert end state + assert_github_pipeline_end_state(pipeline, github_schema, 2) + # load new state + fs_client = pipeline._fs_client() + state_files = sorted(fs_client.list_table_files("_dlt_pipeline_state")) + # first file is in old format + state_1 = json.loads(fs_client.read_text(state_files[0], encoding="utf-8")) + assert "dlt_load_id" in state_1 + # seconds is new + state_2 = json.loads(fs_client.read_text(state_files[1], encoding="utf-8")) + assert "_dlt_load_id" in state_2 + + +def assert_github_pipeline_end_state( + pipeline: dlt.Pipeline, orig_schema: TStoredSchema, schema_updates: int +) -> None: + # get tables counts + table_counts = load_table_counts(pipeline, *pipeline.default_schema.data_table_names()) + assert table_counts == {"issues": 100, "issues__assignees": 31, "issues__labels": 34} + dlt_counts = load_table_counts(pipeline, *pipeline.default_schema.dlt_table_names()) + assert dlt_counts == {"_dlt_version": schema_updates, "_dlt_loads": 2, "_dlt_pipeline_state": 2} + + # check state + created_at_value = pipeline.state["sources"]["github"]["resources"]["load_issues"][ + "incremental" + ]["created_at"]["last_value"] + assert isinstance(created_at_value, pendulum.DateTime) + assert created_at_value == pendulum.parse("2023-02-17T09:52:12Z") + pipeline = pipeline.drop() + # print(pipeline.working_dir) + assert pipeline.dataset_name == GITHUB_DATASET + assert pipeline.default_schema_name is None + # sync from destination + pipeline.sync_destination() + # print(pipeline.working_dir) + # we have updated schema + assert pipeline.default_schema.ENGINE_VERSION == 9 + # make sure that schema hash retrieved from the destination is exactly the same as the schema hash that was in storage before the schema was wiped + assert pipeline.default_schema.stored_version_hash == orig_schema["version_hash"] def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: @@ -182,7 +263,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) print( venv.run_script( - "../tests/pipeline/cases/github_pipeline/github_normalize.py", + "../tests/pipeline/cases/github_pipeline/github_normalize.py" ) ) # switch to current version and make sure the load package loads and schema migrates @@ -192,7 +273,9 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) - with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: + with DuckDbSqlClient( + GITHUB_DATASET, duckdb_cfg.credentials, duckdb().capabilities() + ) as client: rows = client.execute_sql("SELECT * FROM issues") assert len(rows) == 70 github_schema = json.loads( @@ -201,7 +284,9 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) ) # attach to existing pipeline - pipeline = dlt.attach(GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials) + pipeline = dlt.attach( + GITHUB_PIPELINE_NAME, destination=duckdb(credentials=duckdb_cfg.credentials) + ) # get the schema from schema storage before we sync github_schema = json.loads( test_storage.load( @@ -217,7 +302,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: assert pipeline.state["_version_hash"] is not None # but in db there's no hash - we loaded an old package with backward compatible schema with pipeline.sql_client() as client: - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME}") + rows = client.execute_sql(f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME}") # no hash assert len(rows[0]) == 5 + 2 assert len(rows) == 1 @@ -227,7 +312,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: # this will sync schema to destination pipeline.sync_schema() # we have hash now - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME}") + rows = client.execute_sql(f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME}") assert len(rows[0]) == 6 + 2 diff --git a/tests/pipeline/test_import_export_schema.py b/tests/pipeline/test_import_export_schema.py index 6f40e1d1eb..eb36d36ba3 100644 --- a/tests/pipeline/test_import_export_schema.py +++ b/tests/pipeline/test_import_export_schema.py @@ -117,7 +117,7 @@ def test_import_schema_is_respected() -> None: destination=dummy(completed_prob=1), import_schema_path=IMPORT_SCHEMA_PATH, export_schema_path=EXPORT_SCHEMA_PATH, - full_refresh=True, + dev_mode=True, ) p.extract(EXAMPLE_DATA, table_name="person") # starts with import schema v 1 that is dirty -> 2 @@ -153,7 +153,7 @@ def resource(): destination=dummy(completed_prob=1), import_schema_path=IMPORT_SCHEMA_PATH, export_schema_path=EXPORT_SCHEMA_PATH, - full_refresh=True, + dev_mode=True, ) p.run(source()) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index f838f31333..95b97c7666 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -7,7 +7,7 @@ import random import threading from time import sleep -from typing import Any, Tuple, cast +from typing import Any, List, Tuple, cast from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest @@ -19,6 +19,7 @@ from dlt.common.configuration.specs.aws_credentials import AwsCredentials from dlt.common.configuration.specs.exceptions import NativeValueError from dlt.common.configuration.specs.gcp_credentials import GcpOAuthCredentials +from dlt.common.data_writers.exceptions import FileImportNotFound, SpecLookupFailed from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import WithStateSync from dlt.common.destination.exceptions import ( @@ -32,6 +33,8 @@ from dlt.common.exceptions import PipelineStateNotAvailable from dlt.common.pipeline import LoadInfo, PipelineContext from dlt.common.runtime.collector import LogCollector +from dlt.common.schema.exceptions import TableIdentifiersFrozen +from dlt.common.schema.typing import TColumnSchema from dlt.common.schema.utils import new_column, new_table from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id @@ -44,9 +47,11 @@ from dlt.extract import DltResource, DltSource from dlt.extract.extractors import MaterializedEmptyList from dlt.load.exceptions import LoadClientJobFailed +from dlt.normalize.exceptions import NormalizeJobFailed from dlt.pipeline.exceptions import InvalidPipelineName, PipelineNotActive, PipelineStepFailed from dlt.pipeline.helpers import retry_load +from dlt.pipeline.pipeline import Pipeline from tests.common.utils import TEST_SENTRY_DSN from tests.common.configuration.utils import environment from tests.utils import TEST_STORAGE_ROOT, skipifnotwindows @@ -55,7 +60,9 @@ assert_data_table_counts, assert_load_info, airtable_emojis, + assert_only_table_columns, load_data_table_counts, + load_tables_to_dicts, many_delayed, ) @@ -201,7 +208,8 @@ def test_pipeline_context() -> None: assert ctx.pipeline() is p3 assert p3.is_active is True assert p2.is_active is False - assert Container()[DestinationCapabilitiesContext].naming_convention == "snake_case" + # no default naming convention + assert Container()[DestinationCapabilitiesContext].naming_convention is None # restore previous p2 = dlt.attach("another pipeline") @@ -1539,10 +1547,13 @@ def autodetect(): pipeline = pipeline.drop() source = autodetect() + assert "timestamp" in source.schema.settings["detections"] source.schema.remove_type_detection("timestamp") + assert "timestamp" not in source.schema.settings["detections"] pipeline = dlt.pipeline(destination="duckdb") pipeline.run(source) + assert "timestamp" not in pipeline.default_schema.settings["detections"] assert pipeline.default_schema.get_table("numbers")["columns"]["value"]["data_type"] == "bigint" @@ -1969,7 +1980,7 @@ def source(): assert len(load_info.loads_ids) == 1 -def test_pipeline_load_info_metrics_schema_is_not_chaning() -> None: +def test_pipeline_load_info_metrics_schema_is_not_changing() -> None: """Test if load info schema is idempotent throughout multiple load cycles ## Setup @@ -2025,7 +2036,6 @@ def demand_map(): pipeline_name="quick_start", destination="duckdb", dataset_name="mydata", - # export_schema_path="schemas", ) taxi_load_info = pipeline.run( @@ -2243,7 +2253,7 @@ def test_data(): pipeline = dlt.pipeline( pipeline_name="test_staging_cleared", destination="duckdb", - full_refresh=True, + dev_mode=True, ) info = pipeline.run(test_data, table_name="staging_cleared") @@ -2260,3 +2270,198 @@ def test_data(): with client.execute_query(f"SELECT * FROM {pipeline.dataset_name}.staging_cleared") as cur: assert len(cur.fetchall()) == 3 + + +def test_change_naming_convention_name_collision() -> None: + duck_ = dlt.destinations.duckdb(naming_convention="duck_case", recommended_file_size=120000) + caps = duck_.capabilities() + assert caps.naming_convention == "duck_case" + assert caps.recommended_file_size == 120000 + + # use duck case to load data into duckdb so casing and emoji are preserved + pipeline = dlt.pipeline("test_change_naming_convention_name_collision", destination=duck_) + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + assert_load_info(info) + # make sure that emojis got in + assert "🦚Peacock" in pipeline.default_schema.tables + assert "🔑id" in pipeline.default_schema.tables["🦚Peacock"]["columns"] + assert load_data_table_counts(pipeline) == { + "📆 Schedule": 3, + "🦚Peacock": 1, + "🦚WidePeacock": 1, + "🦚Peacock__peacock": 3, + "🦚WidePeacock__Peacock": 3, + } + with pipeline.sql_client() as client: + rows = client.execute_sql("SELECT 🔑id FROM 🦚Peacock") + # 🔑id value is 1 + assert rows[0][0] == 1 + + # change naming convention and run pipeline again so we generate name clashes + os.environ["SOURCES__AIRTABLE_EMOJIS__SCHEMA__NAMING"] = "sql_ci_v1" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock")) + assert isinstance(pip_ex.value.__cause__, TableIdentifiersFrozen) + + # all good if we drop tables + # info = pipeline.run( + # airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + # refresh="drop_resources", + # ) + # assert_load_info(info) + # assert load_data_table_counts(pipeline) == { + # "📆 Schedule": 3, + # "🦚Peacock": 1, + # "🦚WidePeacock": 1, + # "🦚Peacock__peacock": 3, + # "🦚WidePeacock__Peacock": 3, + # } + + +def test_change_naming_convention_column_collision() -> None: + duck_ = dlt.destinations.duckdb(naming_convention="duck_case") + + data = {"Col": "A"} + pipeline = dlt.pipeline("test_change_naming_convention_column_collision", destination=duck_) + info = pipeline.run([data], table_name="data") + assert_load_info(info) + + os.environ["SCHEMA__NAMING"] = "sql_ci_v1" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run([data], table_name="data") + assert isinstance(pip_ex.value.__cause__, TableIdentifiersFrozen) + + +def test_import_jsonl_file() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + columns: List[TColumnSchema] = [ + {"name": "id", "data_type": "bigint", "nullable": False}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, + ] + import_file = "tests/load/cases/loading/header.jsonl" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2)], + table_name="no_header", + loader_file_format="jsonl", + columns=columns, + ) + info.raise_on_failed_jobs() + print(info) + assert_imported_file(pipeline, "no_header", columns, 2) + + # use hints to infer + hints = dlt.mark.make_hints(columns=columns) + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2, hints=hints)], + table_name="no_header_2", + ) + info.raise_on_failed_jobs() + assert_imported_file(pipeline, "no_header_2", columns, 2, expects_state=False) + + +def test_import_file_without_sniff_schema() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + import_file = "tests/load/cases/loading/header.jsonl" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2)], + table_name="no_header", + ) + assert info.has_failed_jobs + print(info) + + +def test_import_non_existing_file() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + # this file does not exist + import_file = "tests/load/cases/loading/X_header.jsonl" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2)], + table_name="no_header", + ) + inner_ex = pip_ex.value.__cause__ + assert isinstance(inner_ex, FileImportNotFound) + assert inner_ex.import_file_path == import_file + + +def test_import_unsupported_file_format() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + # this file does not exist + import_file = "tests/load/cases/loading/csv_no_header.csv" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run( + [dlt.mark.with_file_import(import_file, "csv", 2)], + table_name="no_header", + ) + inner_ex = pip_ex.value.__cause__ + assert isinstance(inner_ex, NormalizeJobFailed) + assert isinstance(inner_ex.__cause__, SpecLookupFailed) + + +def test_import_unknown_file_format() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + # this file does not exist + import_file = "tests/load/cases/loading/csv_no_header.csv" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run( + [dlt.mark.with_file_import(import_file, "unknown", 2)], # type: ignore[arg-type] + table_name="no_header", + ) + inner_ex = pip_ex.value.__cause__ + assert isinstance(inner_ex, NormalizeJobFailed) + # can't figure format from extension + assert isinstance(inner_ex.__cause__, ValueError) + + +def assert_imported_file( + pipeline: Pipeline, + table_name: str, + columns: List[TColumnSchema], + expected_rows: int, + expects_state: bool = True, +) -> None: + assert_only_table_columns(pipeline, table_name, [col["name"] for col in columns]) + rows = load_tables_to_dicts(pipeline, table_name) + assert len(rows[table_name]) == expected_rows + # we should have twp files loaded + jobs = pipeline.last_trace.last_load_info.load_packages[0].jobs["completed_jobs"] + job_extensions = [os.path.splitext(job.job_file_info.file_name())[1] for job in jobs] + assert ".jsonl" in job_extensions + if expects_state: + assert ".insert_values" in job_extensions + # check extract trace if jsonl is really there + extract_info = pipeline.last_trace.last_extract_info + jobs = extract_info.load_packages[0].jobs["new_jobs"] + # find jsonl job + jsonl_job = next(job for job in jobs if job.job_file_info.table_name == table_name) + assert jsonl_job.job_file_info.file_format == "jsonl" + # find metrics for table + assert ( + extract_info.metrics[extract_info.loads_ids[0]][0]["table_metrics"][table_name].items_count + == expected_rows + ) diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index 7208216c9f..308cdcd91d 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -40,7 +40,11 @@ class BaseModel: # type: ignore[no-redef] @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs( + default_sql_configs=True, default_vector_configs=True, local_filesystem_configs=True + ), + ids=lambda x: x.name, ) def test_create_pipeline_all_destinations(destination_config: DestinationTestConfiguration) -> None: # create pipelines, extract and normalize. that should be possible without installing any dependencies @@ -51,11 +55,11 @@ def test_create_pipeline_all_destinations(destination_config: DestinationTestCon ) # are capabilities injected caps = p._container[DestinationCapabilitiesContext] - print(caps.naming_convention) - # are right naming conventions created - assert p._default_naming.max_length == min( - caps.max_column_identifier_length, caps.max_identifier_length - ) + if caps.naming_convention: + assert p.naming.name() == caps.naming_convention + else: + assert p.naming.name() == "snake_case" + p.extract([1, "2", 3], table_name="data") # is default schema with right naming convention assert p.default_schema.naming.max_length == min( @@ -469,6 +473,61 @@ def users(): assert set(table.schema.names) == {"id", "name", "_dlt_load_id", "_dlt_id"} +def test_resource_file_format() -> None: + os.environ["RESTORE_FROM_DESTINATION"] = "False" + + def jsonl_data(): + yield [ + { + "id": 1, + "name": "item", + "description": "value", + "ordered_at": "2024-04-12", + "price": 128.4, + }, + { + "id": 1, + "name": "item", + "description": "value with space", + "ordered_at": "2024-04-12", + "price": 128.4, + }, + ] + + # preferred file format will use destination preferred format + jsonl_preferred = dlt.resource(jsonl_data, file_format="preferred", name="jsonl_preferred") + assert jsonl_preferred.compute_table_schema()["file_format"] == "preferred" + + jsonl_r = dlt.resource(jsonl_data, file_format="jsonl", name="jsonl_r") + assert jsonl_r.compute_table_schema()["file_format"] == "jsonl" + + jsonl_pq = dlt.resource(jsonl_data, file_format="parquet", name="jsonl_pq") + assert jsonl_pq.compute_table_schema()["file_format"] == "parquet" + + info = dlt.pipeline("example", destination="duckdb").run([jsonl_preferred, jsonl_r, jsonl_pq]) + info.raise_on_failed_jobs() + # check file types on load jobs + load_jobs = { + job.job_file_info.table_name: job.job_file_info + for job in info.load_packages[0].jobs["completed_jobs"] + } + assert load_jobs["jsonl_r"].file_format == "jsonl" + assert load_jobs["jsonl_pq"].file_format == "parquet" + assert load_jobs["jsonl_preferred"].file_format == "insert_values" + + # test not supported format + csv_r = dlt.resource(jsonl_data, file_format="csv", name="csv_r") + assert csv_r.compute_table_schema()["file_format"] == "csv" + info = dlt.pipeline("example", destination="duckdb").run(csv_r) + info.raise_on_failed_jobs() + # fallback to preferred + load_jobs = { + job.job_file_info.table_name: job.job_file_info + for job in info.load_packages[0].jobs["completed_jobs"] + } + assert load_jobs["csv_r"].file_format == "insert_values" + + def test_pick_matching_file_format(test_storage: FileStorage) -> None: from dlt.destinations import filesystem diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index 8cbc1ca516..11c45d72cc 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -1,20 +1,25 @@ import os import shutil +from typing_extensions import get_type_hints import pytest import dlt - +from dlt.common.pendulum import pendulum from dlt.common.exceptions import ( PipelineStateNotAvailable, ResourceNameNotAvailable, ) from dlt.common.schema import Schema +from dlt.common.schema.utils import pipeline_state_table from dlt.common.source import get_current_pipe_name from dlt.common.storages import FileStorage from dlt.common import pipeline as state_module +from dlt.common.storages.load_package import TPipelineStateDoc from dlt.common.utils import uniq_id -from dlt.common.destination.reference import Destination +from dlt.common.destination.reference import Destination, StateInfo +from dlt.common.validation import validate_dict +from dlt.destinations.utils import get_pipeline_state_query_columns from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import ( @@ -41,6 +46,56 @@ def some_data_resource_state(): dlt.current.resource_state()["last_value"] = last_value + 1 +def test_state_repr() -> None: + """Verify that all possible state representations match""" + table = pipeline_state_table() + state_doc_hints = get_type_hints(TPipelineStateDoc) + sync_class_hints = get_type_hints(StateInfo) + info = StateInfo(1, 4, "pipeline", "compressed", pendulum.now(), "hash", "_load_id") + state_doc = info.as_doc() + # just in case hardcode column order + reference_cols = [ + "version", + "engine_version", + "pipeline_name", + "state", + "created_at", + "version_hash", + "_dlt_load_id", + ] + # doc and table must be in the same order with the same name + assert ( + len(table["columns"]) + == len(state_doc_hints) + == len(sync_class_hints) + == len(state_doc) + == len(reference_cols) + ) + for col, hint, class_hint, val, ref_col in zip( + table["columns"].values(), state_doc_hints, sync_class_hints, state_doc, reference_cols + ): + assert col["name"] == hint == class_hint == val == ref_col + + # validate info + validate_dict(TPipelineStateDoc, state_doc, "$") + + info = StateInfo(1, 4, "pipeline", "compressed", pendulum.now()) + state_doc = info.as_doc() + assert "_dlt_load_id" not in state_doc + assert "version_hash" not in state_doc + + # we drop hash in query + compat_table = get_pipeline_state_query_columns() + assert list(compat_table["columns"].keys()) == [ + "version", + "engine_version", + "pipeline_name", + "state", + "created_at", + "_dlt_load_id", + ] + + def test_restore_state_props() -> None: p = dlt.pipeline( pipeline_name="restore_state_props", diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index 7affcc5a81..c10618a7cc 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -52,7 +52,7 @@ def peacock(): @dlt.resource(name="🦚WidePeacock", selected=False) def wide_peacock(): - yield [{"peacock": [1, 2, 3]}] + yield [{"Peacock": [1, 2, 3]}] return budget, schedule, peacock, wide_peacock @@ -198,7 +198,7 @@ def _load_tables_to_dicts_sql( for table_name in table_names: table_rows = [] columns = schema.get_table_columns(table_name).keys() - query_columns = ",".join(map(p.sql_client().capabilities.escape_identifier, columns)) + query_columns = ",".join(map(p.sql_client().escape_column_name, columns)) with p.sql_client() as c: query_columns = ",".join(map(c.escape_column_name, columns)) diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 7196ef3436..aa3f02e51d 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -234,7 +234,6 @@ def test_oauth2_client_credentials_flow_wrong_client_secret(self, rest_client: R assert e.type == HTTPError assert e.match("401 Client Error") - def test_oauth_token_expired_refresh(self, rest_client_immediate_oauth_expiry: RESTClient): rest_client = rest_client_immediate_oauth_expiry auth = cast(OAuth2ClientCredentials, rest_client.auth) diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 9ca54e814c..e5d31c52d2 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -3,6 +3,7 @@ import pytest from requests.models import Response, Request +from requests import Session from dlt.sources.helpers.rest_client.paginators import ( SinglePagePaginator, @@ -157,6 +158,30 @@ def test_update_request(self, test_case): paginator.update_request(request) assert request.url == test_case["expected"] + def test_no_duplicate_params_on_update_request(self): + paginator = JSONResponsePaginator() + + request = Request( + method="GET", + url="http://example.com/api/resource", + params={"param1": "value1"}, + ) + + session = Session() + + response = Mock(Response, json=lambda: {"next": "/api/resource?page=2¶m1=value1"}) + paginator.update_state(response) + paginator.update_request(request) + + assert request.url == "http://example.com/api/resource?page=2¶m1=value1" + + # RESTClient._send_request() calls Session.prepare_request() which + # updates the URL with the query parameters from the request object. + prepared_request = session.prepare_request(request) + + # The next request should just use the "next" URL without any duplicate parameters. + assert prepared_request.url == "http://example.com/api/resource?page=2¶m1=value1" + class TestSinglePagePaginator: def test_update_state(self): diff --git a/tests/utils.py b/tests/utils.py index 580c040706..bf3aafdb77 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,13 +45,22 @@ "motherduck", "mssql", "qdrant", + "lancedb", "destination", "synapse", "databricks", "clickhouse", "dremio", } -NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "destination"} +NON_SQL_DESTINATIONS = { + "filesystem", + "weaviate", + "dummy", + "motherduck", + "qdrant", + "lancedb", + "destination", +} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS # exclude destination configs (for now used for athena and athena iceberg separation) @@ -173,7 +182,7 @@ def unload_modules() -> Iterator[None]: @pytest.fixture(autouse=True) -def wipe_pipeline() -> Iterator[None]: +def wipe_pipeline(preserve_environ) -> Iterator[None]: """Wipes pipeline local state and deactivates it""" container = Container() if container[PipelineContext].is_active():