Skip to content

Commit

Permalink
pr fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 21, 2023
1 parent 03fcc21 commit 4b9adae
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 92 deletions.
2 changes: 1 addition & 1 deletion dlt/common/configuration/specs/run_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class RunConfiguration(BaseConfiguration):
"""Maximum delay between http request retries"""
config_files_storage_path: str = "/run/config/"
"""Platform connection"""
platform_dsn: Optional[str] = None
dlthub_dsn: Optional[TSecretStrValue] = None

__section__ = "runtime"

Expand Down
27 changes: 27 additions & 0 deletions dlt/common/managed_thread_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional

import atexit
from concurrent.futures import ThreadPoolExecutor

class ManagedThreadPool:

def __init__(self, max_workers: int = 1) -> None:
self._max_workers = max_workers
self._thread_pool: Optional[ThreadPoolExecutor] = None

def _create_thread_pool(self) -> None:
assert not self._thread_pool, "Thread pool already created"
self._thread_pool = ThreadPoolExecutor(self._max_workers)
# flush pool on exit
atexit.register(self.stop)

@property
def thread_pool(self) -> ThreadPoolExecutor:
if not self._thread_pool:
self._create_thread_pool()
return self._thread_pool

def stop(self, wait: bool = True) -> None:
if self._thread_pool:
self._thread_pool.shutdown(wait=wait)
self._thread_pool = None
6 changes: 5 additions & 1 deletion dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime # noqa: 251
import humanize
import contextlib
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, TypedDict
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, TypedDict, Mapping
from typing_extensions import NotRequired

from dlt.common import pendulum, logger
Expand Down Expand Up @@ -194,6 +194,10 @@ class SupportsPipeline(Protocol):
def state(self) -> TPipelineState:
"""Returns dictionary with pipeline state"""

@property
def schemas(self) -> Mapping[str, Schema]:
"""Mapping of all pipeline schemas"""

def set_local_state_val(self, key: str, value: Any) -> None:
"""Sets value in local state. Local state is not synchronized with destination."""

Expand Down
17 changes: 8 additions & 9 deletions dlt/common/runtime/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dlt.common.configuration.paths import get_dlt_data_dir

from dlt.common.runtime import logger
from dlt.common.managed_thread_pool import ManagedThreadPool

from dlt.common.configuration.specs import RunConfiguration
from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext
Expand All @@ -20,7 +21,7 @@

TEventCategory = Literal["pipeline", "command", "helper"]

_THREAD_POOL: ThreadPoolExecutor = None
_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_SESSION: requests.Session = None
_WRITE_KEY: str = None
_SEGMENT_REQUEST_TIMEOUT = (1.0, 1.0) # short connect & send timeouts
Expand All @@ -32,9 +33,8 @@ def init_segment(config: RunConfiguration) -> None:
assert config.dlthub_telemetry_segment_write_key, "dlthub_telemetry_segment_write_key not present in RunConfiguration"

# create thread pool to send telemetry to segment
global _THREAD_POOL, _WRITE_KEY, _SESSION
if not _THREAD_POOL:
_THREAD_POOL = ThreadPoolExecutor(1)
global _WRITE_KEY, _SESSION
if not _SESSION:
_SESSION = requests.Session()
# flush pool on exit
atexit.register(_at_exit_cleanup)
Expand Down Expand Up @@ -84,10 +84,9 @@ def before_send(event: DictStrAny) -> Optional[DictStrAny]:


def _at_exit_cleanup() -> None:
global _THREAD_POOL, _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _THREAD_POOL:
_THREAD_POOL.shutdown(wait=True)
_THREAD_POOL = None
global _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _SESSION:
_THREAD_POOL.stop(True)
_SESSION.close()
_SESSION = None
_WRITE_KEY = None
Expand Down Expand Up @@ -211,4 +210,4 @@ def _future_send() -> None:
f"Segment telemetry request returned a failure. Response: {data}"
)

_THREAD_POOL.submit(_future_send)
_THREAD_POOL.thread_pool.submit(_future_send)
12 changes: 11 additions & 1 deletion dlt/common/runtime/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Type, TypedDict, NewType, Union, get_args


TExecInfoNames = Literal["kubernetes", "docker", "codespaces", "github_actions", "airflow", "notebook", "colab","aws_lambda","gcp_cloud_function"]
TExecInfoNames = Literal[
"kubernetes",
"docker",
"codespaces",
"github_actions",
"airflow",
"notebook",
"colab",
"aws_lambda",
"gcp_cloud_function"
]

class TVersion(TypedDict):
"""TypeDict representing a library version"""
Expand Down
6 changes: 3 additions & 3 deletions dlt/common/storages/load_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class LoadPackageInfo(NamedTuple):
package_path: str
state: TLoadPackageState
schema_name: str
schema: TStoredSchema
schema_hash: str
schema_update: TSchemaTables
completed_at: datetime.datetime
jobs: Dict[TJobState, List[LoadJobInfo]]
Expand All @@ -112,7 +112,7 @@ def asdict(self) -> DictStrAny:
table["columns"] = columns
d.pop("schema_update")
d["tables"] = tables
d["schema"] = self.schema
d["schema_hash"] = self.schema_hash
return d

def asstr(self, verbosity: int = 0) -> str:
Expand Down Expand Up @@ -293,7 +293,7 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo:
jobs.append(self._read_job_file_info(state, file, package_created_at))
all_jobs[state] = jobs

