Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a more realistic data masking example #2116

Open
trymzet opened this issue Dec 3, 2024 · 0 comments
Open

Add a more realistic data masking example #2116

trymzet opened this issue Dec 3, 2024 · 0 comments

Comments

@trymzet
Copy link
Contributor

trymzet commented Dec 3, 2024

Documentation description

Currently, the data masking example at https://dlthub.com/docs/general-usage/customising-pipelines/pseudonymizing_columns is not very realistic, as it's unlikely anyone would want to use a function with hardcoded column names. Instead, a function which takes columns as a parameter could be shown. Such function is more complicated that the dummy example, but it would be helpful to people wanting to implement masking by showing a real-world example. This implementation also requires the use of a closure, which is a software engineering concept many non-professional software engineers might be unfamiliar with (and thus it could be hard for them to come up with the solution even with the help of a chatbot).

Here's the function I wrote (note that it's only tested with sql_database and sql_table sources).

from enum import StrEnum

import pyarrow as pa
import pandas as pd
import dlt


class MaskingMethod(StrEnum):
    MASK = "mask"
    NULLIFY = "nullify"


def mask_sql_db_columns(
    columns: list[str],
    method: MaskingMethod | None = None,
    mask: str = "******",
) -> pa.Table | pd.DataFrame | dict:
    """Mask specified columns in a SQL Database table.

    All backends supported by the sql_database source, as of version 1.4.1, are
    supported. See https://github.com/dlt-hub/dlt/blob/devel/dlt/sources/sql_database/helpers.py#L50


    Args:
        columns (list[str]): The list of columns to mask.

    Returns:
        pa.Table | pd.DataFrame | dict[str, Any]: The table or row with the specified columns
            masked.

    """
    if method is None:
        method = MaskingMethod.MASK

    def mask_columns(
        table_or_row: pa.Table | pd.DataFrame | dict[str, Any],
    ) -> pa.Table | pd.DataFrame | dict[str, Any]:
        # Handle `pyarrow` and `connectorx` backends.
        if isinstance(table_or_row, pa.Table):
            table = table_or_row
            for col in table.schema.names:
                if col in columns:
                    if method == MaskingMethod.MASK:
                        replace_with = pa.array([mask] * table.num_rows)
                    elif method == MaskingMethod.NULLIFY:
                        replace_with = pa.nulls(
                            table.num_rows, type=table.schema.field(col).type
                        )
                    table = table.set_column(
                        table.schema.get_field_index(col),
                        col,
                        replace_with,
                    )
            return table

        # Handle `pandas` backend.
        if isinstance(table_or_row, pd.DataFrame):
            table = table_or_row

            for col in table.columns:
                if col in columns:
                    if method == MaskingMethod.MASK:
                        table[col] = mask
                    elif method == MaskingMethod.NULLIFY:
                        table[col] = None
            return table

        # Handle `sqlalchemy` backend.
        if isinstance(table_or_row, dict):
            row = table_or_row
            for col in row:
                if col in columns:
                    if method == MaskingMethod.MASK:
                        row[col] = mask
                    elif method == MaskingMethod.NULLIFY:
                        row[col] = None
            return row

        # Handle unsupported table types.
        msg = f"Unsupported table type: {type(table_or_row)}. Supported backends: (pyarrow, connectorx, pandas, sqlalchemy)."
        raise NotImplementedError(msg)

Then, it can be used like so:

table = sql_table_source(table="test_table")
table.add_map(mask_sql_db_columns(columns=["value"]))

pipeline = dlt.pipeline()
load_info = pipeline.run(table)

Are you a dlt user?

Yes, I'm already a dlt user.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Todo
Development

No branches or pull requests

1 participant