From b13c664778247028fb1ab1a7aa745a471b8740be Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 4 Sep 2024 20:44:53 +0900 Subject: [PATCH] fix: rm update_custom_operator_name --- tests/utils/test_decorators.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 6e37bd6ce093..82c04bcd535b 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -117,13 +117,10 @@ def parse_python_source(task: Task, custom_operator_name: str | None = None) -> operator = task().operator if not hasattr(operator, "get_python_source"): pytest.skip(f"Operator {operator} does not have get_python_source method") + if custom_operator_name: - update_custom_operator_name(operator, custom_operator_name) + custom_operator_name = ( + custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" + ) + operator.__dict__["custom_operator_name"] = custom_operator_name return operator.get_python_source() - - -def update_custom_operator_name(operator: Any, custom_operator_name: str): - custom_operator_name = ( - custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" - ) - operator.__dict__["custom_operator_name"] = custom_operator_name