Skip to content

Commit

Permalink
Merge branch 'devel' into d#/data_contracts
Browse files Browse the repository at this point in the history
# Conflicts:
#	tests/common/schema/test_filtering.py
#	tests/common/schema/test_versioning.py
#	tests/common/test_typing.py
#	tests/load/test_job_client.py
#	tests/load/utils.py
  • Loading branch information
sh-rp committed Oct 2, 2023
2 parents b72a1a9 + c361034 commit f2abadf
Show file tree
Hide file tree
Showing 127 changed files with 4,837 additions and 4,532 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dev: has-poetry

lint:
./check-package.sh
poetry run mypy --config-file mypy.ini dlt
poetry run mypy --config-file mypy.ini dlt tests
poetry run flake8 --max-line-length=200 dlt
poetry run flake8 --max-line-length=200 tests --exclude tests/reflection/module_cases
# $(MAKE) lint-security
Expand Down
13 changes: 8 additions & 5 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gzip
from typing import List, IO, Any, Optional, Type
from typing import List, IO, Any, Optional, Type, TypeVar, Generic

from dlt.common.utils import uniq_id
from dlt.common.typing import TDataItem, TDataItems
Expand All @@ -12,7 +12,10 @@
from dlt.common.destination import DestinationCapabilitiesContext


class BufferedDataWriter:
TWriter = TypeVar("TWriter", bound=DataWriter)


class BufferedDataWriter(Generic[TWriter]):

@configspec
class BufferedDataWriterConfiguration(BaseConfiguration):
Expand Down Expand Up @@ -55,7 +58,7 @@ def __init__(
self._current_columns: TTableSchemaColumns = None
self._file_name: str = None
self._buffered_items: List[TDataItem] = []
self._writer: DataWriter = None
self._writer: TWriter = None
self._file: IO[Any] = None
self._closed = False
try:
Expand Down Expand Up @@ -104,7 +107,7 @@ def close(self) -> None:
def closed(self) -> bool:
return self._closed

def __enter__(self) -> "BufferedDataWriter":
def __enter__(self) -> "BufferedDataWriter[TWriter]":
return self

def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None:
Expand All @@ -123,7 +126,7 @@ def _flush_items(self, allow_empty_file: bool = False) -> None:
self._file = self.open(self._file_name, "wb") # type: ignore
else:
self._file = self.open(self._file_name, "wt", encoding="utf-8") # type: ignore
self._writer = DataWriter.from_file_format(self.file_format, self._file, caps=self._caps)
self._writer = DataWriter.from_file_format(self.file_format, self._file, caps=self._caps) # type: ignore[assignment]
self._writer.write_header(self._current_columns)
# write buffer
if self._buffered_items:
Expand Down
8 changes: 4 additions & 4 deletions dlt/common/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def dumps(self, obj: Any, sort_keys: bool = False, pretty:bool = False) -> str:
def dumpb(self, obj: Any, sort_keys: bool = False, pretty:bool = False) -> bytes:
...

def load(self, fp: IO[bytes]) -> Any:
def load(self, fp: Union[IO[bytes], IO[str]]) -> Any:
...

def loads(self, s: str) -> Any:
Expand Down Expand Up @@ -185,11 +185,11 @@ def custom_pua_remove(obj: Any) -> Any:
json: SupportsJson = None
if os.environ.get("DLT_USE_JSON") == "simplejson":
from dlt.common.json import _simplejson as _json_d
json = _json_d
json = _json_d # type: ignore[assignment]
else:
try:
from dlt.common.json import _orjson as _json_or
json = _json_or
json = _json_or # type: ignore[assignment]
except ImportError:
from dlt.common.json import _simplejson as _json_simple
json = _json_simple
json = _json_simple # type: ignore[assignment]
16 changes: 7 additions & 9 deletions dlt/common/normalizers/naming/duck_case.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
import re
from functools import lru_cache

from dlt.common.normalizers.naming.snake_case import NamingConvention as BaseNamingConvention
from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention


class NamingConvention(BaseNamingConvention):
class NamingConvention(SnakeCaseNamingConvention):

_RE_NON_ALPHANUMERIC = re.compile(r"[^a-zA-Z\d_+-]+")
_REDUCE_ALPHABET = ("*@|", "xal")
_TR_REDUCE_ALPHABET = str.maketrans(_REDUCE_ALPHABET[0], _REDUCE_ALPHABET[1])
_CLEANUP_TABLE = str.maketrans("\n\r\"", "___")
_RE_LEADING_DIGITS = None # do not remove leading digits

@staticmethod
@lru_cache(maxsize=None)
def _normalize_identifier(identifier: str, max_length: int) -> str:
"""Normalizes the identifier according to naming convention represented by this function"""
# all characters that are not letters digits or a few special chars are replaced with underscore
normalized_ident = identifier.translate(NamingConvention._TR_REDUCE_ALPHABET)
normalized_ident = NamingConvention._RE_NON_ALPHANUMERIC.sub("_", normalized_ident)

normalized_ident = identifier.translate(NamingConvention._CLEANUP_TABLE)

# shorten identifier
return NamingConvention.shorten_identifier(
NamingConvention._to_snake_case(normalized_ident),
NamingConvention._RE_UNDERSCORES.sub("_", normalized_ident),
identifier,
max_length
)
16 changes: 8 additions & 8 deletions dlt/common/normalizers/naming/snake_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,21 @@ def _normalize_identifier(identifier: str, max_length: int) -> str:
max_length
)

