Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK] fix grpc related bugs in Python SDK #2398

Merged
merged 16 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ pytest: prepare-pytest prepare-pytest-testdata
pytest ./test/unit/v1beta1/earlystopping
pytest ./test/unit/v1beta1/metricscollector
cp ./pkg/apis/manager/v1beta1/python/api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
cp ./pkg/apis/manager/v1beta1/python/api_pb2_grpc.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
sed -i "s/api_pb2/kubeflow\.katib\.katib_api_pb2/g" ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py
pytest ./sdk/python/v1beta1/kubeflow/katib
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py
rm ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2.py ./sdk/python/v1beta1/kubeflow/katib/katib_api_pb2_grpc.py

# The skopt service doesn't work appropriately with Python 3.11.
# So, we need to run the test with Python 3.9.
Expand Down
4 changes: 2 additions & 2 deletions hack/gen-python-sdk/post_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
if (output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py"):
lines.append("# Import Katib API client.\n")
lines.append("from kubeflow.katib.api.katib_client import KatibClient\n")
lines.append("# Import Katib report metrics functions")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics")
lines.append("# Import Katib report metrics functions\n")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics\n")
lines.append("# Import Katib helper functions.\n")
lines.append("import kubeflow.katib.api.search as search\n")
lines.append("# Import Katib helper constants.\n")
Expand Down
1 change: 1 addition & 0 deletions sdk/python/v1beta1/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ dist/

# Katib gRPC APIs
kubeflow/katib/katib_api_pb2.py
kubeflow/katib/katib_api_pb2_grpc.py
4 changes: 3 additions & 1 deletion sdk/python/v1beta1/kubeflow/katib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@

# Import Katib API client.
from kubeflow.katib.api.katib_client import KatibClient
# Import Katib report metrics functionsfrom kubeflow.katib.api.report_metrics import report_metrics# Import Katib helper functions.
# Import Katib report metrics functions
from kubeflow.katib.api.report_metrics import report_metrics
# Import Katib helper functions.
import kubeflow.katib.api.search as search
# Import Katib helper constants.
from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW
Expand Down
32 changes: 15 additions & 17 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from kubeflow.katib.api_client import ApiClient
from kubeflow.katib.constants import constants
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
from kubeflow.katib.utils import utils
from kubernetes import client
from kubernetes import config
Expand Down Expand Up @@ -386,7 +387,7 @@ def tune(

# Add metrics collector to the Katib Experiment.
# Up to now, We only support parameter `kind`, of which default value is `StdOut`, to specify the kind of metrics collector.
experiment.spec.metrics_collector = models.V1beta1MetricsCollectorSpec(
experiment.spec.metrics_collector_spec = models.V1beta1MetricsCollectorSpec(
collector=models.V1beta1CollectorSpec(kind=metrics_collector_config["kind"])
)

Expand Down Expand Up @@ -1280,21 +1281,18 @@ def get_trial_metrics(

namespace = namespace or self.namespace

db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)
channel = grpc.insecure_channel(db_manager_address)

with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
# When metric name is empty, we select all logs from the Katib DB.
observation_logs = client.GetObservationLog(
katib_api_pb2.GetObservationLogRequest(trial_name=name),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
)
client = katib_api_pb2_grpc.DBManagerStub(channel)
try:
# When metric name is empty, we select all logs from the Katib DB.
observation_logs = client.GetObservationLog(
katib_api_pb2.GetObservationLogRequest(trial_name=name),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to get metrics for Trial {namespace}/{name}. Exception: {e}"
)

return observation_logs.observation_log.metric_logs
return observation_logs.observation_log.metric_logs
96 changes: 93 additions & 3 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kubeflow.katib import V1beta1TrialParameterSpec
from kubeflow.katib import V1beta1TrialTemplate
from kubeflow.katib.constants import constants
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
from kubernetes.client import V1ObjectMeta
import pytest

Expand Down Expand Up @@ -238,7 +239,7 @@ def create_experiment(


@pytest.fixture
def katib_client():
def katib_client_create_experiment():
Electronic-Waste marked this conversation as resolved.
Show resolved Hide resolved
with patch(
"kubernetes.client.CustomObjectsApi",
return_value=Mock(
Expand All @@ -255,14 +256,103 @@ def katib_client():


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_create_experiment_data)
def test_create_experiment(katib_client, test_name, kwargs, expected_output):
def test_create_experiment(katib_client_create_experiment, test_name, kwargs, expected_output):
"""
test create_experiment function of katib client
"""
print("\n\nExecuting test:", test_name)
try:
katib_client.create_experiment(**kwargs)
katib_client_create_experiment.create_experiment(**kwargs)
assert expected_output == TEST_RESULT_SUCCESS
except Exception as e:
assert type(e) is expected_output
print("test execution complete")


def get_observation_log_response(*args, **kwargs):
if kwargs.get("timeout") == 0:
raise TimeoutError
elif args[0].trial_name == "invalid":
raise RuntimeError
else:
return katib_api_pb2.GetObservationLogReply(
observation_log=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result",value="0.99")
)
]
)
)

test_get_trial_metrics_data = [
(
"valid trial name",
{
"name": "example",
"namespace": "valid",
"timeout": constants.DEFAULT_TIMEOUT
},
[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result",value="0.99")
)
]
),
(
"invalid trial name",
{
"name": "invalid",
"namespace": "invalid",
"timeout": constants.DEFAULT_TIMEOUT
},
RuntimeError
),
(
"GetObservationLog timeout error",
{
"name": "example",
"namespace": "valid",
"timeout": 0
},
RuntimeError
)
]


@pytest.fixture
def katib_client_get_trial_metrics():
with patch(
"kubernetes.client.CustomObjectsApi",
return_value=Mock(),
), patch(
"kubernetes.config.load_kube_config",
return_value=Mock()
):
client = KatibClient()
yield client


@pytest.fixture
def mock_get_observation_log():
with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock:
mock_instance = mock.return_value
mock_instance.GetObservationLog.side_effect = get_observation_log_response
yield mock_instance


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_get_trial_metrics_data)
def test_get_trial_metrics(test_name, kwargs, expected_output, katib_client_get_trial_metrics, mock_get_observation_log):
"""
test get_trial_metrics function of katib client
"""
print("\n\nExecuting test:", test_name)
try:
metrics = katib_client_get_trial_metrics.get_trial_metrics(**kwargs)
for i in range(len(metrics)):
assert metrics[i] == expected_output[i]
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
54 changes: 26 additions & 28 deletions sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import grpc
from kubeflow.katib.constants import constants
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
from kubeflow.katib.utils import utils


Expand All @@ -39,9 +40,9 @@ def report_metrics(
timeout: Optional, gRPC API Server timeout in seconds to report metrics.

Raises:
ValueError: The Trial name is not passed to environment variables.
RuntimeError: Unable to push Trial metrics to Katib DB or
ValueError: The Trial name is not passed to environment variables or
metrics value has incorrect format (cannot be converted to type `float`).
RuntimeError: Unable to push Trial metrics to Katib DB.
"""

# Get Trial's namespace and name
Expand All @@ -53,35 +54,32 @@ def report_metrics(
)

# Get channel for grpc call to db manager
db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)
channel = grpc.insecure_channel(db_manager_address)

# Validate metrics value in dict
for value in metrics.values():
utils.validate_metrics_value(value)

# Dial katib db manager to report metrics
with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_logs=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(name=name,value=str(value))
)
for name, value in metrics.items()
]
)
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)
client = katib_api_pb2_grpc.DBManagerStub(channel)
try:
timestamp = datetime.now(timezone.utc).strftime(constants.RFC3339_FORMAT)
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_log=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(name=name,value=str(value))
)
for name, value in metrics.items()
]
)
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)
Loading