From 2ed1233372e378ea1608320f7f35b4876b13810e Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 27 Mar 2024 10:09:49 +0400 Subject: [PATCH] feat(transform): implement columns pivot map function --- dlt/sources/helpers/transform.py | 30 +++++++++++++++++++++++++++ tests/pipeline/test_pipeline_state.py | 22 ++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/dlt/sources/helpers/transform.py b/dlt/sources/helpers/transform.py index 3949823be7..11ed76d5a5 100644 --- a/dlt/sources/helpers/transform.py +++ b/dlt/sources/helpers/transform.py @@ -1,3 +1,5 @@ +from typing import List, Union + from dlt.common.typing import TDataItem from dlt.extract.items import ItemTransformFunctionNoMeta @@ -24,3 +26,31 @@ def _filter(_: TDataItem) -> bool: return count > max_items return _filter + + +def pivot(columns: Union[str, List[str]], prefix: str) -> ItemTransformFunctionNoMeta[TDataItem]: + """Pivot a list of columns into a dictionary. + + Args: + columns (Union[str, List[str]]): list of column names + prefix (str): prefix to add to the column names + + Returns: + ItemTransformFunctionNoMeta[TDataItem]: + A function to pivot columns into a dict. + """ + if isinstance(columns, str): + columns = [columns] + + def _transformer(item: TDataItem) -> TDataItem: + """Pivot columns into a dictionary. + + Args: + item (TDataItem): a data item. + + Returns: + TDataItem: a data item with pivoted columns. + """ + return {prefix + col: item[ind] for ind, col in enumerate(columns)} + + return _transformer diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index f0bcda2717..edda12cd4b 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -11,6 +11,7 @@ from dlt.common import pipeline as state_module from dlt.common.utils import uniq_id from dlt.common.destination.reference import Destination +from dlt.sources.helpers.transform import pivot from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline @@ -486,6 +487,27 @@ def transform(item): ) +def test_transform_function_pivot() -> None: + @dlt.resource + def test_resource(): + for row in ( + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ): + yield row + + res = test_resource() + res.add_map(pivot(["a", "b", "c"], "prefix_")) + + result = list(res) + assert result == [ + {"prefix_a": 1, "prefix_b": 2, "prefix_c": 3}, + {"prefix_a": 4, "prefix_b": 5, "prefix_c": 6}, + {"prefix_a": 7, "prefix_b": 8, "prefix_c": 9}, + {"prefix_a": 10, "prefix_b": 11, "prefix_c": 12}, + ] + + def test_migrate_pipeline_state(test_storage: FileStorage) -> None: # test generation of version hash on migration to v3 state_v1 = load_json_case("state/state.v1")