Skip to content

Commit

Permalink
Merge pull request #661 from dlt-hub/sthor/mypy-in-tests
Browse files Browse the repository at this point in the history
Typing fixes and enable mypy in tests
  • Loading branch information
rudolfix authored Sep 29, 2023
2 parents 6a9a363 + 7f2b684 commit 4ec9ab8
Show file tree
Hide file tree
Showing 110 changed files with 4,377 additions and 4,466 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dev: has-poetry

lint:
./check-package.sh
poetry run mypy --config-file mypy.ini dlt
poetry run mypy --config-file mypy.ini dlt tests
poetry run flake8 --max-line-length=200 dlt
poetry run flake8 --max-line-length=200 tests --exclude tests/reflection/module_cases
# $(MAKE) lint-security
Expand Down
13 changes: 8 additions & 5 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gzip
from typing import List, IO, Any, Optional, Type
from typing import List, IO, Any, Optional, Type, TypeVar, Generic

from dlt.common.utils import uniq_id
from dlt.common.typing import TDataItem, TDataItems
Expand All @@ -12,7 +12,10 @@
from dlt.common.destination import DestinationCapabilitiesContext


class BufferedDataWriter:
TWriter = TypeVar("TWriter", bound=DataWriter)


class BufferedDataWriter(Generic[TWriter]):

