diff --git a/.changes/unreleased/Features-20241104-083421.yaml b/.changes/unreleased/Features-20241104-083421.yaml new file mode 100644 index 000000000..c2292751c --- /dev/null +++ b/.changes/unreleased/Features-20241104-083421.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support Data Profiling in dbt +time: 2024-11-04T08:34:21.596015+09:00 +custom: + Author: syou6162 + Issue: "1330" diff --git a/dbt/adapters/bigquery/dataplex.py b/dbt/adapters/bigquery/dataplex.py new file mode 100644 index 000000000..cf595e075 --- /dev/null +++ b/dbt/adapters/bigquery/dataplex.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass +import hashlib +from typing import Optional + +from dbt.adapters.bigquery import BigQueryConnectionManager +from google.cloud import dataplex_v1 +from google.protobuf import field_mask_pb2 + + +@dataclass +class DataProfileScanSetting: + location: str + scan_id: Optional[str] + + project_id: str + dataset_id: str + table_id: str + + sampling_percent: Optional[float] + row_filter: Optional[str] + cron: Optional[str] + + def parent(self): + return f"projects/{self.project_id}/locations/{self.location}" + + def data_scan_name(self): + return f"{self.parent()}/dataScans/{self.scan_id}" + + +class DataProfileScan: + def __init__(self, connections: BigQueryConnectionManager): + self.connections = connections + + # If the label `dataplex-dp-published-*` is not assigned, we cannot view the results of the Data Profile Scan from BigQuery + def _update_labels_with_data_profile_scan_labels( + self, + project_id: str, + dataset_id: str, + table_id: str, + location: str, + scan_id: str, + ): + table = self.connections.get_bq_table(project_id, dataset_id, table_id) + original_labels = table.labels + profile_scan_labels = { + "dataplex-dp-published-scan": scan_id, + "dataplex-dp-published-project": project_id, + "dataplex-dp-published-location": location, + } + table.labels = {**original_labels, **profile_scan_labels} + self.connections.get_thread_connection().handle.update_table(table, ["labels"]) + + # scan_id must be unique within the project and no longer than 63 characters, + # so generate an id that meets the constraints + def _generate_unique_scan_id(self, dataset_id: str, table_id: str) -> str: + md5 = hashlib.md5(f"{dataset_id}_{table_id}".encode("utf-8")).hexdigest() + return f"dbt-{table_id.replace('_', '-')}-{md5}"[:63] + + def _create_or_update_data_profile_scan( + self, + client: dataplex_v1.DataScanServiceClient, + scan_setting: DataProfileScanSetting, + ): + data_profile_spec = dataplex_v1.DataProfileSpec( + sampling_percent=scan_setting.sampling_percent, + row_filter=scan_setting.row_filter, + ) + display_name = ( + f"Data Profile Scan for {scan_setting.table_id} in {scan_setting.dataset_id}" + ) + description = f"This is a Data Profile Scan for {scan_setting.project_id}.{scan_setting.dataset_id}.{scan_setting.table_id}. Created by dbt." + labels = { + "managed_by": "dbt", + } + + if scan_setting.cron: + trigger = dataplex_v1.Trigger( + schedule=dataplex_v1.Trigger.Schedule(cron=scan_setting.cron) + ) + else: + trigger = dataplex_v1.Trigger(on_demand=dataplex_v1.Trigger.OnDemand()) + execution_spec = dataplex_v1.DataScan.ExecutionSpec(trigger=trigger) + + if all( + scan.name != scan_setting.data_scan_name() + for scan in client.list_data_scans(parent=scan_setting.parent()) + ): + data_scan = dataplex_v1.DataScan( + data=dataplex_v1.DataSource( + resource=f"//bigquery.googleapis.com/projects/{scan_setting.project_id}/datasets/{scan_setting.dataset_id}/tables/{scan_setting.table_id}" + ), + data_profile_spec=data_profile_spec, + execution_spec=execution_spec, + display_name=display_name, + description=description, + labels=labels, + ) + request = dataplex_v1.CreateDataScanRequest( + parent=scan_setting.parent(), + data_scan_id=scan_setting.scan_id, + data_scan=data_scan, + ) + client.create_data_scan(request=request).result() + else: + request = dataplex_v1.GetDataScanRequest( + name=scan_setting.data_scan_name(), + ) + data_scan = client.get_data_scan(request=request) + + data_scan.data_profile_spec = data_profile_spec + data_scan.execution_spec = execution_spec + data_scan.display_name = display_name + data_scan.description = description + data_scan.labels = labels + + update_mask = field_mask_pb2.FieldMask( + paths=[ + "data_profile_spec", + "execution_spec", + "display_name", + "description", + "labels", + ] + ) + request = dataplex_v1.UpdateDataScanRequest( + data_scan=data_scan, + update_mask=update_mask, + ) + client.update_data_scan(request=request).result() + + def create_or_update_data_profile_scan(self, config): + project_id = config.get("database") + dataset_id = config.get("schema") + table_id = config.get("name") + + data_profile_config = config.get("config").get("data_profile_scan", {}) + + # Skip if data_profile_scan is not configured + if not data_profile_config: + return None + + client = dataplex_v1.DataScanServiceClient() + scan_setting = DataProfileScanSetting( + location=data_profile_config["location"], + scan_id=data_profile_config.get( + "scan_id", self._generate_unique_scan_id(dataset_id, table_id) + ), + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + sampling_percent=data_profile_config.get("sampling_percent", None), + row_filter=data_profile_config.get("row_filter", None), + cron=data_profile_config.get("cron", None), + ) + + # Delete existing data profile scan if it is disabled + if not data_profile_config.get("enabled", True): + client.delete_data_scan(name=scan_setting.data_scan_name()) + return None + + self._create_or_update_data_profile_scan(client, scan_setting) + + if not scan_setting.cron: + client.run_data_scan( + request=dataplex_v1.RunDataScanRequest(name=scan_setting.data_scan_name()) + ) + + self._update_labels_with_data_profile_scan_labels( + project_id, dataset_id, table_id, scan_setting.location, scan_setting.scan_id + ) diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index 51c457129..d7f4d1107 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -55,6 +55,7 @@ from dbt.adapters.bigquery.column import BigQueryColumn, get_nested_column_data_types from dbt.adapters.bigquery.connections import BigQueryAdapterResponse, BigQueryConnectionManager +from dbt.adapters.bigquery.dataplex import DataProfileScan from dbt.adapters.bigquery.dataset import add_access_entry_to_dataset, is_access_entry_in_dataset from dbt.adapters.bigquery.python_submissions import ( ClusterDataprocHelper, @@ -969,3 +970,7 @@ def validate_sql(self, sql: str) -> AdapterResponse: :param str sql: The sql to validate """ return self.connections.dry_run(sql) + + @available + def create_or_update_data_profile_scan(self, config): + DataProfileScan(self.connections).create_or_update_data_profile_scan(config) diff --git a/dbt/include/bigquery/macros/materializations/incremental.sql b/dbt/include/bigquery/macros/materializations/incremental.sql index 25a83b0c6..d48281189 100644 --- a/dbt/include/bigquery/macros/materializations/incremental.sql +++ b/dbt/include/bigquery/macros/materializations/incremental.sql @@ -170,6 +170,7 @@ {% do apply_grants(target_relation, grant_config, should_revoke) %} {% do persist_docs(target_relation, model) %} + {% do adapter.create_or_update_data_profile_scan(model) %} {%- if tmp_relation_exists -%} {{ adapter.drop_relation(tmp_relation) }} diff --git a/dbt/include/bigquery/macros/materializations/table.sql b/dbt/include/bigquery/macros/materializations/table.sql index 41bb69770..2273d4794 100644 --- a/dbt/include/bigquery/macros/materializations/table.sql +++ b/dbt/include/bigquery/macros/materializations/table.sql @@ -40,6 +40,7 @@ {% do apply_grants(target_relation, grant_config, should_revoke) %} {% do persist_docs(target_relation, model) %} + {% do adapter.create_or_update_data_profile_scan(model) %} {{ return({'relations': [target_relation]}) }} diff --git a/pyproject.toml b/pyproject.toml index b2d55b25f..d6685c7b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "google-cloud-bigquery[pandas]>=3.0,<4.0", "google-cloud-storage~=2.4", "google-cloud-dataproc~=5.0", + "google-cloud-dataplex~=2.5", # ---- # Expect compatibility with all new versions of these packages, so lower bounds only. "google-api-core>=2.11.0", diff --git a/tests/functional/test_data_profile_scan.py b/tests/functional/test_data_profile_scan.py new file mode 100644 index 000000000..01a4c1839 --- /dev/null +++ b/tests/functional/test_data_profile_scan.py @@ -0,0 +1,438 @@ +import os +import pytest +import yaml +from unittest.mock import patch, MagicMock +from dbt.adapters.bigquery.relation import BigQueryRelation +from dbt.tests.util import ( + run_dbt, + get_connection, + relation_from_name, + read_file, + write_config_file, +) + +SCAN_LOCATION = "us-central1" +SCAN_ID = "bigquery_data_profile_scan_test" +MODEL_NAME = "test_model" + +ORIGINAL_LABELS = { + "my_label_key": "my_label_value", +} + +PROFILE_SCAN_LABELS = [ + "dataplex-dp-published-scan", + "dataplex-dp-published-project", + "dataplex-dp-published-location", +] + +SQL_CONTENT = """ +{{ + config( + materialized="table" + ) +}} + select 20 as id, cast('2020-01-01 01:00:00' as datetime) as date_hour union all + select 40 as id, cast('2020-01-01 02:00:00' as datetime) as date_hour +""" + +YAML_CONTENT = f"""version: 2 +models: + - name: {MODEL_NAME} +""" + +YAML_CONTENT_WITH_PROFILE_SCAN_SETTING = f"""version: 2 +models: + - name: {MODEL_NAME} + config: + data_profile_scan: + location: us-central1 + scan_id: {SCAN_ID} + sampling_percent: 10 + row_filter: "TRUE" + cron: "CRON_TZ=Asia/New_York 0 9 * * *" +""" + +INCREMENTAL_MODEL_CONTENT = """ +{{ + config( + materialized="incremental", + ) +}} + +{% if not is_incremental() %} + + select 10 as id, cast('2020-01-01 01:00:00' as datetime) as date_hour union all + select 30 as id, cast('2020-01-01 02:00:00' as datetime) as date_hour + +{% else %} + + select 20 as id, cast('2020-01-01 01:00:00' as datetime) as date_hour union all + select 40 as id, cast('2020-01-01 02:00:00' as datetime) as date_hour + +{% endif %} +""" + + +class TestDataProfileScanWithProjectProfileScanSetting: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+labels": ORIGINAL_LABELS, + "+data_profile_scan": { + "location": SCAN_LOCATION, + "sampling_percent": 10, + "row_filter": "TRUE", + }, + }, + } + + @pytest.fixture(scope="class") + def models(self): + return { + f"{MODEL_NAME}.sql": SQL_CONTENT, + f"{MODEL_NAME}.yml": YAML_CONTENT, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_called_once() + mock_data_scan_client.run_data_scan.assert_called_once() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = PROFILE_SCAN_LABELS + list(ORIGINAL_LABELS.keys()) + assert set(table.labels.keys()) == set(labels_to_be_created) + + +class TestDataProfileScanWithProjectProfileScanSettingAndCron: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+labels": ORIGINAL_LABELS, + "+data_profile_scan": { + "location": SCAN_LOCATION, + "scan_id": SCAN_ID, + "sampling_percent": 10, + "row_filter": "TRUE", + "cron": "CRON_TZ=Asia/New_York 0 9 * * *", + }, + }, + } + + @pytest.fixture(scope="class") + def models(self): + return { + f"{MODEL_NAME}.sql": SQL_CONTENT, + f"{MODEL_NAME}.yml": YAML_CONTENT, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_called_once() + mock_data_scan_client.run_data_scan.assert_not_called() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = PROFILE_SCAN_LABELS + list(ORIGINAL_LABELS.keys()) + assert set(table.labels.keys()) == set(labels_to_be_created) + + +class TestDataProfileScanWithModelProfileScanSetting: + @pytest.fixture(scope="class") + def models(self): + sql_content = f""" + {{{{ + config( + materialized="table", + labels={ORIGINAL_LABELS}, + ) + }}}} + select 20 as id, cast('2020-01-01 01:00:00' as datetime) as date_hour union all + select 40 as id, cast('2020-01-01 02:00:00' as datetime) as date_hour + """ + + return { + f"{MODEL_NAME}.sql": sql_content, + f"{MODEL_NAME}.yml": YAML_CONTENT_WITH_PROFILE_SCAN_SETTING, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_called_once() + mock_data_scan_client.run_data_scan.assert_not_called() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = PROFILE_SCAN_LABELS + list(ORIGINAL_LABELS.keys()) + assert set(table.labels.keys()) == set(labels_to_be_created) + + +class TestDataProfileScanWithoutProfileScanSetting: + @pytest.fixture(scope="class") + def models(self): + return { + f"{MODEL_NAME}.sql": SQL_CONTENT, + f"{MODEL_NAME}.yml": YAML_CONTENT, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_not_called() + mock_data_scan_client.run_data_scan.assert_not_called() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = [] + assert set(table.labels.keys()) == set(labels_to_be_created) + + +class TestDataProfileScanDisabledProfileScanSetting: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+data_profile_scan": { + "location": SCAN_LOCATION, + "scan_id": SCAN_ID, + "enabled": False, + }, + }, + } + + @pytest.fixture(scope="class") + def models(self): + return { + f"{MODEL_NAME}.sql": SQL_CONTENT, + f"{MODEL_NAME}.yml": YAML_CONTENT, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_not_called() + mock_data_scan_client.run_data_scan.assert_not_called() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = [] + assert set(table.labels.keys()) == set(labels_to_be_created) + + +class TestDataProfileScanUpdatedMidway: + project_name = "my-project" + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+database": self.project_name, + "+labels": ORIGINAL_LABELS, + "+data_profile_scan": { + "location": SCAN_LOCATION, + "scan_id": SCAN_ID, + "sampling_percent": 10, + "row_filter": "TRUE", + }, + }, + } + + @pytest.fixture(scope="class") + def models(self): + return { + f"{MODEL_NAME}.sql": SQL_CONTENT, + f"{MODEL_NAME}.yml": YAML_CONTENT, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_called_once() + mock_data_scan_client.run_data_scan.assert_called_once() + + def list_data_scans_mock(parent): + mock_scan = MagicMock() + mock_scan.name = SCAN_ID + return [mock_scan] + + mock_data_scan_client.list_data_scans.side_effect = list_data_scans_mock + + project_yml = os.path.join(project.project_root, "dbt_project.yml") + config = yaml.safe_load(read_file(project_yml)) + config["models"]["+data_profile_scan"]["sampling_percent"] = None + write_config_file(config, project_yml) + + results = run_dbt() + assert len(results) == 1 + mock_data_scan_client.update_data_scan.assert_called_once() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = PROFILE_SCAN_LABELS + list(ORIGINAL_LABELS.keys()) + assert set(table.labels.keys()) == set(labels_to_be_created) + + +class TestDataProfileScanDisabledMidway: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+labels": ORIGINAL_LABELS, + "+data_profile_scan": { + "location": SCAN_LOCATION, + "scan_id": SCAN_ID, + "sampling_percent": 10, + "row_filter": "TRUE", + }, + }, + } + + @pytest.fixture(scope="class") + def models(self): + return { + f"{MODEL_NAME}.sql": SQL_CONTENT, + f"{MODEL_NAME}.yml": YAML_CONTENT, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_called_once() + mock_data_scan_client.run_data_scan.assert_called_once() + + # Update the project to disable the data profile scan + project_yml = os.path.join(project.project_root, "dbt_project.yml") + config = yaml.safe_load(read_file(project_yml)) + config["models"]["+data_profile_scan"]["enabled"] = False + write_config_file(config, project_yml) + + results = run_dbt() + assert len(results) == 1 + mock_data_scan_client.delete_data_scan.assert_called_once() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = list(ORIGINAL_LABELS.keys()) + assert set(table.labels.keys()) == set(labels_to_be_created) + + +class TestDataProfileScanWithIncrementalModel: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+labels": ORIGINAL_LABELS, + "+data_profile_scan": { + "location": SCAN_LOCATION, + "scan_id": SCAN_ID, + "sampling_percent": 10, + "row_filter": "TRUE", + }, + }, + } + + @pytest.fixture(scope="class") + def models(self): + return { + f"{MODEL_NAME}.sql": INCREMENTAL_MODEL_CONTENT, + f"{MODEL_NAME}.yml": YAML_CONTENT, + } + + def test_create_data_profile_scan(self, project): + with patch("google.cloud.dataplex_v1.DataScanServiceClient") as MockDataScanClient: + mock_data_scan_client = MockDataScanClient.return_value + + results = run_dbt() + assert len(results) == 1 + + mock_data_scan_client.create_data_scan.assert_called_once() + mock_data_scan_client.run_data_scan.assert_called_once() + + def list_data_scans_mock(parent): + mock_scan = MagicMock() + mock_scan.name = SCAN_ID + return [mock_scan] + + mock_data_scan_client.list_data_scans.side_effect = list_data_scans_mock + + results = run_dbt() + assert len(results) == 1 + mock_data_scan_client.update_data_scan.assert_called_once() + + relation: BigQueryRelation = relation_from_name(project.adapter, MODEL_NAME) + with get_connection(project.adapter) as conn: + table = conn.handle.get_table( + project.adapter.connections.get_bq_table( + relation.database, relation.schema, relation.table + ) + ) + labels_to_be_created = PROFILE_SCAN_LABELS + list(ORIGINAL_LABELS.keys()) + assert set(table.labels.keys()) == set(labels_to_be_created)