diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 6e37bd6ce0938..82c04bcd535b9 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -117,13 +117,10 @@ def parse_python_source(task: Task, custom_operator_name: str | None = None) -> operator = task().operator if not hasattr(operator, "get_python_source"): pytest.skip(f"Operator {operator} does not have get_python_source method") + if custom_operator_name: - update_custom_operator_name(operator, custom_operator_name) + custom_operator_name = ( + custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" + ) + operator.__dict__["custom_operator_name"] = custom_operator_name return operator.get_python_source() - - -def update_custom_operator_name(operator: Any, custom_operator_name: str): - custom_operator_name = ( - custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" - ) - operator.__dict__["custom_operator_name"] = custom_operator_name