return LoadPackageInfo(load_id, self.storage.make_full_path(package_path), package_state, schema.name, schema.to_dict(), applied_update, package_created_at, all_jobs)
return LoadPackageInfo(load_id, self.storage.make_full_path(package_path), package_state, schema.name, schema.version_hash, applied_update, package_created_at, all_jobs)

def begin_schema_update(self, load_id: str) -> Optional[TSchemaTables]:
package_path = self.get_normalized_package_path(load_id)
Expand Down
84 changes: 44 additions & 40 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import contextmanager
from functools import wraps
from collections.abc import Sequence as C_Sequence
from typing import Any, Callable, ClassVar, List, Iterator, Optional, Sequence, Tuple, cast, get_type_hints, ContextManager
from typing import Any, Callable, ClassVar, List, Iterator, Optional, Sequence, Tuple, cast, get_type_hints, ContextManager, Mapping
from concurrent.futures import Executor

from dlt import version
Expand Down Expand Up @@ -89,48 +89,52 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
return _wrap # type: ignore


def with_runtime_trace(f: TFun) -> TFun:
def with_runtime_trace(send_state: bool = False) -> Callable[[TFun], TFun]:

@wraps(f)
def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
trace: PipelineTrace = self._trace
trace_step: PipelineStepTrace = None
step_info: Any = None
is_new_trace = self._trace is None and self.config.enable_runtime_trace
def decorator(f: TFun) -> TFun:

# create a new trace if we enter a traced function and there's no current trace
if is_new_trace:
self._trace = trace = start_trace(cast(TPipelineStep, f.__name__), self)
@wraps(f)
def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
trace: PipelineTrace = self._trace
trace_step: PipelineStepTrace = None
step_info: Any = None
is_new_trace = self._trace is None and self.config.enable_runtime_trace

try:
# start a trace step for wrapped function
if trace:
trace_step = start_trace_step(trace, cast(TPipelineStep, f.__name__), self)
# create a new trace if we enter a traced function and there's no current trace
if is_new_trace:
self._trace = trace = start_trace(cast(TPipelineStep, f.__name__), self)

step_info = f(self, *args, **kwargs)
return step_info
except Exception as ex:
step_info = ex # step info is an exception
raise
finally:
try:
if trace_step:
# if there was a step, finish it
end_trace_step(self._trace, trace_step, self, step_info)
if is_new_trace:
assert trace is self._trace, f"Messed up trace reference {id(self._trace)} vs {id(trace)}"
end_trace(trace, self, self._pipeline_storage.storage_path)
# start a trace step for wrapped function
if trace:
trace_step = start_trace_step(trace, cast(TPipelineStep, f.__name__), self)

step_info = f(self, *args, **kwargs)
return step_info
except Exception as ex:
step_info = ex # step info is an exception
raise
finally:
# always end trace
if is_new_trace:
assert self._trace == trace, f"Messed up trace reference {id(self._trace)} vs {id(trace)}"
# if we end new trace that had only 1 step, add it to previous trace
# this way we combine several separate calls to extract, normalize, load as single trace
# the trace of "run" has many steps and will not be merged
self._last_trace = merge_traces(self._last_trace, trace)
self._trace = None
try:
if trace_step:
# if there was a step, finish it
end_trace_step(self._trace, trace_step, self, step_info, send_state)
if is_new_trace:
assert trace is self._trace, f"Messed up trace reference {id(self._trace)} vs {id(trace)}"
end_trace(trace, self, self._pipeline_storage.storage_path, send_state)
finally:
# always end trace
if is_new_trace:
assert self._trace == trace, f"Messed up trace reference {id(self._trace)} vs {id(trace)}"
# if we end new trace that had only 1 step, add it to previous trace
# this way we combine several separate calls to extract, normalize, load as single trace
# the trace of "run" has many steps and will not be merged
self._last_trace = merge_traces(self._last_trace, trace)
self._trace = None

return _wrap # type: ignore
return _wrap # type: ignore

return decorator


def with_config_section(sections: Tuple[str, ...]) -> Callable[[TFun], TFun]:
Expand Down Expand Up @@ -253,7 +257,7 @@ def drop(self) -> "Pipeline":
self.runtime_config
)

@with_runtime_trace
@with_runtime_trace()
@with_schemas_sync # this must precede with_state_sync
@with_state_sync(may_extract_state=True)
@with_config_section((known_sections.EXTRACT,))
Expand Down Expand Up @@ -293,7 +297,7 @@ def extract(
# TODO: provide metrics from extractor
raise PipelineStepFailed(self, "extract", exc, ExtractInfo(describe_extract_data(data))) from exc

@with_runtime_trace
@with_runtime_trace()
@with_schemas_sync
@with_config_section((known_sections.NORMALIZE,))
def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = None) -> NormalizeInfo:
Expand Down Expand Up @@ -326,7 +330,7 @@ def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = No
except Exception as n_ex:
raise PipelineStepFailed(self, "normalize", n_ex, normalize.get_normalize_info()) from n_ex

@with_runtime_trace
@with_runtime_trace()
@with_schemas_sync
@with_state_sync()
@with_config_section((known_sections.LOAD,))
Expand Down Expand Up @@ -379,7 +383,7 @@ def load(
except Exception as l_ex:
raise PipelineStepFailed(self, "load", l_ex, self._get_load_info(load)) from l_ex

@with_runtime_trace
@with_runtime_trace(send_state=True)
@with_config_section(("run",))
def run(
self,
Expand Down
Loading

0 comments on commit 4b9adae

Please sign in to comment.