Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK] fix grpc related bugs in Python SDK #2398

Merged
merged 16 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,17 @@ ifeq ("$(wildcard $(TEST_TENSORFLOW_EVENT_FILE_PATH))", "")
python examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py --epochs 5 --batch-size 200 --log-path $(TEST_TENSORFLOW_EVENT_FILE_PATH)
endif

# TODO(Electronic-Waste): Remove the import rewrite when protobuf supports `python_package` option.
# REF: https://github.com/protocolbuffers/protobuf/issues/7061
pytest: prepare-pytest prepare-pytest-testdata
pytest ./test/unit/v1beta1/suggestion --ignore=./test/unit/v1beta1/suggestion/test_skopt_service.py
pytest ./test/unit/v1beta1/earlystopping
pytest ./test/unit/v1beta1/metricscollector
cp ./pkg/apis/manager/v1beta1/python/api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
cp ./pkg/apis/manager/v1beta1/python/api_pb2_grpc.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
sed -i "s/api_pb2/kubeflow\.katib\.katib_api_pb2/g" ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
pytest ./sdk/python/v1beta1/kubeflow/katib
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py

# The skopt service doesn't work appropriately with Python 3.11.
# So, we need to run the test with Python 3.9.
Expand Down
4 changes: 2 additions & 2 deletions hack/gen-python-sdk/post_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
if output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py":
lines.append("# Import Katib API client.\n")
lines.append("from kubeflow.katib.api.katib_client import KatibClient\n")
lines.append("# Import Katib report metrics functions")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics")
lines.append("# Import Katib report metrics functions\n")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics\n")
lines.append("# Import Katib helper functions.\n")
lines.append("import kubeflow.katib.api.search as search\n")
lines.append("# Import Katib helper constants.\n")
Expand Down
1 change: 1 addition & 0 deletions sdk/python/v1beta1/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ dist/

# Katib gRPC APIs
kubeflow/katib/katib_api_pb2.py
kubeflow/katib/katib_api_pb2_grpc.py
4 changes: 3 additions & 1 deletion sdk/python/v1beta1/kubeflow/katib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@

# Import Katib API client.
from kubeflow.katib.api.katib_client import KatibClient
# Import Katib report metrics functionsfrom kubeflow.katib.api.report_metrics import report_metrics# Import Katib helper functions.
# Import Katib report metrics functions
from kubeflow.katib.api.report_metrics import report_metrics
# Import Katib helper functions.
import kubeflow.katib.api.search as search
# Import Katib helper constants.
from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW
Expand Down
30 changes: 14 additions & 16 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
from kubeflow.katib import models
from kubeflow.katib.api_client import ApiClient
from kubeflow.katib.constants import constants
Expand Down Expand Up @@ -1305,21 +1306,18 @@ def get_trial_metrics(

namespace = namespace or self.namespace

db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)
channel = grpc.insecure_channel(db_manager_address)

with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
# When metric name is empty, we select all logs from the Katib DB.
observation_logs = client.GetObservationLog(
katib_api_pb2.GetObservationLogRequest(trial_name=name),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
)
client = katib_api_pb2_grpc.DBManagerStub(channel)
try:
# When metric name is empty, we select all logs from the Katib DB.
observation_logs = client.GetObservationLog(
katib_api_pb2.GetObservationLogRequest(trial_name=name),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
)

return observation_logs.observation_log.metric_logs
return observation_logs.observation_log.metric_logs
71 changes: 70 additions & 1 deletion sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional
from unittest.mock import Mock, patch

import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import pytest
from kubeflow.katib import (
KatibClient,
Expand Down Expand Up @@ -38,6 +39,24 @@ def create_namespaced_custom_object_response(*args, **kwargs):
return {"metadata": {"name": "12345-experiment-mnist-ci-test"}}


def get_observation_log_response(*args, **kwargs):
if kwargs.get("timeout") == 0:
raise TimeoutError
elif args[0].trial_name == "invalid":
raise RuntimeError
else:
return katib_api_pb2.GetObservationLogReply(
observation_log=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result", value="0.99"),
)
]
)
)


def generate_trial_template() -> V1beta1TrialTemplate:
trial_spec = {
"apiVersion": "batch/v1",
Expand Down Expand Up @@ -223,6 +242,34 @@ def create_experiment(
]


test_get_trial_metrics_data = [
(
"valid trial name",
{"name": "example", "namespace": "valid", "timeout": constants.DEFAULT_TIMEOUT},
[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result", value="0.99"),
)
],
),
(
"invalid trial name",
{
"name": "invalid",
"namespace": "invalid",
"timeout": constants.DEFAULT_TIMEOUT,
},
RuntimeError,
),
(
"GetObservationLog timeout error",
{"name": "example", "namespace": "valid", "timeout": 0},
RuntimeError,
),
]


