diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 6272c048c9ef3..f7644e379d003 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from datetime import timedelta from functools import cached_property from typing import TYPE_CHECKING, Any, Iterable, Sequence @@ -25,6 +26,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri +from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -241,6 +243,7 @@ class EmrContainerSensor(BaseSensorOperator): :param aws_conn_id: aws connection to use, defaults to 'aws_default' :param poll_interval: Time in seconds to wait between two consecutive call to check query status on athena, defaults to 10 + :param deferrable: Run sensor in the deferrable mode. """ INTERMEDIATE_STATES = ( @@ -267,6 +270,7 @@ def __init__( max_retries: int | None = None, aws_conn_id: str = "aws_default", poll_interval: int = 10, + deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -275,6 +279,11 @@ def __init__( self.job_id = job_id self.poll_interval = poll_interval self.max_retries = max_retries + self.deferrable = deferrable + + @cached_property + def hook(self) -> EmrContainerHook: + return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) def poke(self, context: Context) -> bool: state = self.hook.poll_query_status( @@ -290,10 +299,31 @@ def poke(self, context: Context) -> bool: return False return True - @cached_property - def hook(self) -> EmrContainerHook: - """Create and return an EmrContainerHook.""" - return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) + def execute(self, context: Context): + if not self.deferrable: + super().execute(context=context) + else: + timeout = ( + timedelta(seconds=self.max_retries * self.poll_interval + 60) + if self.max_retries + else self.execution_timeout + ) + self.defer( + timeout=timeout, + trigger=EmrContainerSensorTrigger( + virtual_cluster_id=self.virtual_cluster_id, + job_id=self.job_id, + aws_conn_id=self.aws_conn_id, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + else: + self.log.info(event["message"]) class EmrNotebookExecutionSensor(EmrBaseSensor): diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 1c3c8bb8338b2..6ea2c3b25cd61 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -17,12 +17,13 @@ from __future__ import annotations import asyncio -from typing import Any +from functools import cached_property +from typing import Any, AsyncIterator from botocore.exceptions import WaiterError from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.helpers import prune_dict @@ -246,3 +247,73 @@ async def run(self): "message": "JobFlow terminated successfully", } ) + + +class EmrContainerSensorTrigger(BaseTrigger): + """ + Poll for the status of EMR container until reaches terminal state. + + :param virtual_cluster_id: Reference Emr cluster id + :param job_id: job_id to check the state + :param aws_conn_id: Reference to AWS connection id + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + virtual_cluster_id: str, + job_id: str, + aws_conn_id: str = "aws_default", + poll_interval: int = 30, + **kwargs: Any, + ): + self.virtual_cluster_id = virtual_cluster_id + self.job_id = job_id + self.aws_conn_id = aws_conn_id + self.poll_interval = poll_interval + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EmrContainerHook: + return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes EmrContainerSensorTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger", + { + "virtual_cluster_id": self.virtual_cluster_id, + "job_id": self.job_id, + "aws_conn_id": self.aws_conn_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("container_job_complete", deferrable=True, client=client) + attempt = 0 + while True: + attempt = attempt + 1 + try: + await waiter.wait( + id=self.job_id, + virtualClusterId=self.virtual_cluster_id, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"}) + break + self.log.info( + "Job status is %s. Retrying attempt %s", + error.last_response["jobRun"]["state"], + attempt, + ) + await asyncio.sleep(int(self.poll_interval)) + + yield TriggerEvent({"status": "success", "job_id": self.job_id}) diff --git a/airflow/providers/amazon/aws/waiters/emr-containers.json b/airflow/providers/amazon/aws/waiters/emr-containers.json new file mode 100644 index 0000000000000..a4174b0536e50 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/emr-containers.json @@ -0,0 +1,30 @@ +{ + "version": 2, + "waiters": { + "container_job_complete": { + "operation": "DescribeJobRun", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "COMPLETED", + "state": "success" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "CANCELLED", + "state": "failure" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py b/tests/providers/amazon/aws/sensors/test_emr_containers.py index 38d7688f6617d..0df3657288b9b 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_containers.py +++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py @@ -21,9 +21,10 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor +from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger class TestEmrContainerSensor: @@ -73,3 +74,13 @@ def test_poke_cancel_pending(self, mock_check_query_status): with pytest.raises(AirflowException) as ctx: self.sensor.poke(None) assert "EMR Containers sensor failed" in str(ctx.value) + + @mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor.poke") + def test_sensor_defer(self, mock_poke): + self.sensor.deferrable = True + mock_poke.return_value = False + with pytest.raises(TaskDeferred) as exc: + self.sensor.execute(context=None) + assert isinstance( + exc.value.trigger, EmrContainerSensorTrigger + ), "Trigger is not a EmrContainerSensorTrigger" diff --git a/tests/providers/amazon/aws/triggers/test_emr.py b/tests/providers/amazon/aws/triggers/test_emr.py index 5d599801a280a..86e54cb94ae23 100644 --- a/tests/providers/amazon/aws/triggers/test_emr.py +++ b/tests/providers/amazon/aws/triggers/test_emr.py @@ -24,13 +24,21 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrHook -from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger, EmrTerminateJobFlowTrigger +from airflow.providers.amazon.aws.triggers.emr import ( + EmrContainerSensorTrigger, + EmrCreateJobFlowTrigger, + EmrTerminateJobFlowTrigger, +) from airflow.triggers.base import TriggerEvent TEST_JOB_FLOW_ID = "test-job-flow-id" TEST_POLL_INTERVAL = 10 TEST_MAX_ATTEMPTS = 10 TEST_AWS_CONN_ID = "test-aws-id" +VIRTUAL_CLUSTER_ID = "vzwemreks" +JOB_ID = "job-1234" +AWS_CONN_ID = "aws_emr_conn" +POLL_INTERVAL = 60 class TestEmrCreateJobFlowTrigger: @@ -350,3 +358,105 @@ async def test_emr_terminate_job_flow_trigger_run_attempts_failed( assert str(exc.value) == f"JobFlow termination failed: {error_failed}" assert mock_get_waiter().wait.call_count == 3 + + +class TestEmrContainerSensorTrigger: + def test_emr_container_sensor_trigger_serialize(self): + emr_trigger = EmrContainerSensorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + class_path, args = emr_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger" + assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID + assert args["job_id"] == JOB_ID + assert args["aws_conn_id"] == AWS_CONN_ID + assert args["poll_interval"] == POLL_INTERVAL + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") + async def test_emr_container_trigger_run(self, mock_async_conn, mock_get_waiter): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + mock_get_waiter().wait = AsyncMock() + + emr_trigger = EmrContainerSensorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + + generator = emr_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "job_id": JOB_ID}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") + async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn, mock_get_waiter, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"jobRun": {"state": "RUNNING"}}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + emr_trigger = EmrContainerSensorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + + generator = emr_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "success", "job_id": JOB_ID}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") + async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, mock_get_waiter, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + error_available = WaiterError( + name="test_name", + reason="Max attempts exceeded", + last_response={"jobRun": {"state": "FAILED"}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state", + last_response={"jobRun": {"state": "FAILED"}}, + ) + mock_get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_available, error_available, error_failed] + ) + mock_sleep.return_value = True + + emr_trigger = EmrContainerSensorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + + generator = emr_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "failure", "message": f"Job Failed: {error_failed}"})