Skip to content

Commit

Permalink
Refactor DataFusionInstanceLink usage (apache#34514)
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov authored Oct 13, 2023
1 parent 0e5890b commit d27d0bb
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 11 deletions.
21 changes: 15 additions & 6 deletions airflow/providers/google/cloud/operators/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from googleapiclient.errors import HttpError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, DataFusionHook, PipelineStates
from airflow.providers.google.cloud.links.datafusion import (
DataFusionInstanceLink,
Expand All @@ -34,16 +34,25 @@
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.datafusion import DataFusionStartPipelineTrigger
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType
from airflow.providers.google.cloud.utils.helpers import resource_path_to_dict

if TYPE_CHECKING:
from airflow.utils.context import Context


class DataFusionPipelineLinkHelper:
"""Helper class for Pipeline links."""
"""
Helper class for Pipeline links.
.. warning::
This class is deprecated. Consider using ``resource_path_to_dict()`` instead.
"""

@staticmethod
def get_project_id(instance):
raise AirflowProviderDeprecationWarning(
"DataFusionPipelineLinkHelper is deprecated. Consider using resource_path_to_dict() instead."
)
instance = instance["name"]
project_id = next(x for x in instance.split("/") if x.startswith("airflow"))
return project_id
Expand Down Expand Up @@ -114,7 +123,7 @@ def execute(self, context: Context) -> None:
instance = hook.wait_for_operation(operation)
self.log.info("Instance %s restarted successfully", self.instance_name)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down Expand Up @@ -272,7 +281,7 @@ def execute(self, context: Context) -> dict:
instance_name=self.instance_name, location=self.location, project_id=self.project_id
)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down Expand Up @@ -361,7 +370,7 @@ def execute(self, context: Context) -> None:
instance = hook.wait_for_operation(operation)
self.log.info("Instance %s updated successfully", self.instance_name)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down Expand Up @@ -432,7 +441,7 @@ def execute(self, context: Context) -> dict:
project_id=self.project_id,
)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down
21 changes: 21 additions & 0 deletions airflow/providers/google/cloud/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,24 @@
def normalize_directory_path(source_object: str | None) -> str | None:
"""Makes sure dir path ends with a slash."""
return source_object + "/" if source_object and not source_object.endswith("/") else source_object


def resource_path_to_dict(resource_name: str) -> dict[str, str]:
"""Converts a path-like GCP resource name into a dictionary.
For example, the path `projects/my-project/locations/my-location/instances/my-instance` will be converted
to a dict:
`{"projects": "my-project",
"locations": "my-location",
"instances": "my-instance",}`
"""
if not resource_name:
return {}
path_items = resource_name.split("/")
if len(path_items) % 2:
raise ValueError(
"Invalid resource_name. Expected the path-like name consisting of key/value pairs "
"'key1/value1/key2/value2/...', for example 'projects/<project>/locations/<location>'."
)
iterator = iter(path_items)
return dict(zip(iterator, iterator))
17 changes: 13 additions & 4 deletions tests/providers/google/cloud/operators/test_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType

HOOK_STR = "airflow.providers.google.cloud.operators.datafusion.DataFusionHook"
RESOURCE_PATH_TO_DICT_STR = "airflow.providers.google.cloud.operators.datafusion.resource_path_to_dict"

TASK_ID = "test_task"
LOCATION = "test-location"
Expand All @@ -54,9 +55,11 @@


class TestCloudDataFusionUpdateInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_to_dict):
update_maks = "instance.name"
mock_resource_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionUpdateInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand All @@ -78,8 +81,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):


class TestCloudDataFusionRestartInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionRestartInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand All @@ -95,8 +100,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):


class TestCloudDataFusionCreateInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionCreateInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand Down Expand Up @@ -133,8 +140,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):


class TestCloudDataFusionGetInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionGetInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand Down
19 changes: 18 additions & 1 deletion tests/providers/google/cloud/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,28 @@
# under the License.
from __future__ import annotations

from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
import pytest

from airflow.providers.google.cloud.utils.helpers import normalize_directory_path, resource_path_to_dict


class TestHelpers:
def test_normalize_directory_path(self):
assert normalize_directory_path("dir_path") == "dir_path/"
assert normalize_directory_path("dir_path/") == "dir_path/"
assert normalize_directory_path(None) is None

def test_resource_path_to_dict(self):
resource_name = "key1/value1/key2/value2"
expected_dict = {"key1": "value1", "key2": "value2"}
actual_dict = resource_path_to_dict(resource_name=resource_name)
assert set(actual_dict.items()) == set(expected_dict.items())

def test_resource_path_to_dict_empty(self):
resource_name = ""
expected_dict = {}
assert resource_path_to_dict(resource_name=resource_name) == expected_dict

def test_resource_path_to_dict_fail(self):
with pytest.raises(ValueError):
resource_path_to_dict(resource_name="key/value/key")

0 comments on commit d27d0bb

Please sign in to comment.