Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pipeline): implement normalize progress tracking for JSON normalizer #1639

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,22 @@ class _ExtractInfo(NamedTuple):
class ExtractInfo(StepInfo[ExtractMetrics], _ExtractInfo): # type: ignore[misc]
"""A tuple holding information on extracted data items. Returned by pipeline `extract` method."""

@property
def total_rows_count(self) -> int:
"""Return the total extracted rows count from all the jobs.

Returns:
int: Total extracted rows count.
"""
count = 0

for _, metrics_list in self.metrics.items():
for metrics in metrics_list:
for _, value in metrics["job_metrics"].items():
count += value.items_count

return count

def asdict(self) -> DictStrAny:
"""A dictionary representation of ExtractInfo that can be loaded with `dlt`"""
d = super().asdict()
Expand Down Expand Up @@ -459,6 +475,8 @@ class TPipelineLocalState(TypedDict, total=False):
"""Timestamp indicating when the state was synced with the destination."""
_last_extracted_hash: str
"""Hash of state that was recently synced with destination"""
_last_extracted_count: int
"""Number of extracted rows in the last run"""


class TPipelineState(TVersionedState, total=False):
Expand Down
7 changes: 6 additions & 1 deletion dlt/normalize/items_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dlt.common.json import custom_pua_decode, may_have_pua
from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer
from dlt.common.runtime import signals
from dlt.common.runtime.collector import Collector, NULL_COLLECTOR
from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict
from dlt.common.schema.utils import has_table_seen_data
from dlt.common.storages import NormalizeStorage
Expand Down Expand Up @@ -36,12 +37,14 @@ def __init__(
schema: Schema,
load_id: str,
config: NormalizeConfiguration,
collector: Collector = NULL_COLLECTOR,
) -> None:
self.item_storage = item_storage
self.normalize_storage = normalize_storage
self.schema = schema
self.load_id = load_id
self.config = config
self.collector = collector