@staticmethod
def _to_snake_case(identifier: str) -> str:
@classmethod
def _to_snake_case(cls, identifier: str) -> str:
# then convert to snake case
identifier = NamingConvention._SNAKE_CASE_BREAK_1.sub(r'\1_\2', identifier)
identifier = NamingConvention._SNAKE_CASE_BREAK_2.sub(r'\1_\2', identifier).lower()
identifier = cls._SNAKE_CASE_BREAK_1.sub(r'\1_\2', identifier)
identifier = cls._SNAKE_CASE_BREAK_2.sub(r'\1_\2', identifier).lower()

# leading digits will be prefixed
if NamingConvention._RE_LEADING_DIGITS.match(identifier):
# leading digits will be prefixed (if regex is defined)
if cls._RE_LEADING_DIGITS and cls._RE_LEADING_DIGITS.match(identifier):
identifier = "_" + identifier

# replace trailing _ with x
stripped_ident = identifier.rstrip("_")
strip_count = len(identifier) - len(stripped_ident)
stripped_ident += "x" * strip_count

# identifier = NamingConvention._RE_ENDING_UNDERSCORES.sub("x", identifier)
# identifier = cls._RE_ENDING_UNDERSCORES.sub("x", identifier)
# replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR
return NamingConvention._RE_UNDERSCORES.sub("_", stripped_ident)
return cls._RE_UNDERSCORES.sub("_", stripped_ident)
5 changes: 4 additions & 1 deletion dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import humanize
import contextlib
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, TypedDict
from typing_extensions import NotRequired

from dlt.common import pendulum, logger
from dlt.common.configuration import configspec
Expand Down Expand Up @@ -163,9 +164,11 @@ class TPipelineState(TypedDict, total=False):
_local: TPipelineLocalState
"""A section of state that is not synchronized with the destination and does not participate in change merging and version control"""

sources: NotRequired[Dict[str, Dict[str, Any]]]


class TSourceState(TPipelineState):
sources: Dict[str, Dict[str, Any]]
sources: Dict[str, Dict[str, Any]] # type: ignore[misc]


class SupportsPipeline(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/runtime/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(self, single_bar: bool = False, **tqdm_kwargs: Any) -> None:
except ModuleNotFoundError:
raise MissingDependencyException("TqdmCollector", ["tqdm"], "We need tqdm to display progress bars.")
self.single_bar = single_bar
self._bars: Dict[str, tqdm] = {}
self._bars: Dict[str, tqdm[None]] = {}
self.tqdm_kwargs = tqdm_kwargs or {}

def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "") -> None:
Expand Down
4 changes: 2 additions & 2 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib

from copy import deepcopy, copy
from typing import Dict, List, Sequence, Tuple, Type, Any, cast, Iterable, Optional
from typing import Dict, List, Sequence, Tuple, Type, Any, cast, Iterable, Optional, Union

