Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prototype platform connection #727

Merged
merged 14 commits into from
Nov 24, 2023
2 changes: 2 additions & 0 deletions dlt/common/configuration/specs/run_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class RunConfiguration(BaseConfiguration):
request_max_retry_delay: float = 300
"""Maximum delay between http request retries"""
config_files_storage_path: str = "/run/config/"
"""Platform connection"""
dlthub_dsn: Optional[TSecretStrValue] = None

__section__ = "runtime"

Expand Down
27 changes: 27 additions & 0 deletions dlt/common/managed_thread_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional

import atexit
from concurrent.futures import ThreadPoolExecutor

class ManagedThreadPool:

def __init__(self, max_workers: int = 1) -> None:
self._max_workers = max_workers
self._thread_pool: Optional[ThreadPoolExecutor] = None

def _create_thread_pool(self) -> None:
assert not self._thread_pool, "Thread pool already created"
self._thread_pool = ThreadPoolExecutor(self._max_workers)
# flush pool on exit
atexit.register(self.stop)

@property
def thread_pool(self) -> ThreadPoolExecutor:
if not self._thread_pool:
self._create_thread_pool()
return self._thread_pool

def stop(self, wait: bool = True) -> None:
if self._thread_pool:
self._thread_pool.shutdown(wait=wait)
self._thread_pool = None
8 changes: 6 additions & 2 deletions dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime # noqa: 251
import humanize
import contextlib
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, TypedDict
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, TypedDict, Mapping
from typing_extensions import NotRequired

from dlt.common import pendulum, logger
Expand Down Expand Up @@ -54,7 +54,7 @@ def asdict(self) -> DictStrAny:
"""A dictionary representation of NormalizeInfo that can be loaded with `dlt`"""
d = self._asdict()
# list representation creates a nice table
d["row_counts"] = [(k, v) for k, v in self.row_counts.items()]
d["row_counts"] = [{"table_name": k, "count": v} for k, v in self.row_counts.items()]
rudolfix marked this conversation as resolved.
Show resolved Hide resolved
return d

def asstr(self, verbosity: int = 0) -> str:
Expand Down Expand Up @@ -194,6 +194,10 @@ class SupportsPipeline(Protocol):
def state(self) -> TPipelineState:
"""Returns dictionary with pipeline state"""

@property
def schemas(self) -> Mapping[str, Schema]:
"""Mapping of all pipeline schemas"""

def set_local_state_val(self, key: str, value: Any) -> None:
"""Sets value in local state. Local state is not synchronized with destination."""

Expand Down
21 changes: 18 additions & 3 deletions dlt/common/runtime/exec_info.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import io
import os
import contextlib
import sys
import multiprocessing
import platform

from dlt.common.runtime.typing import TExecutionContext, TVersion, TExecInfoNames
from dlt.common.typing import StrStr, StrAny, Literal, List
from dlt.common.utils import filter_env_vars
from dlt.version import __version__
from dlt.version import __version__, DLT_PKG_NAME