@abstractmethod
def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSchemaUpdate]: ...
Expand All @@ -55,8 +58,9 @@ def __init__(
schema: Schema,
load_id: str,
config: NormalizeConfiguration,
collector: Collector = NULL_COLLECTOR,
) -> None:
super().__init__(item_storage, normalize_storage, schema, load_id, config)
super().__init__(item_storage, normalize_storage, schema, load_id, config, collector)
self._table_contracts: Dict[str, TSchemaContractDict] = {}
self._filtered_tables: Set[str] = set()
self._filtered_tables_columns: Dict[str, Dict[str, TSchemaEvolutionMode]] = {}
Expand Down Expand Up @@ -86,6 +90,7 @@ def _normalize_chunk(
for item in items:
items_gen = normalize_data_fun(item, self.load_id, root_table_name)
try:
self.collector.update("Items")
should_descend: bool = None
# use send to prevent descending into child rows when row was discarded
while row_info := items_gen.send(should_descend):
Expand Down
13 changes: 6 additions & 7 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Normalize(Runnable[Executor], WithStepInfo[NormalizeMetrics, NormalizeInfo
@with_config(spec=NormalizeConfiguration, sections=(known_sections.NORMALIZE,))
def __init__(
self,
extracted_count: Optional[int] = None,
collector: Collector = NULL_COLLECTOR,
schema_storage: SchemaStorage = None,
config: NormalizeConfiguration = config.value,
Expand All @@ -59,6 +60,7 @@ def __init__(
self.pool = NullExecutor()
self.load_storage: LoadStorage = None
self.schema_storage: SchemaStorage = None
self.extracted_count = extracted_count

# setup storages
self.create_storages()
Expand Down Expand Up @@ -102,6 +104,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW
schema_dict,
load_id,
files,
self.collector,
)
for files in chunk_files
]
Expand All @@ -128,9 +131,6 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW
summary.file_metrics.extend(result.file_metrics)
# update metrics
self.collector.update("Files", len(result.file_metrics))
self.collector.update(
"Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count
)
except CannotCoerceColumnException as exc:
# schema conflicts resulting from parallel executing
logger.warning(
Expand Down Expand Up @@ -161,12 +161,10 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWor
schema.to_dict(),
load_id,
files,
self.collector,
)
self.update_schema(schema, result.schema_updates)
self.collector.update("Files", len(result.file_metrics))
self.collector.update(
"Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count
)
return result

def spool_files(
Expand Down Expand Up @@ -296,7 +294,8 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics:
continue
with self.collector(f"Normalize {schema.name} in {load_id}"):
self.collector.update("Files", 0, len(schema_files))
self.collector.update("Items", 0)
self.collector.update("Items", 0, total=self.extracted_count)

self._step_info_start_load_id(load_id)
self.spool_schema_files(load_id, schema, schema_files)

Expand Down
15 changes: 8 additions & 7 deletions dlt/normalize/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_best_writer_spec,
is_native_writer,
)
from dlt.common.utils import chunks
from dlt.common.runtime.collector import Collector, NULL_COLLECTOR
from dlt.common.schema.typing import TStoredSchema, TTableSchema
from dlt.common.storages import (
NormalizeStorage,
Expand All @@ -20,6 +20,7 @@
ParsedLoadJobFileName,
)
from dlt.common.schema import TSchemaUpdate, Schema
from dlt.common.utils import chunks

from dlt.normalize.configuration import NormalizeConfiguration
from dlt.normalize.exceptions import NormalizeJobFailed
Expand Down Expand Up @@ -61,6 +62,7 @@ def w_normalize_files(
stored_schema: TStoredSchema,
load_id: str,
extracted_items_files: Sequence[str],
collector: Collector = NULL_COLLECTOR,
) -> TWorkerRV:
destination_caps = config.destination_capabilities
schema_updates: List[TSchemaUpdate] = []
Expand All @@ -82,7 +84,9 @@ def w_normalize_files(
load_storage = LoadStorage(False, supported_file_formats, loader_storage_config)

def _get_items_normalizer(
parsed_file_name: ParsedLoadJobFileName, table_schema: TTableSchema
parsed_file_name: ParsedLoadJobFileName,
table_schema: TTableSchema,
collector: Collector = NULL_COLLECTOR,
) -> ItemsNormalizer:
item_format = DataWriter.item_format_from_file_extension(parsed_file_name.file_format)

Expand Down Expand Up @@ -183,11 +187,7 @@ def _get_items_normalizer(
f" format {item_storage.writer_spec.file_format}"
)
norm = item_normalizers[table_name] = cls(
item_storage,
normalize_storage,
schema,
load_id,
config,
item_storage, normalize_storage, schema, load_id, config, collector
)
return norm

Expand Down Expand Up @@ -236,6 +236,7 @@ def _gather_metrics_and_close(
normalizer = _get_items_normalizer(
parsed_file_name,
stored_schema["tables"].get(root_table_name, {"name": root_table_name}),
collector,
)
logger.debug(
f"Processing extracted items in {extracted_items_file} in load_id"
Expand Down
15 changes: 11 additions & 4 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,12 @@ def extract(
)
# commit load packages with state
extract_step.commit_packages()
return self._get_step_info(extract_step)

info = self._get_step_info(extract_step)
state = self._container[StateInjectableContext].state
state["_local"]["_last_extracted_count"] = info.total_rows_count
return info

except Exception as exc:
# emit step info
step_info = self._get_step_info(extract_step)
Expand Down Expand Up @@ -493,7 +498,9 @@ def _verify_destination_capabilities(
@with_schemas_sync
@with_config_section((known_sections.NORMALIZE,))
def normalize(
self, workers: int = 1, loader_file_format: TLoaderFileFormat = None
self,
workers: int = 1,
loader_file_format: TLoaderFileFormat = None,
) -> NormalizeInfo:
"""Normalizes the data prepared with `extract` method, infers the schema and creates load packages for the `load` method. Requires `destination` to be known."""
if is_interactive():
Expand Down Expand Up @@ -525,6 +532,7 @@ def normalize(
collector=self.collector,
config=normalize_config,
schema_storage=self._schema_storage,
extracted_count=self._get_state()["_local"].get("_last_extracted_count"),
)
try:
with signals.delayed_signals():
Expand Down Expand Up @@ -719,8 +727,7 @@ def run(
)
self.normalize(loader_file_format=loader_file_format)
return self.load(destination, dataset_name, credentials=credentials)
else:
return None
return None

@with_schemas_sync
def sync_destination(
Expand Down
15 changes: 15 additions & 0 deletions tests/pipeline/test_pipeline_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib.util
from typing import Any, ClassVar, Dict, Iterator, List, Optional
import pytest
from unittest import mock

from dlt.pipeline.exceptions import PipelineStepFailed

Expand Down Expand Up @@ -101,6 +102,20 @@ def test_pipeline_progress(progress: TCollectorArg) -> None:
assert isinstance(collector, LogCollector)


@pytest.mark.parametrize("progress", ["tqdm", "enlighten", "log", "alive_progress"])
def test_pipeline_normalize_progress(progress: TCollectorArg) -> None:
os.environ["TIMEOUT"] = "3.0"

p = dlt.pipeline(destination="dummy", progress=progress)
p.extract(many_delayed(5, 10))

with mock.patch.object(p.collector, "update") as col_mock:
p.normalize()
assert col_mock.call_count == 54

p.run(dataset_name="dummy")


@pytest.mark.parametrize("method", ("extract", "run"))
def test_column_argument_pydantic(method: str) -> None:
"""Test columns schema is created from pydantic model"""
Expand Down
Loading