Skip to content

Commit

Permalink
Merge branch 'devel' into d#/destination_config_updates
Browse files Browse the repository at this point in the history
# Conflicts:
#	Makefile
#	dlt/pipeline/pipeline.py
#	dlt/pipeline/state_sync.py
#	tests/pipeline/test_pipeline_state.py
  • Loading branch information
sh-rp committed Nov 30, 2023
2 parents 8b9d309 + 1f94a3b commit 5d05aab
Show file tree
Hide file tree
Showing 94 changed files with 6,797 additions and 5,272 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ lint:
poetry run mypy --config-file mypy.ini dlt tests
poetry run flake8 --max-line-length=200 dlt
poetry run flake8 --max-line-length=200 tests --exclude tests/reflection/module_cases
# $(MAKE) lint-security
poetry run black dlt tests docs --diff --exclude=".*syntax_error.py|\.venv.*"
poetry run black dlt docs tests --diff --exclude=".*syntax_error.py|\.venv.*|_storage/.*"
# poetry run isort ./ --diff
# $(MAKE) lint-security

format:
poetry run black dlt tests docs --exclude=".*syntax_error.py|\.venv.*"
poetry run black dlt docs tests --exclude=".*syntax_error.py|\.venv.*|_storage/.*"
# poetry run isort ./

test-and-lint-snippets:
Expand Down
26 changes: 15 additions & 11 deletions dlt/cli/pipeline_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dlt.common.runners import Venv
from dlt.common.runners.stdout import iter_stdout
from dlt.common.schema.utils import group_tables_by_resource, remove_defaults
from dlt.common.storages import FileStorage, LoadStorage
from dlt.common.storages import FileStorage, PackageStorage
from dlt.pipeline.helpers import DropCommand
from dlt.pipeline.exceptions import CannotRestorePipelineException

Expand Down Expand Up @@ -72,28 +72,30 @@ def pipeline_command(
return # No need to sync again

def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]:
extracted_files = p.list_extracted_resources()
if extracted_files:
extracted_packages = p.list_extracted_load_packages()
if extracted_packages:
fmt.echo(
"Has %s extracted files ready to be normalized"
% fmt.bold(str(len(extracted_files)))
"Has %s extracted packages ready to be normalized with following load ids:"
% fmt.bold(str(len(extracted_packages)))
)
for load_id in extracted_packages:
fmt.echo(load_id)
norm_packages = p.list_normalized_load_packages()
if norm_packages:
fmt.echo(
"Has %s load packages ready to be loaded with following load ids:"
"Has %s normalized packages ready to be loaded with following load ids:"
% fmt.bold(str(len(norm_packages)))
)
for load_id in norm_packages:
fmt.echo(load_id)
# load first (oldest) package
first_package_info = p.get_load_package_info(norm_packages[0])
if LoadStorage.is_package_partially_loaded(first_package_info):
if PackageStorage.is_package_partially_loaded(first_package_info):
fmt.warning(
"This package is partially loaded. Data in the destination may be modified."
)
fmt.echo()
return extracted_files, norm_packages
return extracted_packages, norm_packages

fmt.echo("Found pipeline %s in %s" % (fmt.bold(p.pipeline_name), fmt.bold(p.pipelines_dir)))

Expand Down Expand Up @@ -209,8 +211,8 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]:
fmt.echo("No failed jobs found")

if operation == "drop-pending-packages":
extracted_files, norm_packages = _display_pending_packages()
if len(extracted_files) == 0 and len(norm_packages) == 0:
extracted_packages, norm_packages = _display_pending_packages()
if len(extracted_packages) == 0 and len(norm_packages) == 0:
fmt.echo("No pending packages found")
if fmt.confirm("Delete the above packages?", default=False):
p.drop_pending_packages(with_partial_loads=True)
Expand All @@ -230,7 +232,9 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]:
if operation == "load-package":
load_id = command_kwargs.get("load_id")
if not load_id:
packages = sorted(p.list_normalized_load_packages())
packages = sorted(p.list_extracted_load_packages())
if not packages:
packages = sorted(p.list_normalized_load_packages())
if not packages:
packages = sorted(p.list_completed_load_packages())
if not packages:
Expand Down
4 changes: 2 additions & 2 deletions dlt/common/configuration/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def writable_provider(self) -> ConfigProvider:
if p.is_writable and not p.supports_secrets
)

value: ClassVar[None] = ConfigValue
value: ClassVar[Any] = ConfigValue
"A placeholder that tells dlt to replace it with actual config value during the call to a source or resource decorated function."


Expand All @@ -129,7 +129,7 @@ def writable_provider(self) -> ConfigProvider:
p for p in self._get_providers_from_context() if p.is_writable and p.supports_secrets
)

value: ClassVar[None] = ConfigValue
value: ClassVar[Any] = ConfigValue
"A placeholder that tells dlt to replace it with actual secret during the call to a source or resource decorated function."


Expand Down
6 changes: 0 additions & 6 deletions dlt/common/configuration/specs/gcp_credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
from typing import Any, Final, List, Tuple, Union
from deprecated import deprecated

