-
Notifications
You must be signed in to change notification settings - Fork 442
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SDK] fix grpc related bugs in Python SDK (#2398)
* fix: fix bugs in report_metrics. Signed-off-by: Electronic-Waste <[email protected]> * fix: fix bugs in tune. Signed-off-by: Electronic-Waste <[email protected]> * fix: fix bugs in get_trial_metrics. Signed-off-by: Electronic-Waste <[email protected]> * fix: update .gitignore and setup.py. Signed-off-by: Electronic-Waste <[email protected]> * fix: update Makefile. Signed-off-by: Electronic-Waste <[email protected]> * feat: add report_metrics_test.py. Signed-off-by: Electronic-Waste <[email protected]> * fix: fix lint error. Signed-off-by: Electronic-Waste <[email protected]> * feat: add UTs for get_trial_metrics. Signed-off-by: Electronic-Waste <[email protected]> * fix: update post_gen.py. Signed-off-by: Electronic-Waste <[email protected]> * refactor: rebase to master. Signed-off-by: Electronic-Waste <[email protected]> * test(sdk): use single katib_client. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): add TODO for import rewrite. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): fix lint error with black. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): fix lint error with isort. Signed-off-by: Electronic-Waste <[email protected]> * fix(sdk): reformat import in katib_client_test.py. Signed-off-by: Electronic-Waste <[email protected]> --------- Signed-off-by: Electronic-Waste <[email protected]>
- Loading branch information
1 parent
0e2ba6e
commit a524f33
Showing
9 changed files
with
240 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ dist/ | |
|
||
# Katib gRPC APIs | ||
kubeflow/katib/katib_api_pb2.py | ||
kubeflow/katib/katib_api_pb2_grpc.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from kubeflow.katib import report_metrics | ||
from kubeflow.katib.constants import constants | ||
|
||
TEST_RESULT_SUCCESS = "success" | ||
ENV_VARIABLE_EMPTY = True | ||
ENV_VARIABLE_NOT_EMPTY = False | ||
|
||
|
||
def report_observation_log_response(*args, **kwargs): | ||
if kwargs.get("timeout") == 0: | ||
raise TimeoutError | ||
|
||
|
||
test_report_metrics_data = [ | ||
( | ||
"valid metrics with float type", | ||
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT}, | ||
TEST_RESULT_SUCCESS, | ||
ENV_VARIABLE_NOT_EMPTY, | ||
), | ||
( | ||
"valid metrics with string type", | ||
{"metrics": {"result": "0.99"}, "timeout": constants.DEFAULT_TIMEOUT}, | ||
TEST_RESULT_SUCCESS, | ||
ENV_VARIABLE_NOT_EMPTY, | ||
), | ||
( | ||
"valid metrics with int type", | ||
{"metrics": {"result": 1}, "timeout": constants.DEFAULT_TIMEOUT}, | ||
TEST_RESULT_SUCCESS, | ||
ENV_VARIABLE_NOT_EMPTY, | ||
), | ||
( | ||
"ReportObservationLog timeout error", | ||
{"metrics": {"result": 0.99}, "timeout": 0}, | ||
RuntimeError, | ||
ENV_VARIABLE_NOT_EMPTY, | ||
), | ||
( | ||
"invalid metrics with type string", | ||
{"metrics": {"result": "abc"}, "timeout": constants.DEFAULT_TIMEOUT}, | ||
ValueError, | ||
ENV_VARIABLE_NOT_EMPTY, | ||
), | ||
( | ||
"Trial name is not passed to env variables", | ||
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT}, | ||
ValueError, | ||
ENV_VARIABLE_EMPTY, | ||
), | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def mock_getenv(request): | ||
with patch("os.getenv") as mock: | ||
if request.param is ENV_VARIABLE_EMPTY: | ||
mock.side_effect = ValueError | ||
else: | ||
mock.return_value = "example" | ||
yield mock | ||
|
||
|
||
@pytest.fixture | ||
def mock_get_current_k8s_namespace(): | ||
with patch("kubeflow.katib.utils.utils.get_current_k8s_namespace") as mock: | ||
mock.return_value = "test" | ||
yield mock | ||
|
||
|
||
@pytest.fixture | ||
def mock_report_observation_log(): | ||
with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock: | ||
mock_instance = mock.return_value | ||
mock_instance.ReportObservationLog.side_effect = report_observation_log_response | ||
yield mock_instance | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_name,kwargs,expected_output,mock_getenv", | ||
test_report_metrics_data, | ||
indirect=["mock_getenv"], | ||
) | ||
def test_report_metrics( | ||
test_name, | ||
kwargs, | ||
expected_output, | ||
mock_getenv, | ||
mock_get_current_k8s_namespace, | ||
mock_report_observation_log, | ||
): | ||
""" | ||
test report_metrics function | ||
""" | ||
print("\n\nExecuting test:", test_name) | ||
try: | ||
report_metrics(**kwargs) | ||
assert expected_output == TEST_RESULT_SUCCESS | ||
except Exception as e: | ||
assert type(e) is expected_output | ||
print("test execution complete") |
Oops, something went wrong.