from dlt.common import json
from dlt.common.data_types import TDataType
Expand Down Expand Up @@ -484,7 +484,7 @@ def hint_to_column_prop(h: TColumnHint) -> TColumnProp:
return h


def get_columns_names_with_prop(table: TTableSchema, column_prop: TColumnProp, include_incomplete: bool = False) -> List[str]:
def get_columns_names_with_prop(table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False) -> List[str]:
# column_prop: TColumnProp = hint_to_column_prop(hint_type)
# default = column_prop != "nullable" # default is true, only for nullable false
return [c["name"] for c in table["columns"].values() if bool(c.get(column_prop, False)) is True and (include_incomplete or is_complete_column(c))]
Expand Down
8 changes: 4 additions & 4 deletions dlt/common/storages/data_item_storage.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Dict, Any, List
from typing import Dict, Any, List, Generic
from abc import ABC, abstractmethod

from dlt.common import logger
from dlt.common.schema import TTableSchemaColumns
from dlt.common.typing import TDataItems
from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter
from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter, DataWriter


class DataItemStorage(ABC):
def __init__(self, load_file_type: TLoaderFileFormat, *args: Any) -> None:
self.loader_file_format = load_file_type
self.buffered_writers: Dict[str, BufferedDataWriter] = {}
self.buffered_writers: Dict[str, BufferedDataWriter[DataWriter]] = {}
super().__init__(*args)

def get_writer(self, load_id: str, schema_name: str, table_name: str) -> BufferedDataWriter:
def get_writer(self, load_id: str, schema_name: str, table_name: str) -> BufferedDataWriter[DataWriter]:
# unique writer id
writer_id = f"{load_id}.{schema_name}.{table_name}"
writer = self.buffered_writers.get(writer_id, None)
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/storages/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import shutil
import pathvalidate
from typing import IO, Any, Optional, List, cast
from typing import IO, Any, Optional, List, cast, overload
from dlt.common.typing import AnyFun

from dlt.common.utils import encoding_for_mode, uniq_id
Expand Down
1 change: 1 addition & 0 deletions dlt/common/storages/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FileItem(TypedDict):
file_name: str
mime_type: str
modification_date: pendulum.DateTime
size_in_bytes: int
file_content: Optional[Union[str, bytes]]


Expand Down
6 changes: 5 additions & 1 deletion dlt/common/storages/versioned_storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

import semver

from dlt.common.storages.file_storage import FileStorage
Expand All @@ -8,7 +10,9 @@ class VersionedStorage:

VERSION_FILE = ".version"

def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileStorage) -> None:
def __init__(self, version: Union[semver.VersionInfo, str], is_owner: bool, storage: FileStorage) -> None:
if isinstance(version, str):
version = semver.VersionInfo.parse(version)
self.storage = storage
# read current version
if self.storage.has_file(VersionedStorage.VERSION_FILE):
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,4 @@ def get_generic_type_argument_from_instance(instance: Any, sample_value: Optiona
orig_param_type = get_args(instance.__orig_class__)[0]
if orig_param_type is Any and sample_value is not None:
orig_param_type = type(sample_value)
return orig_param_type # type: ignore
return orig_param_type # type: ignore
2 changes: 1 addition & 1 deletion dlt/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def str2bool(v: str) -> bool:
# return o


def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> StrAny:
def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> DictStrAny:
"""
Transforms a list of objects or strings [{K: {...}}, L, ...] -> {K: {...}, L: None, ...}
"""
Expand Down
1 change: 0 additions & 1 deletion dlt/destinations/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0)
caps.max_identifier_length = 65536
caps.max_column_identifier_length = 65536
caps.naming_convention = "duck_case"
caps.max_query_length = 32 * 1024 * 1024
caps.is_max_query_length_in_bytes = True
caps.max_text_data_type_length = 1024 * 1024 * 1024
Expand Down
21 changes: 19 additions & 2 deletions dlt/destinations/duckdb/configuration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import threading
from pathvalidate import is_valid_filepath
from typing import Any, ClassVar, Final, List, Optional, Tuple
from typing import Any, ClassVar, Final, List, Optional, Tuple, TYPE_CHECKING, Union

