diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 85a884a53e595..1b6466f7678b7 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -149,10 +149,12 @@ def _repair_task( databricks_run_id, ) + run_data = hook.get_run(databricks_run_id) repair_json = { "run_id": databricks_run_id, "latest_repair_id": repair_history_id, "rerun_tasks": tasks_to_repair, + **run_data.get("overriding_parameters", {}), } return hook.repair_run(repair_json) diff --git a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py index c41bd9690b1f5..628fa43f613b9 100644 --- a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py @@ -96,6 +96,36 @@ def test_repair_task(mock_databricks_hook): mock_hook_instance.repair_run.assert_called_once() +@patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook") +def test_repair_task_with_params(mock_databricks_hook): + mock_hook_instance = mock_databricks_hook.return_value + mock_hook_instance.get_latest_repair_id.return_value = 100 + mock_hook_instance.repair_run.return_value = 200 + mock_hook_instance.get_run.return_value = { + "overriding_parameters": { + "key1": "value1", + "key2": "value2", + } + } + + tasks_to_repair = ["task1", "task2"] + result = _repair_task(DATABRICKS_CONN_ID, DATABRICKS_RUN_ID, tasks_to_repair, LOG) + + expected_payload = { + "run_id": DATABRICKS_RUN_ID, + "rerun_tasks": tasks_to_repair, + "overriding_parameters": { + "key1": "value1", + "key2": "value2", + } + } + assert result == 200 + mock_hook_instance.get_latest_repair_id.assert_called_once_with(DATABRICKS_RUN_ID) + mock_hook_instance.get_run.assert_called_once_with(DATABRICKS_RUN_ID) + mock_hook_instance.repair_run.assert_called_once_with(expected_payload) + + + def test_get_launch_task_id_no_launch_task(): task_group = MagicMock(get_child_by_label=MagicMock(side_effect=KeyError)) task_group.parent_group = None