@pytest.fixture
def katib_client():
with patch(
Expand All @@ -232,7 +279,12 @@ def katib_client():
side_effect=create_namespaced_custom_object_response
)
),
), patch("kubernetes.config.load_kube_config", return_value=Mock()):
), patch("kubernetes.config.load_kube_config", return_value=Mock()), patch(
"kubeflow.katib.katib_api_pb2_grpc.DBManagerStub",
return_value=Mock(
GetObservationLog=Mock(side_effect=get_observation_log_response)
),
):
client = KatibClient()
yield client

Expand All @@ -251,3 +303,20 @@ def test_create_experiment(katib_client, test_name, kwargs, expected_output):
except Exception as e:
assert type(e) is expected_output
print("test execution complete")


@pytest.mark.parametrize(
"test_name,kwargs,expected_output", test_get_trial_metrics_data
)
def test_get_trial_metrics(katib_client, test_name, kwargs, expected_output):
"""
test get_trial_metrics function of katib client
"""
print("\n\nExecuting test:", test_name)
try:
metrics = katib_client.get_trial_metrics(**kwargs)
for i in range(len(metrics)):
assert metrics[i] == expected_output[i]
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
54 changes: 25 additions & 29 deletions sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
from kubeflow.katib.constants import constants
from kubeflow.katib.utils import utils

Expand All @@ -38,9 +39,9 @@ def report_metrics(
timeout: Optional, gRPC API Server timeout in seconds to report metrics.

Raises:
ValueError: The Trial name is not passed to environment variables.
RuntimeError: Unable to push Trial metrics to Katib DB or
ValueError: The Trial name is not passed to environment variables or
metrics value has incorrect format (cannot be converted to type `float`).
RuntimeError: Unable to push Trial metrics to Katib DB.
"""

# Get Trial's namespace and name
Expand All @@ -50,37 +51,32 @@ def report_metrics(
raise ValueError("The Trial name is not passed to environment variables")

# Get channel for grpc call to db manager
db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)
channel = grpc.insecure_channel(db_manager_address)

# Validate metrics value in dict
for value in metrics.values():
utils.validate_metrics_value(value)

# Dial katib db manager to report metrics
with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_logs=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(
name=name, value=str(value)
),
)
for name, value in metrics.items()
]
),
client = katib_api_pb2_grpc.DBManagerStub(channel)
try:
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_log=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(name=name, value=str(value)),
)
for name, value in metrics.items()
]
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)
104 changes: 104 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest.mock import patch

import pytest
from kubeflow.katib import report_metrics
from kubeflow.katib.constants import constants

TEST_RESULT_SUCCESS = "success"
ENV_VARIABLE_EMPTY = True
ENV_VARIABLE_NOT_EMPTY = False


def report_observation_log_response(*args, **kwargs):
if kwargs.get("timeout") == 0:
raise TimeoutError


test_report_metrics_data = [
(
"valid metrics with float type",
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT},
TEST_RESULT_SUCCESS,
ENV_VARIABLE_NOT_EMPTY,
),
(
"valid metrics with string type",
{"metrics": {"result": "0.99"}, "timeout": constants.DEFAULT_TIMEOUT},
TEST_RESULT_SUCCESS,
ENV_VARIABLE_NOT_EMPTY,
),
(
"valid metrics with int type",
{"metrics": {"result": 1}, "timeout": constants.DEFAULT_TIMEOUT},
TEST_RESULT_SUCCESS,
ENV_VARIABLE_NOT_EMPTY,
),
(
"ReportObservationLog timeout error",
{"metrics": {"result": 0.99}, "timeout": 0},
RuntimeError,
ENV_VARIABLE_NOT_EMPTY,
),
(
"invalid metrics with type string",
{"metrics": {"result": "abc"}, "timeout": constants.DEFAULT_TIMEOUT},
ValueError,
ENV_VARIABLE_NOT_EMPTY,
),
(
"Trial name is not passed to env variables",
{"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT},
ValueError,
ENV_VARIABLE_EMPTY,
),
]


@pytest.fixture
def mock_getenv(request):
with patch("os.getenv") as mock:
if request.param is ENV_VARIABLE_EMPTY:
mock.side_effect = ValueError
else:
mock.return_value = "example"
yield mock


@pytest.fixture
def mock_get_current_k8s_namespace():
with patch("kubeflow.katib.utils.utils.get_current_k8s_namespace") as mock:
mock.return_value = "test"
yield mock


@pytest.fixture
def mock_report_observation_log():
with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock:
mock_instance = mock.return_value
mock_instance.ReportObservationLog.side_effect = report_observation_log_response
yield mock_instance


@pytest.mark.parametrize(
"test_name,kwargs,expected_output,mock_getenv",
test_report_metrics_data,
indirect=["mock_getenv"],
)
def test_report_metrics(
test_name,
kwargs,
expected_output,
mock_getenv,
mock_get_current_k8s_namespace,
mock_report_observation_log,
):
"""
test report_metrics function
"""
print("\n\nExecuting test:", test_name)
try:
report_metrics(**kwargs)
assert expected_output == TEST_RESULT_SUCCESS
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
Loading
Loading