TExecInfoNames = Literal["kubernetes", "docker", "codespaces", "github_actions", "airflow", "notebook", "colab","aws_lambda","gcp_cloud_function"]
# if one of these environment variables is set, we assume to be running in CI env
CI_ENVIRONMENT_TELL = [
"bamboo.buildKey",
Expand Down Expand Up @@ -163,4 +166,16 @@ def is_aws_lambda() -> bool:

def is_gcp_cloud_function() -> bool:
"Return True if the process is running in the serverless platform GCP Cloud Functions"
return os.environ.get("FUNCTION_NAME") is not None
return os.environ.get("FUNCTION_NAME") is not None


def get_execution_context() -> TExecutionContext:
"Get execution context information"
return TExecutionContext(
ci_run=in_continuous_integration(),
python=sys.version.split(" ")[0],
cpu=multiprocessing.cpu_count(),
exec_info=exec_info_names(),
os=TVersion(name=platform.system(), version=platform.release()),
library=TVersion(name=DLT_PKG_NAME, version=__version__)
)
38 changes: 14 additions & 24 deletions dlt/common/runtime/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,39 @@

# several code fragments come from https://github.com/RasaHQ/rasa/blob/main/rasa/telemetry.py
import os
import sys
import multiprocessing

import atexit
import base64
import requests
import platform
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Optional
from dlt.common.configuration.paths import get_dlt_data_dir

from dlt.common.runtime import logger
from dlt.common.managed_thread_pool import ManagedThreadPool

from dlt.common.configuration.specs import RunConfiguration
from dlt.common.runtime.exec_info import exec_info_names, in_continuous_integration
from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext
from dlt.common.typing import DictStrAny, StrAny
from dlt.common.utils import uniq_id
from dlt.version import __version__, DLT_PKG_NAME
from dlt.version import __version__

TEventCategory = Literal["pipeline", "command", "helper"]

_THREAD_POOL: ThreadPoolExecutor = None
_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_SESSION: requests.Session = None
_WRITE_KEY: str = None
_SEGMENT_REQUEST_TIMEOUT = (1.0, 1.0) # short connect & send timeouts
_SEGMENT_ENDPOINT = "https://api.segment.io/v1/track"
_SEGMENT_CONTEXT: DictStrAny = None
_SEGMENT_CONTEXT: TExecutionContext = None


def init_segment(config: RunConfiguration) -> None:
assert config.dlthub_telemetry_segment_write_key, "dlthub_telemetry_segment_write_key not present in RunConfiguration"

# create thread pool to send telemetry to segment
global _THREAD_POOL, _WRITE_KEY, _SESSION
if not _THREAD_POOL:
_THREAD_POOL = ThreadPoolExecutor(1)
global _WRITE_KEY, _SESSION
if not _SESSION:
_SESSION = requests.Session()
# flush pool on exit
atexit.register(_at_exit_cleanup)
Expand Down Expand Up @@ -86,10 +84,9 @@ def before_send(event: DictStrAny) -> Optional[DictStrAny]:


def _at_exit_cleanup() -> None:
global _THREAD_POOL, _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _THREAD_POOL:
_THREAD_POOL.shutdown(wait=True)
_THREAD_POOL = None
global _SESSION, _WRITE_KEY, _SEGMENT_CONTEXT
if _SESSION:
_THREAD_POOL.stop(True)
_SESSION.close()
_SESSION = None
_WRITE_KEY = None
Expand Down Expand Up @@ -150,7 +147,7 @@ def _segment_request_payload(
}


def _default_context_fields() -> DictStrAny:
def _default_context_fields() -> TExecutionContext:
"""Return a dictionary that contains the default context values.

Return:
Expand All @@ -161,14 +158,7 @@ def _default_context_fields() -> DictStrAny:
if not _SEGMENT_CONTEXT:
# Make sure to update the example in docs/docs/telemetry/telemetry.mdx
# if you change / add context
_SEGMENT_CONTEXT = {
"os": {"name": platform.system(), "version": platform.release()},
"ci_run": in_continuous_integration(),
"python": sys.version.split(" ")[0],
"library": {"name": DLT_PKG_NAME, "version": __version__},
"cpu": multiprocessing.cpu_count(),
"exec_info": exec_info_names()
}
_SEGMENT_CONTEXT = get_execution_context()

# avoid returning the cached dict --> caller could modify the dictionary...
# usually we would use `lru_cache`, but that doesn't return a dict copy and
Expand Down Expand Up @@ -220,4 +210,4 @@ def _future_send() -> None:
f"Segment telemetry request returned a failure. Response: {data}"
)

_THREAD_POOL.submit(_future_send)
_THREAD_POOL.thread_pool.submit(_future_send)
28 changes: 28 additions & 0 deletions dlt/common/runtime/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Type, TypedDict, NewType, Union, get_args


TExecInfoNames = Literal[
"kubernetes",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look also at this. maybe we can copy code form there to have even more CI envs?
https://www.npmjs.com/package/ci-info

also there's CI env flag which says that code runs in CI so maybe we should add "generic_ci"

"docker",
"codespaces",
"github_actions",
"airflow",
"notebook",
"colab",
"aws_lambda",
"gcp_cloud_function"
]

class TVersion(TypedDict):
"""TypeDict representing a library version"""
name: str
version: str

class TExecutionContext(TypedDict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! for open telemetry we can collect way more information that is not anonymous. but in other PR

"""TypeDict representing the runtime context info"""
ci_run: bool
python: str
cpu: int
exec_info: List[TExecInfoNames]
library: TVersion
os: TVersion
5 changes: 4 additions & 1 deletion dlt/common/storages/load_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dlt.common.configuration.accessors import config
from dlt.common.exceptions import TerminalValueError
from dlt.common.schema import Schema, TSchemaTables, TTableSchemaColumns
from dlt.common.schema.typing import TStoredSchema
from dlt.common.storages.configuration import LoadStorageConfiguration
from dlt.common.storages.versioned_storage import VersionedStorage
from dlt.common.storages.data_item_storage import DataItemStorage
Expand Down Expand Up @@ -87,6 +88,7 @@ class LoadPackageInfo(NamedTuple):
package_path: str
state: TLoadPackageState
schema_name: str
schema_hash: str
schema_update: TSchemaTables
completed_at: datetime.datetime
jobs: Dict[TJobState, List[LoadJobInfo]]
Expand All @@ -110,6 +112,7 @@ def asdict(self) -> DictStrAny:
table["columns"] = columns
d.pop("schema_update")
d["tables"] = tables
d["schema_hash"] = self.schema_hash
return d

def asstr(self, verbosity: int = 0) -> str:
Expand Down Expand Up @@ -290,7 +293,7 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo:
jobs.append(self._read_job_file_info(state, file, package_created_at))
all_jobs[state] = jobs

return LoadPackageInfo(load_id, self.storage.make_full_path(package_path), package_state, schema.name, applied_update, package_created_at, all_jobs)
return LoadPackageInfo(load_id, self.storage.make_full_path(package_path), package_state, schema.name, schema.version_hash, applied_update, package_created_at, all_jobs)

def begin_schema_update(self, load_id: str) -> Optional[TSchemaTables]:
package_path = self.get_normalized_package_path(load_id)
Expand Down
4 changes: 2 additions & 2 deletions dlt/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def run(
)

# plug default tracking module
from dlt.pipeline import trace, track
trace.TRACKING_MODULE = track
from dlt.pipeline import trace, track, platform
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for platform but maybe we should just say opentelemetry?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at the moment it is not opentelemetry format at all. i can rename it, but i would just say we switch to opentelemetry when the prototype is out and then also rename this file.

trace.TRACKING_MODULES = [track, platform]

# setup default pipeline in the container
Container()[PipelineContext] = PipelineContext(pipeline)
Loading