forked from kubeflow/katib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added test for create_experiment in katib_client
Signed-off-by: tariq-hasan <[email protected]>
- Loading branch information
1 parent
7c03cb4
commit d5c1851
Showing
3 changed files
with
296 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
290 changes: 290 additions & 0 deletions
290
sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,290 @@ | ||
import multiprocessing | ||
from typing import List, Optional | ||
from unittest.mock import patch, Mock | ||
|
||
import pytest | ||
from kubernetes.client import V1ObjectMeta | ||
|
||
from kubeflow.katib import KatibClient | ||
from kubeflow.katib import V1beta1AlgorithmSpec | ||
from kubeflow.katib import V1beta1Experiment | ||
from kubeflow.katib import V1beta1ExperimentSpec | ||
from kubeflow.katib import V1beta1FeasibleSpace | ||
from kubeflow.katib import V1beta1ObjectiveSpec | ||
from kubeflow.katib import V1beta1ParameterSpec | ||
from kubeflow.katib import V1beta1TrialParameterSpec | ||
from kubeflow.katib import V1beta1TrialTemplate | ||
from kubeflow.katib.constants import constants | ||
|
||
|
||
class ConflictException(Exception): | ||
def __init__(self): | ||
self.status = 409 | ||
|
||
|
||
def create_namespaced_custom_object_response(*args, **kwargs): | ||
if args[2] == "timeout": | ||
raise multiprocessing.TimeoutError() | ||
elif args[2] == "conflict": | ||
raise ConflictException() | ||
elif args[2] == "runtime": | ||
raise Exception() | ||
elif args[2] in ("test", "test-name"): | ||
return {"metadata": {"name": "experiment-mnist-ci-test"}} | ||
elif args[2] == "test-generate-name": | ||
return {"metadata": {"name": "12345-experiment-mnist-ci-test"}} | ||
|
||
|
||
def generate_trial_template() -> V1beta1TrialTemplate: | ||
trial_spec={ | ||
"apiVersion": "batch/v1", | ||
"kind": "Job", | ||
"spec": { | ||
"template": { | ||
"metadata": { | ||
"annotations": { | ||
"sidecar.istio.io/inject": "false" | ||
} | ||
}, | ||
"spec": { | ||
"containers": [ | ||
{ | ||
"name": "training-container", | ||
"image": "docker.io/kubeflowkatib/pytorch-mnist-cpu:v0.14.0", | ||
"command": [ | ||
"python3", | ||
"/opt/pytorch-mnist/mnist.py", | ||
"--epochs=1", | ||
"--batch-size=64", | ||
"--lr=${trialParameters.learningRate}", | ||
"--momentum=${trialParameters.momentum}", | ||
] | ||
} | ||
], | ||
"restartPolicy": "Never" | ||
} | ||
} | ||
} | ||
} | ||
|
||
return V1beta1TrialTemplate( | ||
primary_container_name="training-container", | ||
trial_parameters=[ | ||
V1beta1TrialParameterSpec( | ||
name="learningRate", | ||
description="Learning rate for the training model", | ||
reference="lr" | ||
), | ||
V1beta1TrialParameterSpec( | ||
name="momentum", | ||
description="Momentum for the training model", | ||
reference="momentum" | ||
), | ||
], | ||
trial_spec=trial_spec | ||
) | ||
|
||
|
||
def generate_experiment( | ||
metadata: V1ObjectMeta, | ||
algorithm_spec: V1beta1AlgorithmSpec, | ||
objective_spec: V1beta1ObjectiveSpec, | ||
parameters: List[V1beta1ParameterSpec], | ||
trial_template: V1beta1TrialTemplate, | ||
) -> V1beta1Experiment: | ||
return V1beta1Experiment( | ||
api_version=constants.API_VERSION, | ||
kind=constants.EXPERIMENT_KIND, | ||
metadata=metadata, | ||
spec=V1beta1ExperimentSpec( | ||
max_trial_count=3, | ||
parallel_trial_count=2, | ||
max_failed_trial_count=1, | ||
algorithm=algorithm_spec, | ||
objective=objective_spec, | ||
parameters=parameters, | ||
trial_template=trial_template, | ||
) | ||
) | ||
|
||
|
||
def create_experiment( | ||
name: Optional[str] = None, | ||
generate_name: Optional[str] = None | ||
) -> V1beta1Experiment: | ||
experiment_namespace = "test" | ||
|
||
if name is not None: | ||
metadata = V1ObjectMeta(name=name, namespace=experiment_namespace) | ||
elif generate_name is not None: | ||
metadata = V1ObjectMeta(generate_name=generate_name, namespace=experiment_namespace) | ||
else: | ||
metadata = V1ObjectMeta(namespace=experiment_namespace) | ||
|
||
algorithm_spec=V1beta1AlgorithmSpec( | ||
algorithm_name="random" | ||
) | ||
|
||
objective_spec=V1beta1ObjectiveSpec( | ||
type="minimize", | ||
goal= 0.001, | ||
objective_metric_name="loss", | ||
) | ||
|
||
parameters=[ | ||
V1beta1ParameterSpec( | ||
name="lr", | ||
parameter_type="double", | ||
feasible_space=V1beta1FeasibleSpace( | ||
min="0.01", | ||
max="0.06" | ||
), | ||
), | ||
V1beta1ParameterSpec( | ||
name="momentum", | ||
parameter_type="double", | ||
feasible_space=V1beta1FeasibleSpace( | ||
min="0.5", | ||
max="0.9" | ||
), | ||
), | ||
] | ||
|
||
trial_template = generate_trial_template() | ||
|
||
experiment = generate_experiment( | ||
metadata, | ||
algorithm_spec, | ||
objective_spec, | ||
parameters, | ||
trial_template | ||
) | ||
return experiment | ||
|
||
|
||
test_create_experiment_data = [ | ||
( | ||
"experiment name and generate_name missing", | ||
{"experiment": create_experiment()}, | ||
ValueError, | ||
), | ||
( | ||
"create_namespaced_custom_object timeout error", | ||
{ | ||
"experiment": create_experiment(name="experiment-mnist-ci-test"), | ||
"namespace": "timeout", | ||
}, | ||
TimeoutError, | ||
), | ||
( | ||
"create_namespaced_custom_object conflict error", | ||
{ | ||
"experiment": create_experiment(name="experiment-mnist-ci-test"), | ||
"namespace": "conflict", | ||
}, | ||
Exception, | ||
), | ||
( | ||
"create_namespaced_custom_object runtime error", | ||
{ | ||
"experiment": create_experiment(name="experiment-mnist-ci-test"), | ||
"namespace": "runtime", | ||
}, | ||
RuntimeError, | ||
), | ||
( | ||
"valid flow with experiment type V1beta1Experiment and name", | ||
{ | ||
"experiment": create_experiment(name="experiment-mnist-ci-test"), | ||
"namespace": "test-name", | ||
}, | ||
constants.TEST_RESULT_SUCCESS, | ||
), | ||
( | ||
"valid flow with experiment type V1beta1Experiment and generate_name", | ||
{ | ||
"experiment": create_experiment(generate_name="experiment-mnist-ci-test"), | ||
"namespace": "test-generate-name", | ||
}, | ||
constants.TEST_RESULT_SUCCESS, | ||
), | ||
( | ||
"valid flow with experiment type V1beta1Experiment and name and generate_name", | ||
{ | ||
"experiment": create_experiment( | ||
name="experiment-mnist-ci-test", | ||
generate_name="experiment-mnist-ci-test", | ||
), | ||
"namespace": "test", | ||
}, | ||
constants.TEST_RESULT_SUCCESS, | ||
), | ||
( | ||
"valid flow with experiment JSON and name", | ||
{ | ||
"experiment": { | ||
"metadata": { | ||
"name": "experiment-mnist-ci-test", | ||
} | ||
}, | ||
"namespace": "test-name", | ||
}, | ||
constants.TEST_RESULT_SUCCESS, | ||
), | ||
( | ||
"valid flow with experiment JSON and generate_name", | ||
{ | ||
"experiment": { | ||
"metadata": { | ||
"generate_name": "experiment-mnist-ci-test", | ||
} | ||
}, | ||
"namespace": "test-generate-name", | ||
}, | ||
constants.TEST_RESULT_SUCCESS, | ||
), | ||
( | ||
"valid flow with experiment JSON and name and generate_name", | ||
{ | ||
"experiment": { | ||
"metadata": { | ||
"name": "experiment-mnist-ci-test", | ||
"generate_name": "experiment-mnist-ci-test", | ||
} | ||
}, | ||
"namespace": "test", | ||
}, | ||
constants.TEST_RESULT_SUCCESS, | ||
), | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def katib_client(): | ||
with patch( | ||
"kubernetes.client.CustomObjectsApi", | ||
return_value=Mock( | ||
create_namespaced_custom_object=Mock( | ||
side_effect=create_namespaced_custom_object_response | ||
) | ||
), | ||
), patch( | ||
"kubernetes.config.load_kube_config", | ||
return_value=Mock() | ||
): | ||
client = KatibClient() | ||
yield client | ||
|
||
|
||
@pytest.mark.parametrize("test_name,kwargs,expected_output", test_create_experiment_data) | ||
def test_create_experiment(katib_client, test_name, kwargs, expected_output): | ||
""" | ||
test create_experiment function of katib client | ||
""" | ||
print("\n\nExecuting test:", test_name) | ||
try: | ||
katib_client.create_experiment(**kwargs) | ||
assert expected_output == constants.TEST_RESULT_SUCCESS | ||
except Exception as e: | ||
assert type(e) is expected_output | ||
print("test execution complete") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters