diff --git a/airflow/providers/google/cloud/operators/datafusion.py b/airflow/providers/google/cloud/operators/datafusion.py index b2149495f93d..4f62b8240740 100644 --- a/airflow/providers/google/cloud/operators/datafusion.py +++ b/airflow/providers/google/cloud/operators/datafusion.py @@ -24,7 +24,7 @@ from googleapiclient.errors import HttpError from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, DataFusionHook, PipelineStates from airflow.providers.google.cloud.links.datafusion import ( DataFusionInstanceLink, @@ -34,16 +34,25 @@ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.cloud.triggers.datafusion import DataFusionStartPipelineTrigger from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType +from airflow.providers.google.cloud.utils.helpers import resource_path_to_dict if TYPE_CHECKING: from airflow.utils.context import Context class DataFusionPipelineLinkHelper: - """Helper class for Pipeline links.""" + """ + Helper class for Pipeline links. + + .. warning:: + This class is deprecated. Consider using ``resource_path_to_dict()`` instead. + """ @staticmethod def get_project_id(instance): + raise AirflowProviderDeprecationWarning( + "DataFusionPipelineLinkHelper is deprecated. Consider using resource_path_to_dict() instead." + ) instance = instance["name"] project_id = next(x for x in instance.split("/") if x.startswith("airflow")) return project_id @@ -114,7 +123,7 @@ def execute(self, context: Context) -> None: instance = hook.wait_for_operation(operation) self.log.info("Instance %s restarted successfully", self.instance_name) - project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance) + project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, task_instance=self, @@ -272,7 +281,7 @@ def execute(self, context: Context) -> dict: instance_name=self.instance_name, location=self.location, project_id=self.project_id ) - project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance) + project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, task_instance=self, @@ -361,7 +370,7 @@ def execute(self, context: Context) -> None: instance = hook.wait_for_operation(operation) self.log.info("Instance %s updated successfully", self.instance_name) - project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance) + project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, task_instance=self, @@ -432,7 +441,7 @@ def execute(self, context: Context) -> dict: project_id=self.project_id, ) - project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance) + project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, task_instance=self, diff --git a/airflow/providers/google/cloud/utils/helpers.py b/airflow/providers/google/cloud/utils/helpers.py index 72216ec20b06..a0ff3e58ef4f 100644 --- a/airflow/providers/google/cloud/utils/helpers.py +++ b/airflow/providers/google/cloud/utils/helpers.py @@ -21,3 +21,24 @@ def normalize_directory_path(source_object: str | None) -> str | None: """Makes sure dir path ends with a slash.""" return source_object + "/" if source_object and not source_object.endswith("/") else source_object + + +def resource_path_to_dict(resource_name: str) -> dict[str, str]: + """Converts a path-like GCP resource name into a dictionary. + + For example, the path `projects/my-project/locations/my-location/instances/my-instance` will be converted + to a dict: + `{"projects": "my-project", + "locations": "my-location", + "instances": "my-instance",}` + """ + if not resource_name: + return {} + path_items = resource_name.split("/") + if len(path_items) % 2: + raise ValueError( + "Invalid resource_name. Expected the path-like name consisting of key/value pairs " + "'key1/value1/key2/value2/...', for example 'projects//locations/'." + ) + iterator = iter(path_items) + return dict(zip(iterator, iterator)) diff --git a/tests/providers/google/cloud/operators/test_datafusion.py b/tests/providers/google/cloud/operators/test_datafusion.py index a06b019f5e27..2783d3fc626f 100644 --- a/tests/providers/google/cloud/operators/test_datafusion.py +++ b/tests/providers/google/cloud/operators/test_datafusion.py @@ -39,6 +39,7 @@ from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType HOOK_STR = "airflow.providers.google.cloud.operators.datafusion.DataFusionHook" +RESOURCE_PATH_TO_DICT_STR = "airflow.providers.google.cloud.operators.datafusion.resource_path_to_dict" TASK_ID = "test_task" LOCATION = "test-location" @@ -54,9 +55,11 @@ class TestCloudDataFusionUpdateInstanceOperator: + @mock.patch(RESOURCE_PATH_TO_DICT_STR) @mock.patch(HOOK_STR) - def test_execute_check_hook_call_should_execute_successfully(self, mock_hook): + def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_to_dict): update_maks = "instance.name" + mock_resource_to_dict.return_value = {"projects": PROJECT_ID} op = CloudDataFusionUpdateInstanceOperator( task_id="test_tasks", instance_name=INSTANCE_NAME, @@ -78,8 +81,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook): class TestCloudDataFusionRestartInstanceOperator: + @mock.patch(RESOURCE_PATH_TO_DICT_STR) @mock.patch(HOOK_STR) - def test_execute_check_hook_call_should_execute_successfully(self, mock_hook): + def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict): + mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID} op = CloudDataFusionRestartInstanceOperator( task_id="test_tasks", instance_name=INSTANCE_NAME, @@ -95,8 +100,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook): class TestCloudDataFusionCreateInstanceOperator: + @mock.patch(RESOURCE_PATH_TO_DICT_STR) @mock.patch(HOOK_STR) - def test_execute_check_hook_call_should_execute_successfully(self, mock_hook): + def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict): + mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID} op = CloudDataFusionCreateInstanceOperator( task_id="test_tasks", instance_name=INSTANCE_NAME, @@ -133,8 +140,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook): class TestCloudDataFusionGetInstanceOperator: + @mock.patch(RESOURCE_PATH_TO_DICT_STR) @mock.patch(HOOK_STR) - def test_execute_check_hook_call_should_execute_successfully(self, mock_hook): + def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict): + mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID} op = CloudDataFusionGetInstanceOperator( task_id="test_tasks", instance_name=INSTANCE_NAME, diff --git a/tests/providers/google/cloud/utils/test_helpers.py b/tests/providers/google/cloud/utils/test_helpers.py index 9055af89da63..c277b8470bb8 100644 --- a/tests/providers/google/cloud/utils/test_helpers.py +++ b/tests/providers/google/cloud/utils/test_helpers.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations -from airflow.providers.google.cloud.utils.helpers import normalize_directory_path +import pytest + +from airflow.providers.google.cloud.utils.helpers import normalize_directory_path, resource_path_to_dict class TestHelpers: @@ -24,3 +26,18 @@ def test_normalize_directory_path(self): assert normalize_directory_path("dir_path") == "dir_path/" assert normalize_directory_path("dir_path/") == "dir_path/" assert normalize_directory_path(None) is None + + def test_resource_path_to_dict(self): + resource_name = "key1/value1/key2/value2" + expected_dict = {"key1": "value1", "key2": "value2"} + actual_dict = resource_path_to_dict(resource_name=resource_name) + assert set(actual_dict.items()) == set(expected_dict.items()) + + def test_resource_path_to_dict_empty(self): + resource_name = "" + expected_dict = {} + assert resource_path_to_dict(resource_name=resource_name) == expected_dict + + def test_resource_path_to_dict_fail(self): + with pytest.raises(ValueError): + resource_path_to_dict(resource_name="key/value/key")