@configspec
class BufferedDataWriterConfiguration(BaseConfiguration):
Expand Down Expand Up @@ -55,7 +58,7 @@ def __init__(
self._current_columns: TTableSchemaColumns = None
self._file_name: str = None
self._buffered_items: List[TDataItem] = []
self._writer: DataWriter = None
self._writer: TWriter = None
self._file: IO[Any] = None
self._closed = False
try:
Expand Down Expand Up @@ -104,7 +107,7 @@ def close(self) -> None:
def closed(self) -> bool:
return self._closed

def __enter__(self) -> "BufferedDataWriter":
def __enter__(self) -> "BufferedDataWriter[TWriter]":
return self

def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None:
Expand All @@ -123,7 +126,7 @@ def _flush_items(self, allow_empty_file: bool = False) -> None:
self._file = self.open(self._file_name, "wb") # type: ignore
else:
self._file = self.open(self._file_name, "wt", encoding="utf-8") # type: ignore
self._writer = DataWriter.from_file_format(self.file_format, self._file, caps=self._caps)
self._writer = DataWriter.from_file_format(self.file_format, self._file, caps=self._caps) # type: ignore[assignment]
self._writer.write_header(self._current_columns)
# write buffer
if self._buffered_items:
Expand Down
8 changes: 4 additions & 4 deletions dlt/common/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def dumps(self, obj: Any, sort_keys: bool = False, pretty:bool = False) -> str:
def dumpb(self, obj: Any, sort_keys: bool = False, pretty:bool = False) -> bytes:
...

def load(self, fp: IO[bytes]) -> Any:
def load(self, fp: Union[IO[bytes], IO[str]]) -> Any:
...

def loads(self, s: str) -> Any:
Expand Down Expand Up @@ -185,11 +185,11 @@ def custom_pua_remove(obj: Any) -> Any:
json: SupportsJson = None
if os.environ.get("DLT_USE_JSON") == "simplejson":
from dlt.common.json import _simplejson as _json_d
json = _json_d
json = _json_d # type: ignore[assignment]
else:
try:
from dlt.common.json import _orjson as _json_or
json = _json_or
json = _json_or # type: ignore[assignment]
except ImportError:
from dlt.common.json import _simplejson as _json_simple
json = _json_simple
json = _json_simple # type: ignore[assignment]
5 changes: 4 additions & 1 deletion dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import humanize
import contextlib
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, TypedDict
from typing_extensions import NotRequired

from dlt.common import pendulum, logger
from dlt.common.configuration import configspec
Expand Down Expand Up @@ -163,9 +164,11 @@ class TPipelineState(TypedDict, total=False):
_local: TPipelineLocalState
"""A section of state that is not synchronized with the destination and does not participate in change merging and version control"""

sources: NotRequired[Dict[str, Dict[str, Any]]]


class TSourceState(TPipelineState):
sources: Dict[str, Dict[str, Any]]
sources: Dict[str, Dict[str, Any]] # type: ignore[misc]


class SupportsPipeline(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/runtime/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(self, single_bar: bool = False, **tqdm_kwargs: Any) -> None:
except ModuleNotFoundError:
raise MissingDependencyException("TqdmCollector", ["tqdm"], "We need tqdm to display progress bars.")
self.single_bar = single_bar
self._bars: Dict[str, tqdm] = {}
self._bars: Dict[str, tqdm[None]] = {}
self.tqdm_kwargs = tqdm_kwargs or {}

def update(self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "") -> None:
Expand Down
4 changes: 2 additions & 2 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib

from copy import deepcopy, copy
from typing import Dict, List, Sequence, Tuple, Type, Any, cast, Iterable, Optional
from typing import Dict, List, Sequence, Tuple, Type, Any, cast, Iterable, Optional, Union

from dlt.common import json
from dlt.common.data_types import TDataType
Expand Down Expand Up @@ -476,7 +476,7 @@ def hint_to_column_prop(h: TColumnHint) -> TColumnProp:
return h


def get_columns_names_with_prop(table: TTableSchema, column_prop: TColumnProp, include_incomplete: bool = False) -> List[str]:
def get_columns_names_with_prop(table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False) -> List[str]:
# column_prop: TColumnProp = hint_to_column_prop(hint_type)
# default = column_prop != "nullable" # default is true, only for nullable false
return [c["name"] for c in table["columns"].values() if bool(c.get(column_prop, False)) is True and (include_incomplete or is_complete_column(c))]
Expand Down
8 changes: 4 additions & 4 deletions dlt/common/storages/data_item_storage.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Dict, Any, List
from typing import Dict, Any, List, Generic
from abc import ABC, abstractmethod

from dlt.common import logger
from dlt.common.schema import TTableSchemaColumns
from dlt.common.typing import TDataItems
from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter
from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter, DataWriter


class DataItemStorage(ABC):
def __init__(self, load_file_type: TLoaderFileFormat, *args: Any) -> None:
self.loader_file_format = load_file_type
self.buffered_writers: Dict[str, BufferedDataWriter] = {}
self.buffered_writers: Dict[str, BufferedDataWriter[DataWriter]] = {}
super().__init__(*args)

def get_writer(self, load_id: str, schema_name: str, table_name: str) -> BufferedDataWriter:
def get_writer(self, load_id: str, schema_name: str, table_name: str) -> BufferedDataWriter[DataWriter]:
# unique writer id
writer_id = f"{load_id}.{schema_name}.{table_name}"
writer = self.buffered_writers.get(writer_id, None)
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/storages/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import shutil
import pathvalidate
from typing import IO, Any, Optional, List, cast
from typing import IO, Any, Optional, List, cast, overload
from dlt.common.typing import AnyFun

from dlt.common.utils import encoding_for_mode, uniq_id
Expand Down
6 changes: 5 additions & 1 deletion dlt/common/storages/versioned_storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

import semver

from dlt.common.storages.file_storage import FileStorage
Expand All @@ -8,7 +10,9 @@ class VersionedStorage:

VERSION_FILE = ".version"

def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileStorage) -> None:
def __init__(self, version: Union[semver.VersionInfo, str], is_owner: bool, storage: FileStorage) -> None:
if isinstance(version, str):
version = semver.VersionInfo.parse(version)
self.storage = storage
# read current version
if self.storage.has_file(VersionedStorage.VERSION_FILE):
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def get_generic_type_argument_from_instance(instance: Any, sample_value: Optiona
orig_param_type = get_args(instance.__orig_class__)[0]
if orig_param_type is Any and sample_value is not None:
orig_param_type = type(sample_value)
return orig_param_type # type: ignore
return orig_param_type # type: ignore
2 changes: 1 addition & 1 deletion dlt/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def str2bool(v: str) -> bool:
# return o


def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> StrAny:
def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> DictStrAny:
"""
Transforms a list of objects or strings [{K: {...}}, L, ...] -> {K: {...}, L: None, ...}
"""
Expand Down
21 changes: 19 additions & 2 deletions dlt/destinations/duckdb/configuration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import threading
from pathvalidate import is_valid_filepath
from typing import Any, ClassVar, Final, List, Optional, Tuple
from typing import Any, ClassVar, Final, List, Optional, Tuple, TYPE_CHECKING, Union

from dlt.common import logger
from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ConnectionStringCredentials
from dlt.common.configuration.specs.exceptions import InvalidConnectionString
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration, DestinationClientStagingConfiguration
from dlt.common.typing import TSecretValue

DUCK_DB_NAME = "%s.duckdb"
Expand Down Expand Up @@ -180,3 +180,20 @@ class DuckDbClientConfiguration(DestinationClientDwhWithStagingConfiguration):
credentials: DuckDbCredentials

create_indexes: bool = False # should unique indexes be created, this slows loading down massively

if TYPE_CHECKING:
try:
from duckdb import DuckDBPyConnection
except ModuleNotFoundError:
DuckDBPyConnection = Any # type: ignore[assignment,misc]

def __init__(
self,
destination_name: str = None,
credentials: Union[DuckDbCredentials, str, DuckDBPyConnection] = None,
dataset_name: str = None,
default_schema_name: Optional[str] = None,
create_indexes: bool = False,
staging_config: Optional[DestinationClientStagingConfiguration] = None
) -> None:
...
17 changes: 17 additions & 0 deletions dlt/destinations/dummy/configuration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING, Optional

from dlt.common.configuration import configspec
from dlt.common.destination import TLoaderFileFormat
from dlt.common.destination.reference import DestinationClientConfiguration, CredentialsConfiguration
Expand All @@ -22,3 +24,18 @@ class DummyClientConfiguration(DestinationClientConfiguration):
fail_in_init: bool = True

credentials: DummyClientCredentials = None

if TYPE_CHECKING:
def __init__(
self,
destination_name: str = None,
credentials: Optional[CredentialsConfiguration] = None,
loader_file_format: TLoaderFileFormat = None,
fail_schema_update: bool = None,
fail_prob: float = None,
retry_prob: float = None,
completed_prob: float = None,
timeout: float = None,
fail_in_init: bool = None,
) -> None:
...
13 changes: 12 additions & 1 deletion dlt/destinations/postgres/configuration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Final, ClassVar, Any, List
from typing import Final, ClassVar, Any, List, TYPE_CHECKING
from sqlalchemy.engine import URL

from dlt.common.configuration import configspec
Expand Down Expand Up @@ -46,3 +46,14 @@ def fingerprint(self) -> str:
if self.credentials and self.credentials.host:
return digest128(self.credentials.host)
return ""

if TYPE_CHECKING:
def __init__(
self,
destination_name: str = None,
credentials: PostgresCredentials = None,
dataset_name: str = None,
default_schema_name: str = None,
create_indexes: bool = True
) -> None:
...
1 change: 1 addition & 0 deletions dlt/destinations/postgres/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dlt.destinations.postgres.configuration import PostgresCredentials
from dlt.destinations.postgres import capabilities


class Psycopg2SqlClient(SqlClientBase["psycopg2.connection"], DBTransaction):

dbapi: ClassVar[DBApi] = psycopg2
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/weaviate/weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def make_weaviate_class_schema(self, table_name: str) -> Dict[str, Any]:
}

# check if any column requires vectorization
if get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT): # type: ignore
if get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT):
class_schema.update(self._vectorizer_config)
else:
class_schema.update(NON_VECTORIZED_CLASS)
Expand Down
10 changes: 5 additions & 5 deletions dlt/extract/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def resource(
merge_key: TTableHintTemplate[TColumnNames] = None,
selected: bool = True,
spec: Type[BaseConfiguration] = None
) -> Callable[TResourceFunParams, DltResource]:
) -> DltResource:
...

