Skip to content

Commit

Permalink
[SDK] Enable resource specification for trial containers
Browse files Browse the repository at this point in the history
Co-authored-by: shipengcheng1230 <[email protected]>
  • Loading branch information
droctothorpe and shipengcheng1230 committed Aug 3, 2023
1 parent c749d27 commit abd614d
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import multiprocessing
import textwrap
import time
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
Expand Down Expand Up @@ -147,6 +147,7 @@ def tune(
retain_trials: bool = False,
packages_to_install: List[str] = None,
pip_index_url: str = "https://pypi.org/simple",
resources_per_trial: Union[dict, client.V1ResourceRequirements, None] = None,
):
"""Create HyperParameter Tuning Katib Experiment from the objective function.
Expand Down Expand Up @@ -182,6 +183,20 @@ def tune(
to the base image packages. These packages are installed before
executing the objective function.
pip_index_url: The PyPI url from which to install Python packages.
resources_per trial: A parameter that lets you specify how much
resources each trial container should have. You can either specify a
kubernetes.client.V1ResourceRequirements object (documented here:
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md)
or a dictionary that includes one or more of the following keys:
`cpu`, `memory`, or `gpu` (other keys will be ignored). Appropriate
values for these keys are documented here:
https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/.
For example:
{
"cpu": "1",
"memory": "2Gi",
}
This parameter is optional and defaults to None.
Raises:
ValueError: Objective function has invalid arguments.
Expand Down Expand Up @@ -280,6 +295,23 @@ def tune(
+ exec_script
)

resources = client.V1ResourceRequirements()
if isinstance(resources_per_trial, dict):
requests = {
"cpu": "200m",
"memory": "256Mi",
}
if "gpu" in resources_per_trial:
resources_per_trial["nvidia.com/gpu"] = resources_per_trial.pop("gpu")
requests.update(resources_per_trial)

resources = client.V1ResourceRequirements(
requests=requests,
limits=requests,
)
else:
resources = resources_per_trial

# Create Trial specification.
trial_spec = client.V1Job(
api_version="batch/v1",
Expand All @@ -297,6 +329,7 @@ def tune(
image=base_image,
command=["bash", "-c"],
args=[exec_script],
resources=resources,
)
],
),
Expand Down

0 comments on commit abd614d

Please sign in to comment.