from dlt.common import json, pendulum
from dlt.common.configuration.specs.api_credentials import OAuth2Credentials
Expand Down Expand Up @@ -89,13 +88,8 @@ def on_resolved(self) -> None:
# must end with new line, otherwise won't be parsed by Crypto
self.private_key = TSecretValue(self.private_key + "\n")

@deprecated(reason="Use 'to_native_credentials' method instead")
def to_service_account_credentials(self) -> Any:
return self.to_native_credentials()

def to_native_credentials(self) -> Any:
"""Returns google.oauth2.service_account.Credentials"""

from google.oauth2.service_account import Credentials as ServiceAccountCredentials

if isinstance(self.private_key, ServiceAccountCredentials):
Expand Down
2 changes: 2 additions & 0 deletions dlt/common/configuration/specs/run_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class RunConfiguration(BaseConfiguration):
request_max_retry_delay: float = 300
"""Maximum delay between http request retries"""
config_files_storage_path: str = "/run/config/"
"""Platform connection"""
dlthub_dsn: Optional[TSecretStrValue] = None

__section__ = "runtime"

Expand Down
6 changes: 4 additions & 2 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gzip
from functools import reduce
from typing import List, IO, Any, Optional, Type, TypeVar, Generic

from dlt.common.utils import uniq_id
Expand Down Expand Up @@ -75,7 +74,9 @@ def __init__(
raise InvalidFileNameTemplateException(file_name_template)

def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int:
self._ensure_open()
if self._closed:
self._rotate_file()
self._closed = False
# rotate file if columns changed and writer does not allow for that
# as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths
if (
Expand Down Expand Up @@ -175,6 +176,7 @@ def _flush_and_close_file(self) -> None:
self.closed_files.append(self._file_name)
self._writer = None
self._file = None
self._file_name = None

def _ensure_open(self) -> None:
if self._closed:
Expand Down
27 changes: 27 additions & 0 deletions dlt/common/managed_thread_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional

import atexit
from concurrent.futures import ThreadPoolExecutor


class ManagedThreadPool:
def __init__(self, max_workers: int = 1) -> None:
self._max_workers = max_workers
self._thread_pool: Optional[ThreadPoolExecutor] = None

def _create_thread_pool(self) -> None:
assert not self._thread_pool, "Thread pool already created"
self._thread_pool = ThreadPoolExecutor(self._max_workers)
# flush pool on exit
atexit.register(self.stop)

@property
def thread_pool(self) -> ThreadPoolExecutor:
if not self._thread_pool:
self._create_thread_pool()
return self._thread_pool

def stop(self, wait: bool = True) -> None:
if self._thread_pool:
self._thread_pool.shutdown(wait=wait)
self._thread_pool = None
12 changes: 10 additions & 2 deletions dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TYPE_CHECKING,
Tuple,
TypedDict,
Mapping,
)
from typing_extensions import NotRequired

Expand Down Expand Up @@ -72,7 +73,7 @@ def asdict(self) -> DictStrAny:
"""A dictionary representation of NormalizeInfo that can be loaded with `dlt`"""
d = self._asdict()
# list representation creates a nice table
d["row_counts"] = [(k, v) for k, v in self.row_counts.items()]
d["row_counts"] = [{"table_name": k, "count": v} for k, v in self.row_counts.items()]
return d

def asstr(self, verbosity: int = 0) -> str:
Expand Down Expand Up @@ -178,7 +179,9 @@ class TPipelineLocalState(TypedDict, total=False):
first_run: bool
"""Indicates a first run of the pipeline, where run ends with successful loading of data"""
_last_extracted_at: datetime.datetime
"""Timestamp indicating when the state was synced with the destination. Lack of timestamp means not synced state."""
"""Timestamp indicating when the state was synced with the destination."""
_last_extracted_hash: str
"""Hash of state that was recently synced with destination"""


class TPipelineState(TypedDict, total=False):
Expand All @@ -197,6 +200,7 @@ class TPipelineState(TypedDict, total=False):

# properties starting with _ are not automatically applied to pipeline object when state is restored
_state_version: int
_version_hash: str
_state_engine_version: int
_local: TPipelineLocalState
"""A section of state that is not synchronized with the destination and does not participate in change merging and version control"""
Expand Down Expand Up @@ -232,6 +236,10 @@ class SupportsPipeline(Protocol):
def state(self) -> TPipelineState:
"""Returns dictionary with pipeline state"""

@property
def schemas(self) -> Mapping[str, Schema]:
"""Mapping of all pipeline schemas"""

def set_local_state_val(self, key: str, value: Any) -> None:
"""Sets value in local state. Local state is not synchronized with destination."""

Expand Down
33 changes: 19 additions & 14 deletions dlt/common/runtime/exec_info.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import io
import os
import contextlib
import sys
import multiprocessing
import platform

from dlt.common.runtime.typing import TExecutionContext, TVersion, TExecInfoNames
from dlt.common.typing import StrStr, StrAny, Literal, List
from dlt.common.utils import filter_env_vars
from dlt.version import __version__


TExecInfoNames = Literal[
"kubernetes",
"docker",
"codespaces",
"github_actions",
"airflow",
"notebook",
"colab",
"aws_lambda",
"gcp_cloud_function",
]
from dlt.version import __version__, DLT_PKG_NAME


# if one of these environment variables is set, we assume to be running in CI env
CI_ENVIRONMENT_TELL = [
"bamboo.buildKey",
Expand Down Expand Up @@ -174,3 +167,15 @@ def is_aws_lambda() -> bool:
def is_gcp_cloud_function() -> bool:
"Return True if the process is running in the serverless platform GCP Cloud Functions"
return os.environ.get("FUNCTION_NAME") is not None


def get_execution_context() -> TExecutionContext:
"Get execution context information"
return TExecutionContext(
ci_run=in_continuous_integration(),
python=sys.version.split(" ")[0],
cpu=multiprocessing.cpu_count(),
exec_info=exec_info_names(),
os=TVersion(name=platform.system(), version=platform.release()),
library=TVersion(name=DLT_PKG_NAME, version=__version__),
)
38 changes: 14 additions & 24 deletions dlt/common/runtime/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,31 @@

# several code fragments come from https://github.com/RasaHQ/rasa/blob/main/rasa/telemetry.py
import os
import sys
import multiprocessing

import atexit
import base64
import requests
import platform
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Optional
from dlt.common.configuration.paths import get_dlt_data_dir

from dlt.common.runtime import logger
from dlt.common.managed_thread_pool import ManagedThreadPool

from dlt.common.configuration.specs import RunConfiguration
from dlt.common.runtime.exec_info import exec_info_names, in_continuous_integration
from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext
from dlt.common.typing import DictStrAny, StrAny
from dlt.common.utils import uniq_id
from dlt.version import __version__, DLT_PKG_NAME
from dlt.version import __version__

TEventCategory = Literal["pipeline", "command", "helper"]

_THREAD_POOL: ThreadPoolExecutor = None
_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_SESSION: requests.Session = None
_WRITE_KEY: str = None
_SEGMENT_REQUEST_TIMEOUT = (1.0, 1.0) # short connect & send timeouts
_SEGMENT_ENDPOINT = "https://api.segment.io/v1/track"
_SEGMENT_CONTEXT: DictStrAny = None
_SEGMENT_CONTEXT: TExecutionContext = None


def init_segment(config: RunConfiguration) -> None:
Expand All @@ -36,9 +35,8 @@ def init_segment(config: RunConfiguration) -> None:
), "dlthub_telemetry_segment_write_key not present in RunConfiguration"

# create thread pool to send telemetry to segment
global _THREAD_POOL, _WRITE_KEY, _SESSION
if not _THREAD_POOL:
_THREAD_POOL = ThreadPoolExecutor(1)
global _WRITE_KEY, _SESSION
if not _SESSION:
_SESSION = requests.Session()
# flush pool on exit
atexit.register(_at_exit_cleanup)
Expand Down Expand Up @@ -81,10 +79,9 @@ def before_send(event: DictStrAny) -> Optional[DictStrAny]:


def _at_exit_cleanup() -> None:
global _THREAD_POOL, _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _THREAD_POOL:
_THREAD_POOL.shutdown(wait=True)
_THREAD_POOL = None
global _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _SESSION:
_THREAD_POOL.stop(True)
_SESSION.close()
_SESSION = None
_WRITE_KEY = None
Expand Down Expand Up @@ -141,7 +138,7 @@ def _segment_request_payload(event_name: str, properties: StrAny, context: StrAn
}


def _default_context_fields() -> DictStrAny:
def _default_context_fields() -> TExecutionContext:
"""Return a dictionary that contains the default context values.
Return:
Expand All @@ -152,14 +149,7 @@ def _default_context_fields() -> DictStrAny:
if not _SEGMENT_CONTEXT:
# Make sure to update the example in docs/docs/telemetry/telemetry.mdx
# if you change / add context
_SEGMENT_CONTEXT = {
"os": {"name": platform.system(), "version": platform.release()},
"ci_run": in_continuous_integration(),
"python": sys.version.split(" ")[0],
"library": {"name": DLT_PKG_NAME, "version": __version__},
"cpu": multiprocessing.cpu_count(),
"exec_info": exec_info_names(),
}
_SEGMENT_CONTEXT = get_execution_context()

# avoid returning the cached dict --> caller could modify the dictionary...
# usually we would use `lru_cache`, but that doesn't return a dict copy and
Expand Down Expand Up @@ -207,4 +197,4 @@ def _future_send() -> None:
if not data.get("success"):
logger.debug(f"Segment telemetry request returned a failure. Response: {data}")

_THREAD_POOL.submit(_future_send)
_THREAD_POOL.thread_pool.submit(_future_send)
Loading

0 comments on commit 5d05aab

Please sign in to comment.