@overload
Expand Down Expand Up @@ -388,7 +388,7 @@ def transformer(
merge_key: TTableHintTemplate[TColumnNames] = None,
selected: bool = True,
spec: Type[BaseConfiguration] = None
) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], Callable[TResourceFunParams, DltResource]]:
) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], DltResource]:
...

@overload
Expand All @@ -404,10 +404,10 @@ def transformer(
merge_key: TTableHintTemplate[TColumnNames] = None,
selected: bool = True,
spec: Type[BaseConfiguration] = None
) -> Callable[TResourceFunParams, DltResource]:
) -> DltResource:
...

def transformer( # type: ignore
def transformer(
f: Optional[Callable[Concatenate[TDataItem, TResourceFunParams], Any]] = None,
/,
data_from: TUnboundDltResource = DltResource.Empty,
Expand All @@ -419,7 +419,7 @@ def transformer( # type: ignore
merge_key: TTableHintTemplate[TColumnNames] = None,
selected: bool = True,
spec: Type[BaseConfiguration] = None
) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], Callable[TResourceFunParams, DltResource]]:
) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], DltResource]:
"""A form of `dlt resource` that takes input from other resources via `data_from` argument in order to enrich or transform the data.
The decorated function `f` must take at least one argument of type TDataItems (a single item or list of items depending on the resource `data_from`). `dlt` will pass
Expand Down
8 changes: 7 additions & 1 deletion dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,18 @@ class SourcePipeItem(NamedTuple):


# pipeline step may be iterator of data items or mapping function that returns data item or another iterator
from dlt.common.typing import TDataItem
TPipeStep = Union[
Iterable[TPipedDataItems],
Iterator[TPipedDataItems],
# Callable with meta
Callable[[TDataItems, Optional[Any]], TPipedDataItems],
Callable[[TDataItems, Optional[Any]], Iterator[TPipedDataItems]],
Callable[[TDataItems, Optional[Any]], Iterator[ResolvablePipeItem]]
Callable[[TDataItems, Optional[Any]], Iterator[ResolvablePipeItem]],
# Callable without meta
Callable[[TDataItems], TPipedDataItems],
Callable[[TDataItems], Iterator[TPipedDataItems]],
Callable[[TDataItems], Iterator[ResolvablePipeItem]],
]

TPipeNextItemMode = Union[Literal["fifo"], Literal["round_robin"]]
Expand Down
2 changes: 1 addition & 1 deletion dlt/helpers/dbt/dbt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

try:
# dbt <1.5
from dbt.main import handle_and_check
from dbt.main import handle_and_check # type: ignore[import]
except ImportError:
# dbt >=1.5
from dbt.cli.main import dbtRunner
Expand Down
2 changes: 1 addition & 1 deletion dlt/pipeline/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _list_state_paths(self, source_state: Dict[str, Any]) -> List[str]:
return resolve_paths(self.state_paths_to_drop, source_state)

def _create_modified_state(self) -> Dict[str, Any]:
state: TSourceState = self.pipeline.state # type: ignore[assignment]
state = self.pipeline.state
if not self.drop_state:
return state # type: ignore[return-value]
source_states = _sources_state(state).items()
Expand Down
2 changes: 1 addition & 1 deletion dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
class Pipeline(SupportsPipeline):

STATE_FILE: ClassVar[str] = "state.json"
STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineState).keys())
STATE_PROPS: ClassVar[List[str]] = list(set(get_type_hints(TPipelineState).keys()) - {"sources"})
LOCAL_STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineLocalState).keys())
DEFAULT_DATASET_SUFFIX: ClassVar[str] = "_dataset"

Expand Down
Loading

0 comments on commit 4ec9ab8

Please sign in to comment.