from dlt.common import logger
from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ConnectionStringCredentials
from dlt.common.configuration.specs.exceptions import InvalidConnectionString
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration, DestinationClientStagingConfiguration
from dlt.common.typing import TSecretValue

DUCK_DB_NAME = "%s.duckdb"
Expand Down Expand Up @@ -180,3 +180,20 @@ class DuckDbClientConfiguration(DestinationClientDwhWithStagingConfiguration):
credentials: DuckDbCredentials

create_indexes: bool = False # should unique indexes be created, this slows loading down massively

if TYPE_CHECKING:
try:
from duckdb import DuckDBPyConnection
except ModuleNotFoundError:
DuckDBPyConnection = Any # type: ignore[assignment,misc]

def __init__(
self,
destination_name: str = None,
credentials: Union[DuckDbCredentials, str, DuckDBPyConnection] = None,
dataset_name: str = None,
default_schema_name: Optional[str] = None,
create_indexes: bool = False,
staging_config: Optional[DestinationClientStagingConfiguration] = None
) -> None:
...
5 changes: 4 additions & 1 deletion dlt/destinations/duckdb/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def fully_qualified_dataset_name(self, escape: bool = True) -> str:
@classmethod
def _make_database_exception(cls, ex: Exception) -> Exception:
if isinstance(ex, (duckdb.CatalogException)):
raise DatabaseUndefinedRelation(ex)
if "already exists" in str(ex):
raise DatabaseTerminalException(ex)
else:
raise DatabaseUndefinedRelation(ex)
elif isinstance(ex, duckdb.InvalidInputException):
if "Catalog Error" in str(ex):
raise DatabaseUndefinedRelation(ex)
Expand Down
17 changes: 17 additions & 0 deletions dlt/destinations/dummy/configuration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING, Optional

from dlt.common.configuration import configspec
from dlt.common.destination import TLoaderFileFormat
from dlt.common.destination.reference import DestinationClientConfiguration, CredentialsConfiguration
Expand All @@ -22,3 +24,18 @@ class DummyClientConfiguration(DestinationClientConfiguration):
fail_in_init: bool = True

credentials: DummyClientCredentials = None

if TYPE_CHECKING:
def __init__(
self,
destination_name: str = None,
credentials: Optional[CredentialsConfiguration] = None,
loader_file_format: TLoaderFileFormat = None,
fail_schema_update: bool = None,
fail_prob: float = None,
retry_prob: float = None,
completed_prob: float = None,
timeout: float = None,
fail_in_init: bool = None,
) -> None:
...
1 change: 0 additions & 1 deletion dlt/destinations/motherduck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0)
caps.max_identifier_length = 65536
caps.max_column_identifier_length = 65536
caps.naming_convention = "duck_case"
caps.max_query_length = 512 * 1024
caps.is_max_query_length_in_bytes = True
caps.max_text_data_type_length = 1024 * 1024 * 1024
Expand Down
13 changes: 12 additions & 1 deletion dlt/destinations/postgres/configuration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Final, ClassVar, Any, List
from typing import Final, ClassVar, Any, List, TYPE_CHECKING
from sqlalchemy.engine import URL

from dlt.common.configuration import configspec
Expand Down Expand Up @@ -46,3 +46,14 @@ def fingerprint(self) -> str:
if self.credentials and self.credentials.host:
return digest128(self.credentials.host)
return ""

if TYPE_CHECKING:
def __init__(
self,
destination_name: str = None,
credentials: PostgresCredentials = None,
dataset_name: str = None,
default_schema_name: str = None,
create_indexes: bool = True
) -> None:
...
1 change: 1 addition & 0 deletions dlt/destinations/postgres/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dlt.destinations.postgres.configuration import PostgresCredentials
from dlt.destinations.postgres import capabilities


class Psycopg2SqlClient(SqlClientBase["psycopg2.connection"], DBTransaction):

dbapi: ClassVar[DBApi] = psycopg2
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/weaviate/weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def make_weaviate_class_schema(self, table_name: str) -> Dict[str, Any]:
}

# check if any column requires vectorization
if get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT): # type: ignore
if get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT):
class_schema.update(self._vectorizer_config)
else:
class_schema.update(NON_VECTORIZED_CLASS)
Expand Down
Loading

0 comments on commit f2abadf

Please sign in to comment.