diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000000..767f6146978
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,4 @@
+[flake8]
+max-line-length = 100
+# E203 is ignored to avoid conflicts with Black's formatting, as it's not PEP 8 compliant
+extend-ignore = W503, E203
diff --git a/.github/workflows/e2e-test-tune-api.yaml b/.github/workflows/e2e-test-tune-api.yaml
new file mode 100644
index 00000000000..e1f37a3701b
--- /dev/null
+++ b/.github/workflows/e2e-test-tune-api.yaml
@@ -0,0 +1,34 @@
+name: E2E Test with tune API
+
+on:
+ pull_request:
+ paths-ignore:
+ - "pkg/ui/v1beta1/frontend/**"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ e2e:
+ runs-on: ubuntu-22.04
+ timeout-minutes: 120
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Setup Test Env
+ uses: ./.github/workflows/template-setup-e2e-test
+ with:
+ kubernetes-version: ${{ matrix.kubernetes-version }}
+
+ - name: Run e2e test with tune API
+ uses: ./.github/workflows/template-e2e-test
+ with:
+ tune-api: true
+
+ strategy:
+ fail-fast: false
+ matrix:
+ # Detail: https://hub.docker.com/r/kindest/node
+ kubernetes-version: ["v1.27.11", "v1.28.7", "v1.29.2"]
diff --git a/.github/workflows/template-e2e-test/action.yaml b/.github/workflows/template-e2e-test/action.yaml
index ef1ca26064d..7c9598df04b 100644
--- a/.github/workflows/template-e2e-test/action.yaml
+++ b/.github/workflows/template-e2e-test/action.yaml
@@ -4,15 +4,17 @@ description: Run e2e test using the minikube cluster
inputs:
experiments:
- required: true
+ required: false
description: comma delimited experiment name
+ default: ""
training-operator:
required: false
description: whether to deploy training-operator or not
default: false
trial-images:
- required: true
+ required: false
description: comma delimited trial image name
+ default: ""
katib-ui:
required: true
description: whether to deploy katib-ui or not
@@ -21,13 +23,17 @@ inputs:
required: false
description: mysql or postgres
default: mysql
+ tune-api:
+ required: true
+ description: whether to execute tune-api test or not
+ default: false
runs:
using: composite
steps:
- name: Setup Minikube Cluster
shell: bash
- run: ./test/e2e/v1beta1/scripts/gh-actions/setup-minikube.sh ${{ inputs.katib-ui }} ${{ inputs.trial-images }} ${{ inputs.experiments }}
+ run: ./test/e2e/v1beta1/scripts/gh-actions/setup-minikube.sh ${{ inputs.katib-ui }} ${{ inputs.tune-api }} ${{ inputs.trial-images }} ${{ inputs.experiments }}
- name: Setup Katib
shell: bash
@@ -35,4 +41,9 @@ runs:
- name: Run E2E Experiment
shell: bash
- run: ./test/e2e/v1beta1/scripts/gh-actions/run-e2e-experiment.sh ${{ inputs.experiments }}
+ run: |
+ if "${{ inputs.tune-api }}"; then
+ ./test/e2e/v1beta1/scripts/gh-actions/run-e2e-tune-api.sh
+ else
+ ./test/e2e/v1beta1/scripts/gh-actions/run-e2e-experiment.sh ${{ inputs.experiments }}
+ fi
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 23ed7eeb30f..f191e042b50 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -10,7 +10,17 @@ repos:
hooks:
- id: isort
name: isort
- entry: isort --profile google
+ entry: isort --profile black
+ - repo: https://github.com/psf/black
+ rev: 24.2.0
+ hooks:
+ - id: black
+ files: (sdk|examples|pkg)/.*
+ - repo: https://github.com/pycqa/flake8
+ rev: 7.1.1
+ hooks:
+ - id: flake8
+ files: (sdk|examples|pkg)/.*
exclude: |
(?x)^(
.*zz_generated.deepcopy.*|
diff --git a/cmd/earlystopping/medianstop/v1beta1/main.py b/cmd/earlystopping/medianstop/v1beta1/main.py
index 240517a76c3..132564d12cb 100644
--- a/cmd/earlystopping/medianstop/v1beta1/main.py
+++ b/cmd/earlystopping/medianstop/v1beta1/main.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import logging
import time
+from concurrent import futures
import grpc
diff --git a/cmd/metricscollector/v1beta1/tfevent-metricscollector/main.py b/cmd/metricscollector/v1beta1/tfevent-metricscollector/main.py
index 21ab7e20bd2..274ed59ba48 100644
--- a/cmd/metricscollector/v1beta1/tfevent-metricscollector/main.py
+++ b/cmd/metricscollector/v1beta1/tfevent-metricscollector/main.py
@@ -13,9 +13,7 @@
# limitations under the License.
import argparse
-from logging import getLogger
-from logging import INFO
-from logging import StreamHandler
+from logging import INFO, StreamHandler, getLogger
import api_pb2
import api_pb2_grpc
diff --git a/cmd/suggestion/hyperband/v1beta1/main.py b/cmd/suggestion/hyperband/v1beta1/main.py
index 21dd46e4d9e..d2a3c2dce1f 100644
--- a/cmd/suggestion/hyperband/v1beta1/main.py
+++ b/cmd/suggestion/hyperband/v1beta1/main.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import time
+from concurrent import futures
import grpc
diff --git a/cmd/suggestion/hyperopt/v1beta1/main.py b/cmd/suggestion/hyperopt/v1beta1/main.py
index c459d5b532c..10d4497c20c 100644
--- a/cmd/suggestion/hyperopt/v1beta1/main.py
+++ b/cmd/suggestion/hyperopt/v1beta1/main.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import time
+from concurrent import futures
import grpc
diff --git a/cmd/suggestion/nas/darts/v1beta1/main.py b/cmd/suggestion/nas/darts/v1beta1/main.py
index a1926ad8326..f0b8f6a1f97 100644
--- a/cmd/suggestion/nas/darts/v1beta1/main.py
+++ b/cmd/suggestion/nas/darts/v1beta1/main.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import time
+from concurrent import futures
import grpc
diff --git a/cmd/suggestion/nas/enas/v1beta1/main.py b/cmd/suggestion/nas/enas/v1beta1/main.py
index 62dda9c810c..399ed275a47 100644
--- a/cmd/suggestion/nas/enas/v1beta1/main.py
+++ b/cmd/suggestion/nas/enas/v1beta1/main.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import time
+from concurrent import futures
import grpc
diff --git a/cmd/suggestion/optuna/v1beta1/main.py b/cmd/suggestion/optuna/v1beta1/main.py
index 435933f4858..cadd393d704 100644
--- a/cmd/suggestion/optuna/v1beta1/main.py
+++ b/cmd/suggestion/optuna/v1beta1/main.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import time
+from concurrent import futures
import grpc
diff --git a/cmd/suggestion/pbt/v1beta1/main.py b/cmd/suggestion/pbt/v1beta1/main.py
index 9e5efb133a6..7f16ffad432 100644
--- a/cmd/suggestion/pbt/v1beta1/main.py
+++ b/cmd/suggestion/pbt/v1beta1/main.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import time
+from concurrent import futures
import grpc
diff --git a/cmd/suggestion/skopt/v1beta1/main.py b/cmd/suggestion/skopt/v1beta1/main.py
index 55d6215529b..d2541042855 100644
--- a/cmd/suggestion/skopt/v1beta1/main.py
+++ b/cmd/suggestion/skopt/v1beta1/main.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from concurrent import futures
import time
+from concurrent import futures
import grpc
diff --git a/docs/proposals/parameter-distribution.md b/docs/proposals/parameter-distribution.md
new file mode 100644
index 00000000000..ebe062c02d0
--- /dev/null
+++ b/docs/proposals/parameter-distribution.md
@@ -0,0 +1,169 @@
+# Proposal for Supporting various parameter distributions in Katib
+
+## Summary
+The goal of this project is to enhance the existing Katib Experiment APIs to support various parameter distributions such as uniform, log-uniform, and qlog-uniform. Then extend the suggestion services to be able to configure distributions for search space using libraries provided in each framework.
+
+## Motivation
+Currently, [Katib](https://github.com/kubeflow/katib) is limited to supporting only uniform distribution for integer, float, and categorical hyperparameters. By introducing additional distributions, Katib will become more flexible and powerful in conducting hyperparameter optimization tasks.
+
+A Data Scientist requires Katib to support multiple hyperparameter distributions, such as log-uniform, normal, and log-normal, in addition to the existing uniform distribution. This enhancement is crucial for more flexible and precise hyperparameter optimization. For instance, learning rates often benefit from a log-uniform distribution because small values can significantly impact performance. Similarly, normal distributions are useful for parameters that are expected to vary around a central value.
+
+### Goals
+- Add `Distribution` field to `FeasibleSpace` alongside `ParameterType`.
+- Support for the log-uniform, normal, and log-normal Distributions.
+- Update the Experiment and gRPC API to support `Distribution`.
+- Update logic to handle the new parameter distributions for each suggestion service (e.g., Optuna, Hyperopt).
+- Extend the Python SDK to support the new `Distribution` field.
+### Non-Goals
+- This proposal do not aim to create new version for CRD APIs.
+- This proposal do not aim to make the necessary Katib UI changes.
+- No changes will be made to the core optimization algorithms beyond supporting new distributions.
+
+## Proposal
+
+### Parameter Distribution Comparison Table
+
+| Distribution Type | Hyperopt | Optuna | Ray Tune | Nevergrad |
+|-------------------------------|-----------------------|-------------------------------------------------|-----------------------|---------------------------------------------|
+| **Uniform Continuous** | `hp.uniform` | `FloatDistribution` | `tune.uniform` | `p.Scalar` with uniform transformation |
+| **Quantized Uniform** | `hp.quniform` | `DiscreteUniformDistribution` (deprecated) | `tune.quniform` | `p.Scalar` with uniform and step specified |
+| **Log Uniform** | `hp.loguniform` | `LogUniformDistribution` (deprecated) | `tune.loguniform` | `p.Log` with uniform transformation |
+| **Uniform Integer** | `hp.randint` or quantized distributions with step size `q` set to 1 | `IntDistribution` | `tune.randint` | `p.Scalar` with integer transformation |
+| **Categorical** | `hp.choice` | `CategoricalDistribution` | `tune.choice` | `p.Choice` |
+| **Quantized Log Uniform** | `hp.qloguniform` | Custom Implementation | `tune.qloguniform` | `p.Log` with uniform and step specified |
+| **Normal** | `hp.normal` | (Not directly supported) | `tune.randn` | (Not directly supported) |
+| **Quantized Normal** | `hp.qnormal` | (Not directly supported) | `tune.qrandn` | (Not directly supported) |
+| **Log Normal** | `hp.lognormal` | (Not directly supported) | (Use custom transformation in `tune.randn`) | (Not directly supported) |
+| **Quantized Log Normal** | `hp.qlognormal` | (Not directly supported) | (Use custom transformation in `tune.qrandn`) | (Not directly supported) |
+| **Quantized Integer** | `hp.quniformint` | `IntUniformDistribution` (deprecated) | | `p.Scalar` with integer and step specified |
+| **Log Integer** | | `IntLogUniformDistribution` (deprecated) | `tune.lograndint` | `p.Scalar` with log-integer transformation |
+
+
+- Note:
+In `Nevergrad`, parameter types like `p.Scalar`, `p.Log`, and `p.Choice` are mapped to corresponding `Hyperopt` search space definitions like `hp.uniform`, `hp.loguniform`, and `hp.choice` using internal functions to convert parameter bounds and distributions.
+
+## API Design
+### FeasibleSpace
+Feasible space for optimization.
+Int and Double type use Max/Min.
+Discrete and Categorical type use List.
+
+
+| Field | Type | Label | Description |
+| ----- | ---- | ----- | ----------- |
+| max | [string](#string) | | Max Value |
+| min | [string](#string) | | Minimum Value |
+| list | [string](#string) | repeated | List of Values. |
+| step | [string](#string) | | Step for double or int parameter or q for quantization|
+| distribution | [Distribution](#api-v1-beta1-Distribution) | | Type of the Distribution. |
+
+
+
+
+### Distribution
+- Types of value for HyperParameter Distributions.
+- We add the `distribution` field to represent the hyperparameters search space rather than [`ParameterType`](https://github.com/kubeflow/katib/blob/2c575227586ff1c03cf6b5190d066e2f3061a404/pkg/apis/controller/experiments/v1beta1/experiment_types.go#L199-L207).
+- The `distribution` allows users to configure more granular search space customizations.
+- In this enhancement, we would propose the following 4 distributions:
+
+| Name | Number | Description |
+| ---- | ------ | ----------- |
+| UNIFORM | 0 | Continuous uniform distribution. Samples values evenly between a minimum and maximum value. Use "Max/Min". Use "Step" for `q`. |
+| LOGUNIFORM | 1 | Samples values such that their logarithm is uniformly distributed. Use "Max/Min". Use "Step" for `q`. |
+| NORMAL | 2 | Normal (Gaussian) distribution type. Samples values according to a normal distribution characterized by a mean and standard deviation. Use "Max/Min". Use "Step" for `q`. |
+| LOGNORMAL | 3 | Log-normal distribution type. Samples values such that their logarithm is normally distributed. Use "Max/Min". Use "Step" for `q`. |
+
+
+## Experiment API changes
+Scope: `pkg/apis/controller/experiments/v1beta1/experiment_types.go`
+
+```go
+type ParameterSpec struct {
+ Name string `json:"name,omitempty"`
+ ParameterType ParameterType `json:"parameterType,omitempty"`
+ FeasibleSpace FeasibleSpace `json:"feasibleSpace,omitempty"`
+}
+```
+- Adding new field `Distribution` to `FeasibleSpace`
+
+- The `Step` field can be used to define quantization steps for uniform or log-uniform distributions, effectively covering q-quantization requirements.
+
+Updated `FeasibleSpace` struct
+```diff
+type FeasibleSpace struct {
+ Max string `json:"max,omitempty"`
+ Min string `json:"min,omitempty"`
+ List []string `json:"list,omitempty"`
+ Step string `json:"step,omitempty"` // Step can be used to define q-quantization
++ Distribution Distribution `json:"distribution,omitempty"` // Added Distribution field
+}
+```
+ - New Field Description: `Distribution`
+ - Type: `Distribution`
+ - Description: The Distribution field specifies the type of statistical distribution to be applied to the parameter. This allows the definition of various distributions, such as uniform, log-uniform, or other supported types.
+
+- Defining `Distribution` type
+```go
+type Distribution string
+
+const (
+ DistributionUniform Distribution = "uniform"
+ DistributionLogUniform Distribution = "logUniform"
+ DistributionNormal Distribution = "normal"
+ DistributionLogNormal Distribution = "logNormal"
+)
+```
+
+## gRPC API changes
+Scope: `pkg/apis/manager/v1beta1/api.proto`
+- Add the `Distribution` field to the `FeasibleSpace` message
+```diff
+/**
+ * Feasible space for optimization.
+ * Int and Double type use Max/Min.
+ * Discrete and Categorical type use List.
+ */
+message FeasibleSpace {
+ string max = 1; /// Max Value
+ string min = 2; /// Minimum Value
+ repeated string list = 3; /// List of Values.
+ string step = 4; /// Step for double or int parameter
++ Distribution distribution = 4; // Distribution of the parameter.
+}
+```
+- Define the `Distribution` enum
+```
+/**
+ * Distribution types for HyperParameter.
+ */
+enum Distribution {
+ UNIFORM = 0;
+ LOG_UNIFORM = 1;
+ NORMAL = 2;
+ LOG_NORMAL = 3;
+}
+```
+
+## Suggestion Service Logic
+- For each suggestion service (e.g., Optuna, Hyperopt), the logic will be updated to handle the new parameter distributions.
+- This involves modifying the conversion functions to map Katib distributions to the corresponding framework-specific distributions.
+
+#### Optuna
+ref: https://optuna.readthedocs.io/en/stable/reference/distributions.html
+
+For example:
+- Update the `_get_optuna_search_space` for new Distributions.
+scope: `pkg/suggestion/v1beta1/optuna/base_service.py`
+
+#### Goptuna
+ref: https://github.com/c-bata/goptuna/blob/2245ddd9e8d1edba750839893c8a618f852bc1cf/distribution.go
+
+#### Hyperopt
+ref: http://hyperopt.github.io/hyperopt/getting-started/search_spaces/#parameter-expressions
+
+#### Ray-tune
+ref: https://docs.ray.io/en/latest/tune/api/search_space.html
+
+## Python SDK
+Extend the Python SDK to support the new `Distribution` field.
+
diff --git a/examples/v1beta1/kubeflow-pipelines/mpi-job-horovod.py b/examples/v1beta1/kubeflow-pipelines/mpi-job-horovod.py
index d5867fd6c4e..800012a2650 100644
--- a/examples/v1beta1/kubeflow-pipelines/mpi-job-horovod.py
+++ b/examples/v1beta1/kubeflow-pipelines/mpi-job-horovod.py
@@ -21,34 +21,37 @@
# This Experiment is similar to this:
# https://github.com/kubeflow/katib/blob/master/examples/v1beta1/kubeflow-training-operator/mpijob-horovod.yaml
-# Check the training container source code here: https://github.com/kubeflow/mpi-operator/tree/master/examples/horovod.
+# Check the training container source code here:
+# https://github.com/kubeflow/mpi-operator/tree/master/examples/horovod.
# Note: To run this example, your Kubernetes cluster should run MPIJob operator.
-# Follow this guide to install MPIJob on your cluster: https://www.kubeflow.org/docs/components/training/mpi/
+# Follow this guide to install MPIJob on your cluster:
+# https://www.kubeflow.org/docs/components/training/mpi/
import kfp
-from kfp import components
import kfp.dsl as dsl
-from kubeflow.katib import ApiClient
-from kubeflow.katib import V1beta1AlgorithmSetting
-from kubeflow.katib import V1beta1AlgorithmSpec
-from kubeflow.katib import V1beta1ExperimentSpec
-from kubeflow.katib import V1beta1FeasibleSpace
-from kubeflow.katib import V1beta1ObjectiveSpec
-from kubeflow.katib import V1beta1ParameterSpec
-from kubeflow.katib import V1beta1TrialParameterSpec
-from kubeflow.katib import V1beta1TrialTemplate
+from kfp import components
+from kubeflow.katib import (
+ ApiClient,
+ V1beta1AlgorithmSetting,
+ V1beta1AlgorithmSpec,
+ V1beta1ExperimentSpec,
+ V1beta1FeasibleSpace,
+ V1beta1ObjectiveSpec,
+ V1beta1ParameterSpec,
+ V1beta1TrialParameterSpec,
+ V1beta1TrialTemplate,
+)
@dsl.pipeline(
name="Launch Katib MPIJob Experiment",
- description="An example to launch Katib Experiment with MPIJob"
+ description="An example to launch Katib Experiment with MPIJob",
)
def horovod_mnist_hpo(
experiment_name: str = "mpi-horovod-mnist",
experiment_namespace: str = "kubeflow-user-example-com",
):
-
# Trial count specification.
max_trial_count = 6
max_failed_trial_count = 3
@@ -64,12 +67,7 @@ def horovod_mnist_hpo(
# Algorithm specification.
algorithm = V1beta1AlgorithmSpec(
algorithm_name="bayesianoptimization",
- algorithm_settings=[
- V1beta1AlgorithmSetting(
- name="random_state",
- value="10"
- )
- ]
+ algorithm_settings=[V1beta1AlgorithmSetting(name="random_state", value="10")],
)
# Experiment search space.
@@ -78,19 +76,12 @@ def horovod_mnist_hpo(
V1beta1ParameterSpec(
name="lr",
parameter_type="double",
- feasible_space=V1beta1FeasibleSpace(
- min="0.001",
- max="0.003"
- ),
+ feasible_space=V1beta1FeasibleSpace(min="0.001", max="0.003"),
),
V1beta1ParameterSpec(
name="num-steps",
parameter_type="int",
- feasible_space=V1beta1FeasibleSpace(
- min="50",
- max="150",
- step="10"
- ),
+ feasible_space=V1beta1FeasibleSpace(min="50", max="150", step="10"),
),
]
@@ -106,18 +97,14 @@ def horovod_mnist_hpo(
"replicas": 1,
"template": {
"metadata": {
- "annotations": {
- "sidecar.istio.io/inject": "false"
- }
+ "annotations": {"sidecar.istio.io/inject": "false"}
},
"spec": {
"containers": [
{
"image": "docker.io/kubeflow/mpi-horovod-mnist",
"name": "mpi-launcher",
- "command": [
- "mpirun"
- ],
+ "command": ["mpirun"],
"args": [
"-np",
"2",
@@ -141,26 +128,21 @@ def horovod_mnist_hpo(
"--lr",
"${trialParameters.learningRate}",
"--num-steps",
- "${trialParameters.numberSteps}"
+ "${trialParameters.numberSteps}",
],
"resources": {
- "limits": {
- "cpu": "500m",
- "memory": "2Gi"
- }
- }
+ "limits": {"cpu": "500m", "memory": "2Gi"}
+ },
}
]
- }
- }
+ },
+ },
},
"Worker": {
"replicas": 2,
"template": {
"metadata": {
- "annotations": {
- "sidecar.istio.io/inject": "false"
- }
+ "annotations": {"sidecar.istio.io/inject": "false"}
},
"spec": {
"containers": [
@@ -168,25 +150,20 @@ def horovod_mnist_hpo(
"image": "docker.io/kubeflow/mpi-horovod-mnist",
"name": "mpi-worker",
"resources": {
- "limits": {
- "cpu": "500m",
- "memory": "4Gi"
- }
- }
+ "limits": {"cpu": "500m", "memory": "4Gi"}
+ },
}
]
- }
- }
- }
- }
- }
+ },
+ },
+ },
+ },
+ },
}
# Configure parameters for the Trial template.
trial_template = V1beta1TrialTemplate(
- primary_pod_labels={
- "mpi-job-role": "launcher"
- },
+ primary_pod_labels={"mpi-job-role": "launcher"},
primary_container_name="mpi-launcher",
success_condition='status.conditions.#(type=="Succeeded")#|#(status=="True")#',
failure_condition='status.conditions.#(type=="Failed")#|#(status=="True")#',
@@ -194,15 +171,15 @@ def horovod_mnist_hpo(
V1beta1TrialParameterSpec(
name="learningRate",
description="Learning rate for the training model",
- reference="lr"
+ reference="lr",
),
V1beta1TrialParameterSpec(
name="numberSteps",
description="Number of training steps",
- reference="num-steps"
+ reference="num-steps",
),
],
- trial_spec=trial_spec
+ trial_spec=trial_spec,
)
# Create Experiment specification.
@@ -213,13 +190,15 @@ def horovod_mnist_hpo(
objective=objective,
algorithm=algorithm,
parameters=parameters,
- trial_template=trial_template
+ trial_template=trial_template,
)
# Get the Katib launcher.
# Load component from the URL or from the file.
katib_experiment_launcher_op = components.load_component_from_url(
- "https://raw.githubusercontent.com/kubeflow/pipelines/master/components/kubeflow/katib-launcher/component.yaml")
+ "https://raw.githubusercontent.com/kubeflow/pipelines/master/"
+ "components/kubeflow/katib-launcher/component.yaml"
+ )
# katib_experiment_launcher_op = components.load_component_from_file(
# "../../../components/kubeflow/katib-launcher/component.yaml"
# )
@@ -231,7 +210,8 @@ def horovod_mnist_hpo(
experiment_name=experiment_name,
experiment_namespace=experiment_namespace,
experiment_spec=ApiClient().sanitize_for_serialization(experiment_spec),
- experiment_timeout_minutes=60)
+ experiment_timeout_minutes=60,
+ )
# Output container to print the results.
dsl.ContainerOp(
diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py
index af54729f284..5f532e715da 100644
--- a/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py
+++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py
@@ -17,9 +17,8 @@
import torch
-class Architect():
- """" Architect controls architecture of cell by computing gradients of alphas
- """
+class Architect:
+ """ " Architect controls architecture of cell by computing gradients of alphas"""
def __init__(self, model, w_momentum, w_weight_decay, device):
self.model = model
@@ -48,25 +47,32 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
# Compute gradient
gradients = torch.autograd.grad(loss, self.model.getWeights())
-
+
# Do virtual step (Update gradient)
# Below operations do not need gradient tracking
with torch.no_grad():
# dict key is not the value, but the pointer. So original network weight have to
# be iterated also.
- for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
- m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
- if(self.device == 'cuda'):
- vw.copy_(w - torch.cuda.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
- elif(self.device == 'cpu'):
- vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
+ for w, vw, g in zip(
+ self.model.getWeights(), self.v_model.getWeights(), gradients
+ ):
+ m = w_optim.state[w].get("momentum_buffer", 0.0) * self.w_momentum
+ if self.device == "cuda":
+ vw.copy_(
+ w
+ - torch.cuda.FloatTensor(xi) * (m + g + self.w_weight_decay * w)
+ )
+ elif self.device == "cpu":
+ vw.copy_(
+ w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w)
+ )
# Sync alphas
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
va.copy_(a)
def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
- """ Compute unrolled loss and backward its gradients
+ """Compute unrolled loss and backward its gradients
Args:
xi: learning rate for virtual gradient step (same as model lr)
w_optim: weights optimizer - for virtual step
@@ -77,23 +83,23 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Calculate unrolled loss
# Loss for validation with w'. L_valid(w')
loss = self.v_model.loss(valid_x, valid_y)
-
+
# Calculate gradient
v_alphas = tuple(self.v_model.getAlphas())
v_weights = tuple(self.v_model.getWeights())
v_grads = torch.autograd.grad(loss, v_alphas + v_weights)
- dalpha = v_grads[:len(v_alphas)]
- dws = v_grads[len(v_alphas):]
+ dalpha = v_grads[: len(v_alphas)]
+ dws = v_grads[len(v_alphas) :]
hessian = self.compute_hessian(dws, train_x, train_y)
# Update final gradient = dalpha - xi * hessian
with torch.no_grad():
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
- if(self.device == 'cuda'):
+ if self.device == "cuda":
alpha.grad = da - torch.cuda.FloatTensor(xi) * h
- elif(self.device == 'cpu'):
+ elif self.device == "cpu":
alpha.grad = da - torch.cpu.FloatTensor(xi) * h
def compute_hessian(self, dws, train_x, train_y):
@@ -121,7 +127,7 @@ def compute_hessian(self, dws, train_x, train_y):
with torch.no_grad():
for p, dw in zip(self.model.getWeights(), dws):
# TODO (andreyvelich): Do we need this * 2.0 ?
- p -= 2. * eps * dw
+ p -= 2.0 * eps * dw
loss = self.model.loss(train_x, train_y)
# dalpha { L_train(w-, alpha) }
@@ -132,5 +138,7 @@ def compute_hessian(self, dws, train_x, train_y):
for p, dw in zip(self.model.getWeights(), dws):
p += eps * dw
- hessian = [(p-n) / (2. * eps) for p, n in zip(dalpha_positive, dalpha_negative)]
+ hessian = [
+ (p - n) / (2.0 * eps) for p, n in zip(dalpha_positive, dalpha_negative)
+ ]
return hessian
diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/model.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/model.py
index 195c0f8f865..df61fe24014 100644
--- a/examples/v1beta1/trial-images/darts-cnn-cifar10/model.py
+++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/model.py
@@ -12,20 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from operations import FactorizedReduce
-from operations import MixedOp
-from operations import StdConv
import torch
import torch.nn as nn
import torch.nn.functional as F
+from operations import FactorizedReduce, MixedOp, StdConv
class Cell(nn.Module):
- """ Cell for search
+ """Cell for search
Each edge is mixed and continuous relaxed.
"""
- def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space):
+ def __init__(
+ self,
+ num_nodes,
+ c_prev_prev,
+ c_prev,
+ c_cur,
+ reduction_prev,
+ reduction_cur,
+ search_space,
+ ):
"""
Args:
num_nodes: Number of intermediate cell nodes
@@ -45,7 +52,9 @@ def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduct
if reduction_prev:
self.preprocess0 = FactorizedReduce(c_prev_prev, c_cur)
else:
- self.preprocess0 = StdConv(c_prev_prev, c_cur, kernel_size=1, stride=1, padding=0)
+ self.preprocess0 = StdConv(
+ c_prev_prev, c_cur, kernel_size=1, stride=1, padding=0
+ )
self.preprocess1 = StdConv(c_prev, c_cur, kernel_size=1, stride=1, padding=0)
# Generate dag from mixed operations
@@ -54,7 +63,7 @@ def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduct
for i in range(self.num_nodes):
self.dag_ops.append(nn.ModuleList())
# Include 2 input nodes
- for j in range(2+i):
+ for j in range(2 + i):
# Reduction with stride = 2 must be only for the input node
stride = 2 if reduction_cur and j < 2 else 1
op = MixedOp(c_cur, stride, search_space)
@@ -66,7 +75,9 @@ def forward(self, s0, s1, w_dag):
states = [s0, s1]
for edges, w_list in zip(self.dag_ops, w_dag):
- state_cur = sum(edges[i](s, w) for i, (s, w) in enumerate((zip(states, w_list))))
+ state_cur = sum(
+ edges[i](s, w) for i, (s, w) in enumerate((zip(states, w_list)))
+ )
states.append(state_cur)
state_out = torch.cat(states[2:], dim=1)
@@ -75,8 +86,17 @@ def forward(self, s0, s1, w_dag):
class NetworkCNN(nn.Module):
- def __init__(self, init_channels, input_channels, num_classes,
- num_layers, criterion, search_space, num_nodes, stem_multiplier):
+ def __init__(
+ self,
+ init_channels,
+ input_channels,
+ num_classes,
+ num_layers,
+ criterion,
+ search_space,
+ num_nodes,
+ stem_multiplier,
+ ):
super(NetworkCNN, self).__init__()
self.init_channels = init_channels
@@ -87,11 +107,11 @@ def __init__(self, init_channels, input_channels, num_classes,
self.num_nodes = num_nodes
self.stem_multiplier = stem_multiplier
- c_cur = self.stem_multiplier*self.init_channels
+ c_cur = self.stem_multiplier * self.init_channels
self.stem = nn.Sequential(
nn.Conv2d(input_channels, c_cur, 3, padding=1, bias=False),
- nn.BatchNorm2d(c_cur)
+ nn.BatchNorm2d(c_cur),
)
# In first Cell stem is used for s0 and s1
@@ -110,14 +130,24 @@ def __init__(self, init_channels, input_channels, num_classes,
# For Network with two layers: First layer - Normal, Second - Reduction
# For Other Networks: [1/3, 2/3] Layers - Reduction cell with double channels
# Others - Normal cell
- if ((self.num_layers == 2 and i == 1) or
- (self.num_layers > 2 and i in [self.num_layers//3, 2*self.num_layers//3])):
+ if (self.num_layers == 2 and i == 1) or (
+ self.num_layers > 2
+ and i in [self.num_layers // 3, 2 * self.num_layers // 3]
+ ):
c_cur *= 2
reduction_cur = True
else:
reduction_cur = False
- cell = Cell(self.num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space)
+ cell = Cell(
+ self.num_nodes,
+ c_prev_prev,
+ c_prev,
+ c_cur,
+ reduction_prev,
+ reduction_cur,
+ search_space,
+ )
reduction_prev = reduction_cur
self.cells.append(cell)
@@ -134,9 +164,11 @@ def __init__(self, init_channels, input_channels, num_classes,
self.alpha_reduce = nn.ParameterList()
for i in range(self.num_nodes):
- self.alpha_normal.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops)))
+ self.alpha_normal.append(nn.Parameter(1e-3 * torch.randn(i + 2, num_ops)))
if self.num_layers > 1:
- self.alpha_reduce.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops)))
+ self.alpha_reduce.append(
+ nn.Parameter(1e-3 * torch.randn(i + 2, num_ops))
+ )
# Setup alphas list
self.alphas = []
@@ -192,5 +224,9 @@ def genotype(self, search_space):
# concat all intermediate nodes
concat = range(2, 2 + self.num_nodes)
- return search_space.genotype(normal=gene_normal, normal_concat=concat,
- reduce=gene_reduce, reduce_concat=concat)
+ return search_space.genotype(
+ normal=gene_normal,
+ normal_concat=concat,
+ reduce=gene_reduce,
+ reduce_concat=concat,
+ )
diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/operations.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/operations.py
index c701742fa0f..a4cdebd5b7b 100644
--- a/examples/v1beta1/trial-images/darts-cnn-cifar10/operations.py
+++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/operations.py
@@ -16,18 +16,30 @@
import torch.nn as nn
OPS = {
- 'none': lambda channels, stride: Zero(stride),
- 'avg_pooling_3x3': lambda channels, stride: PoolBN('avg', channels, kernel_size=3, stride=stride, padding=1),
- 'max_pooling_3x3': lambda channels, stride: PoolBN('max', channels, kernel_size=3, stride=stride, padding=1),
- 'skip_connection': lambda channels, stride: Identity() if stride == 1 else FactorizedReduce(channels, channels),
- 'separable_convolution_3x3': lambda channels, stride: SepConv(channels, kernel_size=3, stride=stride, padding=1),
- 'separable_convolution_5x5': lambda channels, stride: SepConv(channels, kernel_size=5, stride=stride, padding=2),
+ "none": lambda channels, stride: Zero(stride),
+ "avg_pooling_3x3": lambda channels, stride: PoolBN(
+ "avg", channels, kernel_size=3, stride=stride, padding=1
+ ),
+ "max_pooling_3x3": lambda channels, stride: PoolBN(
+ "max", channels, kernel_size=3, stride=stride, padding=1
+ ),
+ "skip_connection": lambda channels, stride: (
+ Identity() if stride == 1 else FactorizedReduce(channels, channels)
+ ),
+ "separable_convolution_3x3": lambda channels, stride: SepConv(
+ channels, kernel_size=3, stride=stride, padding=1
+ ),
+ "separable_convolution_5x5": lambda channels, stride: SepConv(
+ channels, kernel_size=5, stride=stride, padding=2
+ ),
# 3x3 -> 5x5
- 'dilated_convolution_3x3': lambda channels, stride: DilConv(channels,
- kernel_size=3, stride=stride, padding=2, dilation=2),
+ "dilated_convolution_3x3": lambda channels, stride: DilConv(
+ channels, kernel_size=3, stride=stride, padding=2, dilation=2
+ ),
# 5x5 -> 9x9
- 'dilated_convolution_5x5': lambda channels, stride: DilConv(channels,
- kernel_size=5, stride=stride, padding=4, dilation=2),
+ "dilated_convolution_5x5": lambda channels, stride: DilConv(
+ channels, kernel_size=5, stride=stride, padding=4, dilation=2
+ ),
}
@@ -42,9 +54,9 @@ def __init__(self, stride):
def forward(self, x):
if self.stride == 1:
- return x * 0.
+ return x * 0.0
# Resize by stride
- return x[:, :, ::self.stride, ::self.stride] * 0.
+ return x[:, :, :: self.stride, :: self.stride] * 0.0
class PoolBN(nn.Module):
@@ -55,15 +67,14 @@ class PoolBN(nn.Module):
def __init__(self, pool_type, channels, kernel_size, stride, padding):
super(PoolBN, self).__init__()
if pool_type == "avg":
- self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
+ self.pool = nn.AvgPool2d(
+ kernel_size, stride, padding, count_include_pad=False
+ )
elif pool_type == "max":
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
self.bn = nn.BatchNorm2d(channels, affine=False)
- self.net = nn.Sequential(
- self.pool,
- self.bn
- )
+ self.net = nn.Sequential(self.pool, self.bn)
def forward(self, x):
# out = self.pool(x),
@@ -91,8 +102,12 @@ class FactorizedReduce(nn.Module):
def __init__(self, c_in, c_out):
super(FactorizedReduce, self).__init__()
self.relu = nn.ReLU()
- self.conv1 = nn.Conv2d(c_in, c_out // 2, kernel_size=1, stride=2, padding=0, bias=False)
- self.conv2 = nn.Conv2d(c_in, c_out // 2, kernel_size=1, stride=2, padding=0, bias=False)
+ self.conv1 = nn.Conv2d(
+ c_in, c_out // 2, kernel_size=1, stride=2, padding=0, bias=False
+ )
+ self.conv2 = nn.Conv2d(
+ c_in, c_out // 2, kernel_size=1, stride=2, padding=0, bias=False
+ )
self.bn = nn.BatchNorm2d(c_out, affine=False)
def forward(self, x):
@@ -105,7 +120,7 @@ def forward(self, x):
class StdConv(nn.Module):
- """ Standard convolition
+ """Standard convolition
ReLU - Conv - BN
"""
@@ -113,8 +128,15 @@ def __init__(self, c_in, c_out, kernel_size, stride, padding):
super(StdConv, self).__init__()
self.net = nn.Sequential(
nn.ReLU(),
- nn.Conv2d(c_in, c_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
- nn.BatchNorm2d(c_out, affine=False)
+ nn.Conv2d(
+ c_in,
+ c_out,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ ),
+ nn.BatchNorm2d(c_out, affine=False),
)
def forward(self, x):
@@ -122,7 +144,7 @@ def forward(self, x):
class DilConv(nn.Module):
- """ (Dilated) depthwise separable conv
+ """(Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field
@@ -134,9 +156,20 @@ def __init__(self, channels, kernel_size, stride, padding, dilation):
self.net = nn.Sequential(
nn.ReLU(),
- nn.Conv2d(channels, channels, kernel_size, stride, padding, dilation=dilation, groups=channels, bias=False),
- nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(channels, affine=False)
+ nn.Conv2d(
+ channels,
+ channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation=dilation,
+ groups=channels,
+ bias=False,
+ ),
+ nn.Conv2d(
+ channels, channels, kernel_size=1, stride=1, padding=0, bias=False
+ ),
+ nn.BatchNorm2d(channels, affine=False),
)
def forward(self, x):
@@ -144,7 +177,7 @@ def forward(self, x):
class SepConv(nn.Module):
- """ Depthwise separable conv
+ """Depthwise separable conv
DilConv (dilation=1) * 2
"""
@@ -152,7 +185,7 @@ def __init__(self, channels, kernel_size, stride, padding):
super(SepConv, self).__init__()
self.net = nn.Sequential(
DilConv(channels, kernel_size, stride=stride, padding=padding, dilation=1),
- DilConv(channels, kernel_size, stride=1, padding=padding, dilation=1)
+ DilConv(channels, kernel_size, stride=1, padding=padding, dilation=1),
)
def forward(self, x):
@@ -160,8 +193,7 @@ def forward(self, x):
class MixedOp(nn.Module):
- """ Mixed operation
- """
+ """Mixed operation"""
def __init__(self, channels, stride, search_space):
super(MixedOp, self).__init__()
diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py
index 8afc37e27d7..0a1f145bd6d 100644
--- a/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py
+++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py
@@ -16,26 +16,38 @@
import argparse
import json
-from architect import Architect
-from model import NetworkCNN
import numpy as np
-from search_space import SearchSpace
import torch
import torch.nn as nn
import utils
+from architect import Architect
+from model import NetworkCNN
+from search_space import SearchSpace
def main():
- parser = argparse.ArgumentParser(description='TrainingContainer')
- parser.add_argument('--algorithm-settings', type=str, default="", help="algorithm settings")
- parser.add_argument('--search-space', type=str, default="", help="search space for the neural architecture search")
- parser.add_argument('--num-layers', type=str, default="", help="number of layers of the neural network")
+ parser = argparse.ArgumentParser(description="TrainingContainer")
+ parser.add_argument(
+ "--algorithm-settings", type=str, default="", help="algorithm settings"
+ )
+ parser.add_argument(
+ "--search-space",
+ type=str,
+ default="",
+ help="search space for the neural architecture search",
+ )
+ parser.add_argument(
+ "--num-layers",
+ type=str,
+ default="",
+ help="number of layers of the neural network",
+ )
args = parser.parse_args()
# Get Algorithm Settings
- algorithm_settings = args.algorithm_settings.replace("\'", "\"")
+ algorithm_settings = args.algorithm_settings.replace("'", '"')
algorithm_settings = json.loads(algorithm_settings)
print(">>> Algorithm settings")
for key, value in algorithm_settings.items():
@@ -69,7 +81,7 @@ def main():
stem_multiplier = int(algorithm_settings["stem_multiplier"])
# Get Search Space
- search_space = args.search_space.replace("\'", "\"")
+ search_space = args.search_space.replace("'", '"')
search_space = json.loads(search_space)
search_space = SearchSpace(search_space)
@@ -103,16 +115,28 @@ def main():
criterion = nn.CrossEntropyLoss().to(device)
- model = NetworkCNN(init_channels, input_channels, num_classes, num_layers,
- criterion, search_space, num_nodes, stem_multiplier)
+ model = NetworkCNN(
+ init_channels,
+ input_channels,
+ num_classes,
+ num_layers,
+ criterion,
+ search_space,
+ num_nodes,
+ stem_multiplier,
+ )
model = model.to(device)
# Weights optimizer
- w_optim = torch.optim.SGD(model.getWeights(), w_lr, momentum=w_momentum, weight_decay=w_weight_decay)
+ w_optim = torch.optim.SGD(
+ model.getWeights(), w_lr, momentum=w_momentum, weight_decay=w_weight_decay
+ )
# Alphas optimizer
- alpha_optim = torch.optim.Adam(model.getAlphas(), alpha_lr, betas=(0.5, 0.999), weight_decay=alpha_weight_decay)
+ alpha_optim = torch.optim.Adam(
+ model.getAlphas(), alpha_lr, betas=(0.5, 0.999), weight_decay=alpha_weight_decay
+ )
# Split data to train/validation
num_train = len(train_data)
@@ -122,27 +146,30 @@ def main():
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
- train_loader = torch.utils.data.DataLoader(train_data,
- batch_size=batch_size,
- sampler=train_sampler,
- num_workers=num_workers,
- pin_memory=True)
-
- valid_loader = torch.utils.data.DataLoader(train_data,
- batch_size=batch_size,
- sampler=valid_sampler,
- num_workers=num_workers,
- pin_memory=True)
+ train_loader = torch.utils.data.DataLoader(
+ train_data,
+ batch_size=batch_size,
+ sampler=train_sampler,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
+
+ valid_loader = torch.utils.data.DataLoader(
+ train_data,
+ batch_size=batch_size,
+ sampler=valid_sampler,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
- w_optim,
- num_epochs,
- eta_min=w_lr_min)
+ w_optim, num_epochs, eta_min=w_lr_min
+ )
architect = Architect(model, w_momentum, w_weight_decay, device)
# Start training
- best_top1 = 0.
+ best_top1 = 0.0
for epoch in range(num_epochs):
lr = lr_scheduler.get_last_lr()
@@ -151,14 +178,28 @@ def main():
# Training
print(">>> Training")
- train(train_loader, valid_loader, model, architect, w_optim, alpha_optim,
- lr, epoch, num_epochs, device, w_grad_clip, print_step)
+ train(
+ train_loader,
+ valid_loader,
+ model,
+ architect,
+ w_optim,
+ alpha_optim,
+ lr,
+ epoch,
+ num_epochs,
+ device,
+ w_grad_clip,
+ print_step,
+ )
lr_scheduler.step()
# Validation
print("\n>>> Validation")
cur_step = (epoch + 1) * len(train_loader)
- top1 = validate(valid_loader, model, epoch, cur_step, num_epochs, device, print_step)
+ top1 = validate(
+ valid_loader, model, epoch, cur_step, num_epochs, device, print_step
+ )
# Print genotype
genotype = model.genotype(search_space)
@@ -173,18 +214,36 @@ def main():
print("\nBest-Genotype={}".format(str(best_genotype).replace(" ", "")))
-def train(train_loader, valid_loader, model, architect, w_optim, alpha_optim,
- lr, epoch, num_epochs, device, w_grad_clip, print_step):
+def train(
+ train_loader,
+ valid_loader,
+ model,
+ architect,
+ w_optim,
+ alpha_optim,
+ lr,
+ epoch,
+ num_epochs,
+ device,
+ w_grad_clip,
+ print_step,
+):
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
losses = utils.AverageMeter()
cur_step = epoch * len(train_loader)
model.train()
- for step, ((train_x, train_y), (valid_x, valid_y)) in enumerate(zip(train_loader, valid_loader)):
+ for step, ((train_x, train_y), (valid_x, valid_y)) in enumerate(
+ zip(train_loader, valid_loader)
+ ):
- train_x, train_y = train_x.to(device, non_blocking=True), train_y.to(device, non_blocking=True)
- valid_x, valid_y = valid_x.to(device, non_blocking=True), valid_y.to(device, non_blocking=True)
+ train_x, train_y = train_x.to(device, non_blocking=True), train_y.to(
+ device, non_blocking=True
+ )
+ valid_x, valid_y = valid_x.to(device, non_blocking=True), valid_y.to(
+ device, non_blocking=True
+ )
train_size = train_x.size(0)
@@ -213,12 +272,21 @@ def train(train_loader, valid_loader, model, architect, w_optim, alpha_optim,
print(
"Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
- epoch+1, num_epochs, step, len(train_loader)-1, losses=losses,
- top1=top1, top5=top5))
+ epoch + 1,
+ num_epochs,
+ step,
+ len(train_loader) - 1,
+ losses=losses,
+ top1=top1,
+ top5=top5,
+ )
+ )
cur_step += 1
- print("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch+1, num_epochs, top1.avg))
+ print(
+ "Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, num_epochs, top1.avg)
+ )
def validate(valid_loader, model, epoch, cur_step, num_epochs, device, print_step):
@@ -230,7 +298,9 @@ def validate(valid_loader, model, epoch, cur_step, num_epochs, device, print_ste
with torch.no_grad():
for step, (valid_x, valid_y) in enumerate(valid_loader):
- valid_x, valid_y = valid_x.to(device, non_blocking=True), valid_y.to(device, non_blocking=True)
+ valid_x, valid_y = valid_x.to(device, non_blocking=True), valid_y.to(
+ device, non_blocking=True
+ )
valid_size = valid_x.size(0)
@@ -246,10 +316,19 @@ def validate(valid_loader, model, epoch, cur_step, num_epochs, device, print_ste
print(
"Validation: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
- epoch+1, num_epochs, step, len(valid_loader)-1, losses=losses,
- top1=top1, top5=top5))
-
- print("Valid: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch+1, num_epochs, top1.avg))
+ epoch + 1,
+ num_epochs,
+ step,
+ len(valid_loader) - 1,
+ losses=losses,
+ top1=top1,
+ top5=top5,
+ )
+ )
+
+ print(
+ "Valid: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, num_epochs, top1.avg)
+ )
return top1.avg
diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/search_space.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/search_space.py
index 3485f1a1ec5..1d77f20a417 100644
--- a/examples/v1beta1/trial-images/darts-cnn-cifar10/search_space.py
+++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/search_space.py
@@ -17,14 +17,16 @@
import torch
-class SearchSpace():
+class SearchSpace:
def __init__(self, search_space):
self.primitives = search_space
self.primitives.append("none")
print(">>> All Primitives")
print("{}\n".format(self.primitives))
- self.genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
+ self.genotype = namedtuple(
+ "Genotype", "normal normal_concat reduce reduce_concat"
+ )
def parse(self, alpha, k):
"""
@@ -46,7 +48,7 @@ def parse(self, alpha, k):
"""
gene = []
- assert self.primitives[-1] == 'none' # assume last PRIMITIVE is 'none'
+ assert self.primitives[-1] == "none" # assume last PRIMITIVE is 'none'
# 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
# 2) Choose top-k edges per node by edge score (top-1 weight in edge)
diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/utils.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/utils.py
index 6a278ada83f..070be55d366 100644
--- a/examples/v1beta1/trial-images/darts-cnn-cifar10/utils.py
+++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/utils.py
@@ -16,21 +16,21 @@
import torchvision.transforms as transforms
-class AverageMeter():
- """ Computes and stores the average and current value """
+class AverageMeter:
+ """Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
- """ Reset all statistics """
+ """Reset all statistics"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
- """ Update statistics """
+ """Update statistics"""
self.val = val
self.sum += val * n
self.count += n
@@ -38,7 +38,7 @@ def update(self, val, n=1):
def accuracy(output, target, topk=(1,)):
- """ Computes the precision@k for the specified values of k """
+ """Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
@@ -67,18 +67,14 @@ def get_dataset():
# Do preprocessing
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
- transf = [
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip()
- ]
+ transf = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
- normalize = [
- transforms.ToTensor(),
- transforms.Normalize(MEAN, STD)
- ]
+ normalize = [transforms.ToTensor(), transforms.Normalize(MEAN, STD)]
train_transform = transforms.Compose(transf + normalize)
- train_data = dataset_cls(root="./data", train=True, download=True, transform=train_transform)
+ train_data = dataset_cls(
+ root="./data", train=True, download=True, transform=train_transform
+ )
return input_channels, num_classes, train_data
diff --git a/examples/v1beta1/trial-images/enas-cnn-cifar10/ModelConstructor.py b/examples/v1beta1/trial-images/enas-cnn-cifar10/ModelConstructor.py
index 5de068b1608..64b517c5b77 100644
--- a/examples/v1beta1/trial-images/enas-cnn-cifar10/ModelConstructor.py
+++ b/examples/v1beta1/trial-images/enas-cnn-cifar10/ModelConstructor.py
@@ -14,41 +14,23 @@
import json
-from keras import backend as K
-from keras.layers import Activation
-from keras.layers import AveragePooling2D
-from keras.layers import BatchNormalization
-from keras.layers import concatenate
-from keras.layers import Conv2D
-from keras.layers import Dense
-from keras.layers import Dropout
-from keras.layers import GlobalAveragePooling2D
-from keras.layers import Input
-from keras.layers import MaxPooling2D
-from keras.layers import ZeroPadding2D
+from keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input
from keras.models import Model
-import numpy as np
-from op_library import concat
-from op_library import conv
-from op_library import dw_conv
-from op_library import reduction
-from op_library import sp_conv
+from op_library import concat, conv, dw_conv, reduction, sp_conv
class ModelConstructor(object):
def __init__(self, arc_json, nn_json):
self.arch = json.loads(arc_json)
nn_config = json.loads(nn_json)
- self.num_layers = nn_config['num_layers']
- self.input_sizes = nn_config['input_sizes']
- self.output_size = nn_config['output_sizes'][-1]
- self.embedding = nn_config['embedding']
+ self.num_layers = nn_config["num_layers"]
+ self.input_sizes = nn_config["input_sizes"]
+ self.output_size = nn_config["output_sizes"][-1]
+ self.embedding = nn_config["embedding"]
def build_model(self):
# a list of the data all layers
all_layers = [0 for _ in range(self.num_layers + 1)]
- # a list of all the dimensions of all layers
- all_dims = [0 for _ in range(self.num_layers + 1)]
# ================= Stacking layers =================
# Input Layer. Layer 0
@@ -56,38 +38,37 @@ def build_model(self):
all_layers[0] = input_layer
# Intermediate Layers. Starting from layer 1.
- for l in range(1, self.num_layers + 1):
+ for l_index in range(1, self.num_layers + 1):
input_layers = list()
- opt = self.arch[l - 1][0]
+ opt = self.arch[l_index - 1][0]
opt_config = self.embedding[str(opt)]
- skip = self.arch[l - 1][1:l+1]
+ skip = self.arch[l_index - 1][1 : l_index + 1]
# set up the connection to the previous layer first
- input_layers.append(all_layers[l - 1])
+ input_layers.append(all_layers[l_index - 1])
# then add skip connections
- for i in range(l - 1):
- if l > 1 and skip[i] == 1:
+ for i in range(l_index - 1):
+ if l_index > 1 and skip[i] == 1:
input_layers.append(all_layers[i])
layer_input = concat(input_layers)
- if opt_config['opt_type'] == 'convolution':
+ if opt_config["opt_type"] == "convolution":
layer_output = conv(layer_input, opt_config)
- if opt_config['opt_type'] == 'separable_convolution':
+ if opt_config["opt_type"] == "separable_convolution":
layer_output = sp_conv(layer_input, opt_config)
- if opt_config['opt_type'] == 'depthwise_convolution':
+ if opt_config["opt_type"] == "depthwise_convolution":
layer_output = dw_conv(layer_input, opt_config)
- elif opt_config['opt_type'] == 'reduction':
+ elif opt_config["opt_type"] == "reduction":
layer_output = reduction(layer_input, opt_config)
- all_layers[l] = layer_output
+ all_layers[l_index] = layer_output
# Final Layer
# Global Average Pooling, then Fully connected with softmax.
avgpooled = GlobalAveragePooling2D()(all_layers[self.num_layers])
dropped = Dropout(0.4)(avgpooled)
- logits = Dense(units=self.output_size,
- activation='softmax')(dropped)
+ logits = Dense(units=self.output_size, activation="softmax")(dropped)
# Encapsulate the model
self.model = Model(inputs=input_layer, outputs=logits)
diff --git a/examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py b/examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py
index f843594674d..f0db82bea79 100644
--- a/examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py
+++ b/examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py
@@ -14,32 +14,50 @@
import argparse
+import tensorflow as tf
from keras.datasets import cifar10
from ModelConstructor import ModelConstructor
from tensorflow import keras
-import tensorflow as tf
-from tensorflow.keras.layers import RandomFlip
-from tensorflow.keras.layers import RandomTranslation
-from tensorflow.keras.layers import Rescaling
+from tensorflow.keras.layers import RandomFlip, RandomTranslation, Rescaling
from tensorflow.keras.utils import to_categorical
if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='TrainingContainer')
- parser.add_argument('--architecture', type=str, default="", metavar='N',
- help='architecture of the neural network')
- parser.add_argument('--nn_config', type=str, default="", metavar='N',
- help='configurations and search space embeddings')
- parser.add_argument('--num_epochs', type=int, default=10, metavar='N',
- help='number of epoches that each child will be trained')
- parser.add_argument('--num_gpus', type=int, default=1, metavar='N',
- help='number of GPU that used for training')
+ parser = argparse.ArgumentParser(description="TrainingContainer")
+ parser.add_argument(
+ "--architecture",
+ type=str,
+ default="",
+ metavar="N",
+ help="architecture of the neural network",
+ )
+ parser.add_argument(
+ "--nn_config",
+ type=str,
+ default="",
+ metavar="N",
+ help="configurations and search space embeddings",
+ )
+ parser.add_argument(
+ "--num_epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epoches that each child will be trained",
+ )
+ parser.add_argument(
+ "--num_gpus",
+ type=int,
+ default=1,
+ metavar="N",
+ help="number of GPU that used for training",
+ )
args = parser.parse_args()
- arch = args.architecture.replace("\'", "\"")
+ arch = args.architecture.replace("'", '"')
print(">>> arch received by trial")
print(arch)
- nn_config = args.nn_config.replace("\'", "\"")
+ nn_config = args.nn_config.replace("'", '"')
print(">>> nn_config received by trial")
print(nn_config)
@@ -54,36 +72,40 @@
print("\n>>> Constructing Model...")
constructor = ModelConstructor(arch, nn_config)
- num_physical_gpus = len(tf.config.experimental.list_physical_devices('GPU'))
+ num_physical_gpus = len(tf.config.experimental.list_physical_devices("GPU"))
if 1 <= num_gpus <= num_physical_gpus:
- devices = ["/gpu:"+str(i) for i in range(num_physical_gpus)]
+ devices = ["/gpu:" + str(i) for i in range(num_physical_gpus)]
else:
- num_physical_cpu = len(tf.config.experimental.list_physical_devices('CPU'))
- devices = ["/cpu:"+str(j) for j in range(num_physical_cpu)]
+ num_physical_cpu = len(tf.config.experimental.list_physical_devices("CPU"))
+ devices = ["/cpu:" + str(j) for j in range(num_physical_cpu)]
strategy = tf.distribute.MirroredStrategy(devices)
with strategy.scope():
test_model = constructor.build_model()
test_model.summary()
- test_model.compile(loss=keras.losses.categorical_crossentropy,
- optimizer=keras.optimizers.Adam(learning_rate=1e-3),
- metrics=['accuracy'])
+ test_model.compile(
+ loss=keras.losses.categorical_crossentropy,
+ optimizer=keras.optimizers.Adam(learning_rate=1e-3),
+ metrics=["accuracy"],
+ )
print(">>> Model Constructed Successfully\n")
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
- x_train = x_train.astype('float32')
- x_test = x_test.astype('float32')
+ x_train = x_train.astype("float32")
+ x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
- augmentation = tf.keras.Sequential([
- Rescaling(1./255),
- RandomFlip('horizontal'),
- RandomTranslation(height_factor=0.1, width_factor=0.1),
- ])
+ augmentation = tf.keras.Sequential(
+ [
+ Rescaling(1.0 / 255),
+ RandomFlip("horizontal"),
+ RandomTranslation(height_factor=0.1, width_factor=0.1),
+ ]
+ )
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.map(lambda x, y: (augmentation(x, training=True), y))
@@ -93,11 +115,14 @@
print(">>> Data Loaded. Training starts.")
for e in range(num_epochs):
print("\nTotal Epoch {}/{}".format(e + 1, num_epochs))
- history = test_model.fit(train_dataset,
- steps_per_epoch=int(len(x_train) / 128) + 1,
- epochs=1, verbose=1,
- validation_data=(x_test, y_test))
- print("Training-Accuracy={}".format(history.history['accuracy'][-1]))
- print("Training-Loss={}".format(history.history['loss'][-1]))
- print("Validation-Accuracy={}".format(history.history['val_accuracy'][-1]))
- print("Validation-Loss={}".format(history.history['val_loss'][-1]))
+ history = test_model.fit(
+ train_dataset,
+ steps_per_epoch=int(len(x_train) / 128) + 1,
+ epochs=1,
+ verbose=1,
+ validation_data=(x_test, y_test),
+ )
+ print("Training-Accuracy={}".format(history.history["accuracy"][-1]))
+ print("Training-Loss={}".format(history.history["loss"][-1]))
+ print("Validation-Accuracy={}".format(history.history["val_accuracy"][-1]))
+ print("Validation-Loss={}".format(history.history["val_loss"][-1]))
diff --git a/examples/v1beta1/trial-images/enas-cnn-cifar10/op_library.py b/examples/v1beta1/trial-images/enas-cnn-cifar10/op_library.py
index 7defaf38464..eebcff7db49 100644
--- a/examples/v1beta1/trial-images/enas-cnn-cifar10/op_library.py
+++ b/examples/v1beta1/trial-images/enas-cnn-cifar10/op_library.py
@@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from keras import backend as K
-from keras.layers import Activation
-from keras.layers import AveragePooling2D
-from keras.layers import BatchNormalization
-from keras.layers import concatenate
-from keras.layers import Conv2D
-from keras.layers import Dense
-from keras.layers import DepthwiseConv2D
-from keras.layers import GlobalAveragePooling2D
-from keras.layers import Input
-from keras.layers import MaxPooling2D
-from keras.layers import SeparableConv2D
-from keras.layers import ZeroPadding2D
import numpy as np
+from keras import backend as K
+from keras.layers import (
+ Activation,
+ AveragePooling2D,
+ BatchNormalization,
+ Conv2D,
+ DepthwiseConv2D,
+ MaxPooling2D,
+ SeparableConv2D,
+ ZeroPadding2D,
+ concatenate,
+)
def concat(inputs):
@@ -46,10 +45,13 @@ def concat(inputs):
diff = max_dim - total_dim[i][1]
half_diff = int(diff / 2)
if diff % 2 == 0:
- padded_input[i] = ZeroPadding2D(padding=(half_diff, half_diff))(inputs[i])
+ padded_input[i] = ZeroPadding2D(padding=(half_diff, half_diff))(
+ inputs[i]
+ )
else:
- padded_input[i] = ZeroPadding2D(padding=((half_diff, half_diff + 1),
- (half_diff, half_diff + 1)))(inputs[i])
+ padded_input[i] = ZeroPadding2D(
+ padding=((half_diff, half_diff + 1), (half_diff, half_diff + 1))
+ )(inputs[i])
else:
padded_input[i] = inputs[i]
@@ -59,21 +61,22 @@ def concat(inputs):
def conv(x, config):
parameters = {
- "num_filter": 64,
- "filter_size": 3,
- "stride": 1,
+ "num_filter": 64,
+ "filter_size": 3,
+ "stride": 1,
}
for k in parameters.keys():
if k in config:
parameters[k] = int(config[k])
- activated = Activation('relu')(x)
+ activated = Activation("relu")(x)
conved = Conv2D(
- filters=parameters['num_filter'],
- kernel_size=parameters['filter_size'],
- strides=parameters['stride'],
- padding='same')(activated)
+ filters=parameters["num_filter"],
+ kernel_size=parameters["filter_size"],
+ strides=parameters["stride"],
+ padding="same",
+ )(activated)
result = BatchNormalization()(conved)
@@ -82,9 +85,9 @@ def conv(x, config):
def sp_conv(x, config):
parameters = {
- "num_filter": 64,
- "filter_size": 3,
- "stride": 1,
+ "num_filter": 64,
+ "filter_size": 3,
+ "stride": 1,
"depth_multiplier": 1,
}
@@ -92,36 +95,39 @@ def sp_conv(x, config):
if k in config:
parameters[k] = int(config[k])
- activated = Activation('relu')(x)
+ activated = Activation("relu")(x)
conved = SeparableConv2D(
- filters=parameters['num_filter'],
- kernel_size=parameters['filter_size'],
- strides=parameters['stride'],
- depth_multiplier=parameters['depth_multiplier'],
- padding='same')(activated)
+ filters=parameters["num_filter"],
+ kernel_size=parameters["filter_size"],
+ strides=parameters["stride"],
+ depth_multiplier=parameters["depth_multiplier"],
+ padding="same",
+ )(activated)
result = BatchNormalization()(conved)
return result
+
def dw_conv(x, config):
parameters = {
- "filter_size": 3,
- "stride": 1,
+ "filter_size": 3,
+ "stride": 1,
"depth_multiplier": 1,
}
for k in parameters.keys():
if k in config:
parameters[k] = int(config[k])
- activated = Activation('relu')(x)
+ activated = Activation("relu")(x)
conved = DepthwiseConv2D(
- kernel_size=parameters['filter_size'],
- strides=parameters['stride'],
- depth_multiplier=parameters['depth_multiplier'],
- padding='same')(activated)
+ kernel_size=parameters["filter_size"],
+ strides=parameters["stride"],
+ depth_multiplier=parameters["depth_multiplier"],
+ padding="same",
+ )(activated)
result = BatchNormalization()(conved)
@@ -134,31 +140,32 @@ def reduction(x, config):
# such situation is very likely to appear though
dim = K.int_shape(x)
if dim[1] == 1 or dim[2] == 1:
- print("WARNING: One or more dimensions of the input of the reduction layer is 1. It cannot be further reduced. A identity layer will be used instead.")
+ print(
+ "WARNING: One or more dimensions of the input of the reduction layer is 1. "
+ "It cannot be further reduced. A identity layer will be used instead."
+ )
return x
parameters = {
- 'reduction_type': "max_pooling",
- 'pool_size': 2,
- 'stride': None,
+ "reduction_type": "max_pooling",
+ "pool_size": 2,
+ "stride": None,
}
- if 'reduction_type' in config:
- parameters['reduction_type'] = config['reduction_type']
- if 'pool_size' in config:
- parameters['pool_size'] = int(config['pool_size'])
- if 'stride' in config:
- parameters['stride'] = int(config['stride'])
+ if "reduction_type" in config:
+ parameters["reduction_type"] = config["reduction_type"]
+ if "pool_size" in config:
+ parameters["pool_size"] = int(config["pool_size"])
+ if "stride" in config:
+ parameters["stride"] = int(config["stride"])
- if parameters['reduction_type'] == 'max_pooling':
+ if parameters["reduction_type"] == "max_pooling":
result = MaxPooling2D(
- pool_size=parameters['pool_size'],
- strides=parameters['stride']
+ pool_size=parameters["pool_size"], strides=parameters["stride"]
)(x)
- elif parameters['reduction_type'] == 'avg_pooling':
+ elif parameters["reduction_type"] == "avg_pooling":
result = AveragePooling2D(
- pool_size=parameters['pool_size'],
- strides=parameters['stride']
+ pool_size=parameters["pool_size"], strides=parameters["stride"]
)(x)
return result
diff --git a/examples/v1beta1/trial-images/pytorch-mnist/mnist.py b/examples/v1beta1/trial-images/pytorch-mnist/mnist.py
index 95611f5953d..7ecc911cbb4 100644
--- a/examples/v1beta1/trial-images/pytorch-mnist/mnist.py
+++ b/examples/v1beta1/trial-images/pytorch-mnist/mnist.py
@@ -24,8 +24,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
-from torchvision import datasets
-from torchvision import transforms
+from torchvision import datasets, transforms
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
@@ -35,7 +34,7 @@ def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
- self.fc1 = nn.Linear(4*4*50, 500)
+ self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
@@ -43,7 +42,7 @@ def forward(self, x):
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
- x = x.view(-1, 4*4*50)
+ x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
@@ -60,10 +59,14 @@ def train(args, model, device, train_loader, optimizer, epoch):
optimizer.step()
if batch_idx % args.log_interval == 0:
msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
- epoch, batch_idx * len(data), len(train_loader.dataset),
- 100. * batch_idx / len(train_loader), loss.item())
+ epoch,
+ batch_idx * len(data),
+ len(train_loader.dataset),
+ 100.0 * batch_idx / len(train_loader),
+ loss.item(),
+ )
logging.info(msg)
- niter = epoch * len(train_loader) + batch_idx
+ niter = epoch * len(train_loader) + batch_idx # noqa: F841
def test(args, model, device, test_loader, epoch, hpt):
@@ -74,24 +77,31 @@ def test(args, model, device, test_loader, epoch, hpt):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
- test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
- pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
+ test_loss += F.nll_loss(
+ output, target, reduction="sum"
+ ).item() # sum up batch loss
+ pred = output.max(1, keepdim=True)[
+ 1
+ ] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_accuracy = float(correct) / len(test_loader.dataset)
- logging.info("{{metricName: accuracy, metricValue: {:.4f}}};{{metricName: loss, metricValue: {:.4f}}}\n".format(
- test_accuracy, test_loss))
+
+ logging.info(
+ "{{metricName: accuracy, metricValue: {:.4f}}};"
+ "{{metricName: loss, metricValue: {:.4f}}}\n".format(test_accuracy, test_loss)
+ )
if args.logger == "hypertune":
hpt.report_hyperparameter_tuning_metric(
- hyperparameter_metric_tag='loss',
- metric_value=test_loss,
- global_step=epoch)
+ hyperparameter_metric_tag="loss", metric_value=test_loss, global_step=epoch
+ )
hpt.report_hyperparameter_tuning_metric(
- hyperparameter_metric_tag='accuracy',
+ hyperparameter_metric_tag="accuracy",
metric_value=test_accuracy,
- global_step=epoch)
+ global_step=epoch,
+ )
def should_distribute():
@@ -105,33 +115,82 @@ def is_distributed():
def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
- parser.add_argument("--batch-size", type=int, default=64, metavar="N",
- help="input batch size for training (default: 64)")
- parser.add_argument("--test-batch-size", type=int, default=1000, metavar="N",
- help="input batch size for testing (default: 1000)")
- parser.add_argument("--epochs", type=int, default=10, metavar="N",
- help="number of epochs to train (default: 10)")
- parser.add_argument("--lr", type=float, default=0.01, metavar="LR",
- help="learning rate (default: 0.01)")
- parser.add_argument("--momentum", type=float, default=0.5, metavar="M",
- help="SGD momentum (default: 0.5)")
- parser.add_argument("--no-cuda", action="store_true", default=False,
- help="disables CUDA training")
- parser.add_argument("--seed", type=int, default=1, metavar="S",
- help="random seed (default: 1)")
- parser.add_argument("--log-interval", type=int, default=10, metavar="N",
- help="how many batches to wait before logging training status")
- parser.add_argument("--log-path", type=str, default="",
- help="Path to save logs. Print to StdOut if log-path is not set")
- parser.add_argument("--save-model", action="store_true", default=False,
- help="For Saving the current Model")
- parser.add_argument("--logger", type=str, choices=["standard", "hypertune"],
- help="Logger", default="standard")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=0.01,
+ metavar="LR",
+ help="learning rate (default: 0.01)",
+ )
+ parser.add_argument(
+ "--momentum",
+ type=float,
+ default=0.5,
+ metavar="M",
+ help="SGD momentum (default: 0.5)",
+ )
+ parser.add_argument(
+ "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--log-path",
+ type=str,
+ default="",
+ help="Path to save logs. Print to StdOut if log-path is not set",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--logger",
+ type=str,
+ choices=["standard", "hypertune"],
+ help="Logger",
+ default="standard",
+ )
if dist.is_available():
- parser.add_argument("--backend", type=str, help="Distributed backend",
- choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
- default=dist.Backend.GLOO)
+ parser.add_argument(
+ "--backend",
+ type=str,
+ help="Distributed backend",
+ choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
+ default=dist.Backend.GLOO,
+ )
args = parser.parse_args()
# Use this format (%Y-%m-%dT%H:%M:%SZ) to record timestamp of the metrics.
@@ -140,16 +199,18 @@ def main():
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
- level=logging.DEBUG)
+ level=logging.DEBUG,
+ )
else:
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.DEBUG,
- filename=args.log_path)
+ filename=args.log_path,
+ )
if args.logger == "hypertune" and args.log_path != "":
- os.environ['CLOUD_ML_HP_METRIC_FILE'] = args.log_path
+ os.environ["CLOUD_ML_HP_METRIC_FILE"] = args.log_path
# For JSON logging
hpt = hypertune.HyperTune()
@@ -169,27 +230,34 @@ def main():
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
- datasets.FashionMNIST("./data",
- train=True,
- download=True,
- transform=transforms.Compose([
- transforms.ToTensor()
- ])),
- batch_size=args.batch_size, shuffle=True, **kwargs)
+ datasets.FashionMNIST(
+ "./data",
+ train=True,
+ download=True,
+ transform=transforms.Compose([transforms.ToTensor()]),
+ ),
+ batch_size=args.batch_size,
+ shuffle=True,
+ **kwargs,
+ )
test_loader = torch.utils.data.DataLoader(
- datasets.FashionMNIST("./data",
- train=False,
- transform=transforms.Compose([
- transforms.ToTensor()
- ])),
- batch_size=args.test_batch_size, shuffle=False, **kwargs)
+ datasets.FashionMNIST(
+ "./data", train=False, transform=transforms.Compose([transforms.ToTensor()])
+ ),
+ batch_size=args.test_batch_size,
+ shuffle=False,
+ **kwargs,
+ )
model = Net().to(device)
if is_distributed():
- Distributor = nn.parallel.DistributedDataParallel if use_cuda \
+ Distributor = (
+ nn.parallel.DistributedDataParallel
+ if use_cuda
else nn.parallel.DistributedDataParallelCPU
+ )
model = Distributor(model)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
@@ -198,7 +266,7 @@ def main():
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch, hpt)
- if (args.save_model):
+ if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
diff --git a/examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py b/examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py
index fea1ed575f4..ca65ff5bbe6 100644
--- a/examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py
+++ b/examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py
@@ -17,17 +17,15 @@
import tensorflow as tf
from tensorflow.keras import Model
-from tensorflow.keras.layers import Conv2D
-from tensorflow.keras.layers import Dense
-from tensorflow.keras.layers import Flatten
+from tensorflow.keras.layers import Conv2D, Dense, Flatten
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
- self.conv1 = Conv2D(32, 3, activation='relu')
+ self.conv1 = Conv2D(32, 3, activation="relu")
self.flatten = Flatten()
- self.d1 = Dense(128, activation='relu')
+ self.d1 = Dense(128, activation="relu")
self.d2 = Dense(10)
def call(self, x):
@@ -37,7 +35,17 @@ def call(self, x):
return self.d2(x)
-def train_step(args, model, optimizer, train_ds, epoch, loss_object, train_summary_writer, train_loss, train_accuracy):
+def train_step(
+ args,
+ model,
+ optimizer,
+ train_ds,
+ epoch,
+ loss_object,
+ train_summary_writer,
+ train_loss,
+ train_accuracy,
+):
for step, (images, labels) in enumerate(train_ds):
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
@@ -51,18 +59,25 @@ def train_step(args, model, optimizer, train_ds, epoch, loss_object, train_summa
train_accuracy(labels, predictions)
if step % args.log_interval == 0:
- print("Train Epoch: {} [{}/60000 ({:.0f}%)]\tloss={:.4f}, accuracy={:.4f}".format(
- epoch + 1, step * args.batch_size, 100. * step * args.batch_size / 60000,
- train_loss.result(), train_accuracy.result() * 100)
+ print(
+ "Train Epoch: {} [{}/60000 ({:.0f}%)]\tloss={:.4f}, accuracy={:.4f}".format(
+ epoch + 1,
+ step * args.batch_size,
+ 100.0 * step * args.batch_size / 60000,
+ train_loss.result(),
+ train_accuracy.result() * 100,
+ )
)
with train_summary_writer.as_default():
- tf.summary.scalar('loss', train_loss.result(), step=epoch)
- tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)
+ tf.summary.scalar("loss", train_loss.result(), step=epoch)
+ tf.summary.scalar("accuracy", train_accuracy.result(), step=epoch)
-def test_step(model, test_ds, epoch, loss_object, test_summary_writer, test_loss, test_accuracy):
- for (images, labels) in test_ds:
+def test_step(
+ model, test_ds, epoch, loss_object, test_summary_writer, test_loss, test_accuracy
+):
+ for images, labels in test_ds:
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
predictions = model(images, training=False)
@@ -72,30 +87,53 @@ def test_step(model, test_ds, epoch, loss_object, test_summary_writer, test_loss
test_accuracy(labels, predictions)
with test_summary_writer.as_default():
- tf.summary.scalar('loss', test_loss.result(), step=epoch)
- tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
+ tf.summary.scalar("loss", test_loss.result(), step=epoch)
+ tf.summary.scalar("accuracy", test_accuracy.result(), step=epoch)
- print("Test Loss: {:.4f}, Test Accuracy: {:.4f}\n".format(
- test_loss.result(), test_accuracy.result() * 100)
+ print(
+ "Test Loss: {:.4f}, Test Accuracy: {:.4f}\n".format(
+ test_loss.result(), test_accuracy.result() * 100
+ )
)
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--batch-size', type=int, default=64,
- help='input batch size for training (default: 64)')
- parser.add_argument('--learning-rate', type=float, default=0.001,
- help='learning rate (default: 0.001)')
- parser.add_argument("--epochs", type=int, default=10, metavar="N",
- help="number of epochs to train (default: 10)")
- parser.add_argument("--log-interval", type=int, default=100, metavar="N",
- help="how many batches to wait before logging training status (default: 100)")
parser.add_argument(
- '--log-path',
+ "--batch-size",
+ type=int,
+ default=64,
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--learning-rate",
+ type=float,
+ default=0.001,
+ help="learning rate (default: 0.001)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=100,
+ metavar="N",
+ help="how many batches to wait before logging training status (default: 100)",
+ )
+ parser.add_argument(
+ "--log-path",
type=str,
- default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
- 'tensorflow/mnist/logs/mnist_with_summaries'),
- help='Summaries log PATH')
+ default=os.path.join(
+ os.getenv("TEST_TMPDIR", "/tmp"),
+ "tensorflow/mnist/logs/mnist_with_summaries",
+ ),
+ help="Summaries log PATH",
+ )
args = parser.parse_args()
# Setup dataset
@@ -105,12 +143,18 @@ def main():
# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")
- train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(args.batch_size)
- test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(args.batch_size)
+ train_ds = (
+ tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ .shuffle(10000)
+ .batch(args.batch_size)
+ )
+ test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(
+ args.batch_size
+ )
# Setup tensorflow summaries
- train_log_dir = os.path.join(args.log_path, 'train')
- test_log_dir = os.path.join(args.log_path, 'test')
+ train_log_dir = os.path.join(args.log_path, "train")
+ test_log_dir = os.path.join(args.log_path, "test")
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
@@ -119,20 +163,37 @@ def main():
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
- train_loss = tf.keras.metrics.Mean(name='train_loss')
- train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
+ train_loss = tf.keras.metrics.Mean(name="train_loss")
+ train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")
- test_loss = tf.keras.metrics.Mean(name='test_loss')
- test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
+ test_loss = tf.keras.metrics.Mean(name="test_loss")
+ test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="test_accuracy")
for epoch in range(args.epochs):
# Reset the metrics at the start of the next epoch
train_summary_writer.flush()
test_summary_writer.flush()
- train_step(args, model, optimizer, train_ds, epoch, loss_object, train_summary_writer,
- train_loss, train_accuracy)
- test_step(model, test_ds, epoch, loss_object, test_summary_writer, test_loss, test_accuracy)
+ train_step(
+ args,
+ model,
+ optimizer,
+ train_ds,
+ epoch,
+ loss_object,
+ train_summary_writer,
+ train_loss,
+ train_accuracy,
+ )
+ test_step(
+ model,
+ test_ds,
+ epoch,
+ loss_object,
+ test_summary_writer,
+ test_loss,
+ test_accuracy,
+ )
if __name__ == "__main__":
diff --git a/go.mod b/go.mod
index 2a32eb38b00..f972490f407 100644
--- a/go.mod
+++ b/go.mod
@@ -69,7 +69,7 @@ require (
github.com/dimchansky/utfbom v1.1.1 // indirect
github.com/docker/cli v24.0.0+incompatible // indirect
github.com/docker/distribution v2.8.2+incompatible // indirect
- github.com/docker/docker v24.0.9+incompatible // indirect
+ github.com/docker/docker v26.1.5+incompatible // indirect
github.com/docker/docker-credential-helpers v0.7.0 // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/evanphx/json-patch v5.6.0+incompatible // indirect
diff --git a/go.sum b/go.sum
index f52bcc63ccf..d5ab6ae2cd1 100644
--- a/go.sum
+++ b/go.sum
@@ -180,8 +180,8 @@ github.com/docker/cli v24.0.0+incompatible h1:0+1VshNwBQzQAx9lOl+OYCTCEAD8fKs/qe
github.com/docker/cli v24.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8=
github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
-github.com/docker/docker v24.0.9+incompatible h1:HPGzNmwfLZWdxHqK9/II92pyi1EpYKsAqcl4G0Of9v0=
-github.com/docker/docker v24.0.9+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
+github.com/docker/docker v26.1.5+incompatible h1:NEAxTwEjxV6VbBMBoGG3zPqbiJosIApZjxlbrG9q3/g=
+github.com/docker/docker v26.1.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/docker-credential-helpers v0.7.0 h1:xtCHsjxogADNZcdv1pKUHXryefjlVRqWqIhk/uXJp0A=
github.com/docker/docker-credential-helpers v0.7.0/go.mod h1:rETQfLdHNT3foU5kuNkFR1R1V12OJRRO5lzt2D1b5X0=
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
diff --git a/hack/gen-python-sdk/post_gen.py b/hack/gen-python-sdk/post_gen.py
index b7bda1f251f..70eab3a2595 100644
--- a/hack/gen-python-sdk/post_gen.py
+++ b/hack/gen-python-sdk/post_gen.py
@@ -18,14 +18,14 @@
IGNORE_LINES = [
"from kubeflow.katib.models.v1_unstructured_unstructured import V1UnstructuredUnstructured",
- "from kubeflow.katib.models.v1_time import V1Time"
+ "from kubeflow.katib.models.v1_time import V1Time",
]
def _rewrite_helper(input_file, output_file, rewrite_rules):
rules = rewrite_rules or []
lines = []
- with open(input_file, 'r') as f:
+ with open(input_file, "r") as f:
while True:
line = f.readline()
if not line:
@@ -34,11 +34,11 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
for rule in rules:
line = rule(line)
# Remove ignored lines.
- if not any(l in line for l in IGNORE_LINES):
+ if not any(li in line for li in IGNORE_LINES):
lines.append(line)
# Add Katib APIs to the init file.
- if (output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py"):
+ if output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py":
lines.append("# Import Katib API client.\n")
lines.append("from kubeflow.katib.api.katib_client import KatibClient\n")
lines.append("# Import Katib report metrics functions\n")
@@ -46,62 +46,79 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
lines.append("# Import Katib helper functions.\n")
lines.append("import kubeflow.katib.api.search as search\n")
lines.append("# Import Katib helper constants.\n")
- lines.append("from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW\n")
- lines.append("from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW_GPU\n")
- lines.append("from kubeflow.katib.constants.constants import BASE_IMAGE_PYTORCH\n")
- lines.append("from kubeflow.katib.constants.constants import BASE_IMAGE_MXNET\n")
+ lines.append(
+ "from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW\n"
+ )
+ lines.append(
+ "from kubeflow.katib.constants.constants import BASE_IMAGE_TENSORFLOW_GPU\n"
+ )
+ lines.append(
+ "from kubeflow.katib.constants.constants import BASE_IMAGE_PYTORCH\n"
+ )
+ lines.append(
+ "from kubeflow.katib.constants.constants import BASE_IMAGE_MXNET\n"
+ )
# Add Kubernetes models to proper deserialization of Katib models.
- if (output_file == "sdk/python/v1beta1/kubeflow/katib/models/__init__.py"):
+ if output_file == "sdk/python/v1beta1/kubeflow/katib/models/__init__.py":
lines.append("\n")
lines.append("# Import Kubernetes models.\n")
lines.append("from kubernetes.client import *\n")
- with open(output_file, 'w') as f:
+ with open(output_file, "w") as f:
f.writelines(lines)
-def update_python_sdk(src, dest, versions=('v1beta1')):
+def update_python_sdk(src, dest, versions=("v1beta1")):
# tiny transformers to refine generated codes
rewrite_rules = [
# Models rules.
- lambda l: l.replace('import katib', 'import kubeflow.katib'),
- lambda l: l.replace('from katib', 'from kubeflow.katib'),
+ lambda line: line.replace("import katib", "import kubeflow.katib"),
+ lambda line: line.replace("from katib", "from kubeflow.katib"),
# For the api_client.py.
- lambda l: l.replace('klass = getattr(katib.models, klass)', 'klass = getattr(kubeflow.katib.models, klass)'),
+ lambda line: line.replace(
+ "klass = getattr(katib.models, klass)",
+ "klass = getattr(kubeflow.katib.models, klass)",
+ ),
# Doc rules.
- lambda l: l.replace('[**datetime**](V1Time.md)', '**datetime**'),
- lambda l: l.replace('[**object**](V1UnstructuredUnstructured.md)', '**object**'),
-
- lambda l: l.replace('[**V1Container**](V1Container.md)',
- '[**V1Container**](https://github.com/kubernetes-client/'
- 'python/blob/master/kubernetes/docs/V1Container.md)'),
-
- lambda l: l.replace('[**V1ObjectMeta**](V1ObjectMeta.md)',
- '[**V1ObjectMeta**](https://github.com/kubernetes-client/'
- 'python/blob/master/kubernetes/docs/V1ObjectMeta.md)'),
-
- lambda l: l.replace('[**V1ListMeta**](V1ListMeta.md)',
- '[**V1ListMeta**](https://github.com/kubernetes-client/'
- 'python/blob/master/kubernetes/docs/V1ListMeta.md)'),
-
- lambda l: l.replace('[**V1HTTPGetAction**](V1HTTPGetAction.md)',
- '[**V1HTTPGetAction**](https://github.com/kubernetes-client/'
- 'python/blob/master/kubernetes/docs/V1HTTPGetAction.md)')
+ lambda line: line.replace("[**datetime**](V1Time.md)", "**datetime**"),
+ lambda line: line.replace(
+ "[**object**](V1UnstructuredUnstructured.md)", "**object**"
+ ),
+ lambda line: line.replace(
+ "[**V1Container**](V1Container.md)",
+ "[**V1Container**](https://github.com/kubernetes-client/"
+ "python/blob/master/kubernetes/docs/V1Container.md)",
+ ),
+ lambda line: line.replace(
+ "[**V1ObjectMeta**](V1ObjectMeta.md)",
+ "[**V1ObjectMeta**](https://github.com/kubernetes-client/"
+ "python/blob/master/kubernetes/docs/V1ObjectMeta.md)",
+ ),
+ lambda line: line.replace(
+ "[**V1ListMeta**](V1ListMeta.md)",
+ "[**V1ListMeta**](https://github.com/kubernetes-client/"
+ "python/blob/master/kubernetes/docs/V1ListMeta.md)",
+ ),
+ lambda line: line.replace(
+ "[**V1HTTPGetAction**](V1HTTPGetAction.md)",
+ "[**V1HTTPGetAction**](https://github.com/kubernetes-client/"
+ "python/blob/master/kubernetes/docs/V1HTTPGetAction.md)",
+ ),
]
# TODO (andreyvelich): Currently test can't be generated properly.
src_dirs = [
- os.path.join(src, 'katib'),
- os.path.join(src, 'katib', 'models'),
+ os.path.join(src, "katib"),
+ os.path.join(src, "katib", "models"),
# os.path.join(src, 'test'),
- os.path.join(src, 'docs')
+ os.path.join(src, "docs"),
]
dest_dirs = [
- os.path.join(dest, 'kubeflow', 'katib'),
- os.path.join(dest, 'kubeflow', 'katib', 'models'),
+ os.path.join(dest, "kubeflow", "katib"),
+ os.path.join(dest, "kubeflow", "katib", "models"),
# os.path.join(dest, 'test'),
- os.path.join(dest, 'docs')
+ os.path.join(dest, "docs"),
]
for src_dir, dest_dir in zip(src_dirs, dest_dirs):
@@ -128,13 +145,13 @@ def update_python_sdk(src, dest, versions=('v1beta1')):
update_buffer = []
# Get data from generated doc
- with open(os.path.join(src, 'README.md'), 'r') as src_f:
+ with open(os.path.join(src, "README.md"), "r") as src_f:
anchor = 0
for line in src_f.readlines():
- if line.startswith('## Documentation For Models'):
+ if line.startswith("## Documentation For Models"):
if anchor == 0:
anchor = 1
- elif line.startswith('##') and anchor == 1:
+ elif line.startswith("##") and anchor == 1:
anchor = 2
if anchor == 0:
continue
@@ -148,24 +165,24 @@ def update_python_sdk(src, dest, versions=('v1beta1')):
update_buffer = update_buffer[:-1]
# Update README with new models
- with open(os.path.join(dest, 'README.md'), 'r') as dest_f:
+ with open(os.path.join(dest, "README.md"), "r") as dest_f:
anchor = 0
for line in dest_f.readlines():
- if line.startswith('## Documentation For Models'):
+ if line.startswith("## Documentation For Models"):
if anchor == 0:
buffer.extend(update_buffer)
anchor = 1
- elif line.startswith('##') and anchor == 1:
+ elif line.startswith("##") and anchor == 1:
anchor = 2
if anchor == 1:
continue
buffer.append(line)
- with open(os.path.join(dest, 'README.md'), 'w') as dest_f:
+ with open(os.path.join(dest, "README.md"), "w") as dest_f:
dest_f.writelines(buffer)
# Clear working dictionary
shutil.rmtree(src)
-if __name__ == '__main__':
+if __name__ == "__main__":
update_python_sdk(src=sys.argv[1], dest=sys.argv[2])
diff --git a/pkg/apis/controller/experiments/v1beta1/experiment_types.go b/pkg/apis/controller/experiments/v1beta1/experiment_types.go
index 37498f24442..2d8e1d6968b 100644
--- a/pkg/apis/controller/experiments/v1beta1/experiment_types.go
+++ b/pkg/apis/controller/experiments/v1beta1/experiment_types.go
@@ -207,12 +207,23 @@ const (
)
type FeasibleSpace struct {
- Max string `json:"max,omitempty"`
- Min string `json:"min,omitempty"`
- List []string `json:"list,omitempty"`
- Step string `json:"step,omitempty"`
+ Max string `json:"max,omitempty"`
+ Min string `json:"min,omitempty"`
+ List []string `json:"list,omitempty"`
+ Step string `json:"step,omitempty"`
+ Distribution Distribution `json:"distribution,omitempty"`
}
+type Distribution string
+
+const (
+ DistributionUniform Distribution = "uniform"
+ DistributionLogUniform Distribution = "logUniform"
+ DistributionNormal Distribution = "normal"
+ DistributionLogNormal Distribution = "logNormal"
+ DistributionUnknown Distribution = "unknown"
+)
+
// TrialTemplate describes structure of trial template
type TrialTemplate struct {
// Retain indicates that trial resources must be not cleanup
diff --git a/pkg/apis/manager/v1beta1/api.pb.go b/pkg/apis/manager/v1beta1/api.pb.go
index 8ca56b39563..862db115141 100644
--- a/pkg/apis/manager/v1beta1/api.pb.go
+++ b/pkg/apis/manager/v1beta1/api.pb.go
@@ -85,10 +85,11 @@ func (ParameterType) EnumDescriptor() ([]byte, []int) {
type Distribution int32
const (
- Distribution_UNIFORM Distribution = 0
- Distribution_LOG_UNIFORM Distribution = 1
- Distribution_NORMAL Distribution = 2
- Distribution_LOG_NORMAL Distribution = 3
+ Distribution_UNIFORM Distribution = 0
+ Distribution_LOG_UNIFORM Distribution = 1
+ Distribution_NORMAL Distribution = 2
+ Distribution_LOG_NORMAL Distribution = 3
+ Distribution_DISTRIBUTION_UNKNOWN Distribution = 4
)
// Enum value maps for Distribution.
@@ -98,12 +99,14 @@ var (
1: "LOG_UNIFORM",
2: "NORMAL",
3: "LOG_NORMAL",
+ 4: "DISTRIBUTION_UNKNOWN",
}
Distribution_value = map[string]int32{
- "UNIFORM": 0,
- "LOG_UNIFORM": 1,
- "NORMAL": 2,
- "LOG_NORMAL": 3,
+ "UNIFORM": 0,
+ "LOG_UNIFORM": 1,
+ "NORMAL": 2,
+ "LOG_NORMAL": 3,
+ "DISTRIBUTION_UNKNOWN": 4,
}
)
@@ -3034,82 +3037,83 @@ var file_api_proto_rawDesc = []byte{
0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x44, 0x4f, 0x55, 0x42, 0x4c, 0x45, 0x10, 0x01, 0x12, 0x07,
0x0a, 0x03, 0x49, 0x4e, 0x54, 0x10, 0x02, 0x12, 0x0c, 0x0a, 0x08, 0x44, 0x49, 0x53, 0x43, 0x52,
0x45, 0x54, 0x45, 0x10, 0x03, 0x12, 0x0f, 0x0a, 0x0b, 0x43, 0x41, 0x54, 0x45, 0x47, 0x4f, 0x52,
- 0x49, 0x43, 0x41, 0x4c, 0x10, 0x04, 0x2a, 0x48, 0x0a, 0x0c, 0x44, 0x69, 0x73, 0x74, 0x72, 0x69,
+ 0x49, 0x43, 0x41, 0x4c, 0x10, 0x04, 0x2a, 0x62, 0x0a, 0x0c, 0x44, 0x69, 0x73, 0x74, 0x72, 0x69,
0x62, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x49, 0x46, 0x4f, 0x52,
0x4d, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x4c, 0x4f, 0x47, 0x5f, 0x55, 0x4e, 0x49, 0x46, 0x4f,
0x52, 0x4d, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x4e, 0x4f, 0x52, 0x4d, 0x41, 0x4c, 0x10, 0x02,
0x12, 0x0e, 0x0a, 0x0a, 0x4c, 0x4f, 0x47, 0x5f, 0x4e, 0x4f, 0x52, 0x4d, 0x41, 0x4c, 0x10, 0x03,
- 0x2a, 0x38, 0x0a, 0x0d, 0x4f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x69, 0x76, 0x65, 0x54, 0x79, 0x70,
- 0x65, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0c,
- 0x0a, 0x08, 0x4d, 0x49, 0x4e, 0x49, 0x4d, 0x49, 0x5a, 0x45, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08,
- 0x4d, 0x41, 0x58, 0x49, 0x4d, 0x49, 0x5a, 0x45, 0x10, 0x02, 0x2a, 0x4a, 0x0a, 0x0e, 0x43, 0x6f,
- 0x6d, 0x70, 0x61, 0x72, 0x69, 0x73, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x12,
- 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x5f, 0x43, 0x4f, 0x4d, 0x50, 0x41, 0x52, 0x49, 0x53,
- 0x4f, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x51, 0x55, 0x41, 0x4c, 0x10, 0x01, 0x12,
- 0x08, 0x0a, 0x04, 0x4c, 0x45, 0x53, 0x53, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x52, 0x45,
- 0x41, 0x54, 0x45, 0x52, 0x10, 0x03, 0x32, 0xc6, 0x02, 0x0a, 0x09, 0x44, 0x42, 0x4d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x72, 0x12, 0x6a, 0x0a, 0x14, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x62,
- 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x12, 0x29, 0x2e, 0x61,
- 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x52, 0x65, 0x70, 0x6f,
- 0x72, 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67,
- 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x27, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31,
- 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x62, 0x73,
- 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79,
- 0x12, 0x61, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69,
- 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x12, 0x26, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62,
- 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74,
- 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e,
- 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74,
- 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65,
- 0x70, 0x6c, 0x79, 0x12, 0x6a, 0x0a, 0x14, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4f, 0x62, 0x73,
- 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x12, 0x29, 0x2e, 0x61, 0x70,
- 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74,
- 0x65, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52,
- 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x27, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e,
- 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4f, 0x62, 0x73, 0x65,
- 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x32,
- 0xe1, 0x01, 0x0a, 0x0a, 0x53, 0x75, 0x67, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x58,
- 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x53, 0x75, 0x67, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x73,
+ 0x12, 0x18, 0x0a, 0x14, 0x44, 0x49, 0x53, 0x54, 0x52, 0x49, 0x42, 0x55, 0x54, 0x49, 0x4f, 0x4e,
+ 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x04, 0x2a, 0x38, 0x0a, 0x0d, 0x4f, 0x62,
+ 0x6a, 0x65, 0x63, 0x74, 0x69, 0x76, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, 0x07, 0x55,
+ 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x4d, 0x49, 0x4e, 0x49,
+ 0x4d, 0x49, 0x5a, 0x45, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08, 0x4d, 0x41, 0x58, 0x49, 0x4d, 0x49,
+ 0x5a, 0x45, 0x10, 0x02, 0x2a, 0x4a, 0x0a, 0x0e, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x72, 0x69, 0x73,
+ 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x12, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57,
+ 0x4e, 0x5f, 0x43, 0x4f, 0x4d, 0x50, 0x41, 0x52, 0x49, 0x53, 0x4f, 0x4e, 0x10, 0x00, 0x12, 0x09,
+ 0x0a, 0x05, 0x45, 0x51, 0x55, 0x41, 0x4c, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x4c, 0x45, 0x53,
+ 0x53, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x52, 0x45, 0x41, 0x54, 0x45, 0x52, 0x10, 0x03,
+ 0x32, 0xc6, 0x02, 0x0a, 0x09, 0x44, 0x42, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x12, 0x6a,
+ 0x0a, 0x14, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74,
+ 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x12, 0x29, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e,
+ 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x62, 0x73, 0x65,
+ 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x1a, 0x27, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31,
+ 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69,
+ 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x61, 0x0a, 0x11, 0x47, 0x65,
+ 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x12,
+ 0x26, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47,
+ 0x65, 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67,
+ 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31,
+ 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76,
+ 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x6a, 0x0a,
+ 0x14, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69,
+ 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x12, 0x29, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62,
+ 0x65, 0x74, 0x61, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4f, 0x62, 0x73, 0x65, 0x72,
+ 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
+ 0x1a, 0x27, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e,
+ 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f,
+ 0x6e, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x32, 0xe1, 0x01, 0x0a, 0x0a, 0x53, 0x75,
+ 0x67, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x58, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x53,
+ 0x75, 0x67, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x23, 0x2e, 0x61, 0x70, 0x69,
+ 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x53, 0x75, 0x67,
+ 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
+ 0x21, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47,
+ 0x65, 0x74, 0x53, 0x75, 0x67, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x65, 0x70,
+ 0x6c, 0x79, 0x12, 0x79, 0x0a, 0x19, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x6c,
+ 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12,
+ 0x2e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x56,
+ 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d,
+ 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
+ 0x2c, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x56,
+ 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d,
+ 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x32, 0xe0, 0x02,
+ 0x0a, 0x0d, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12,
+ 0x6d, 0x0a, 0x15, 0x47, 0x65, 0x74, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70,
+ 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x2a, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76,
+ 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x61, 0x72, 0x6c, 0x79,
+ 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x71,
+ 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65,
+ 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70,
+ 0x70, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x58,
+ 0x0a, 0x0e, 0x53, 0x65, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73,
0x12, 0x23, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e,
- 0x47, 0x65, 0x74, 0x53, 0x75, 0x67, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x65,
+ 0x53, 0x65, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62,
- 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x53, 0x75, 0x67, 0x67, 0x65, 0x73, 0x74, 0x69,
- 0x6f, 0x6e, 0x73, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x79, 0x0a, 0x19, 0x56, 0x61, 0x6c, 0x69,
- 0x64, 0x61, 0x74, 0x65, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x53, 0x65, 0x74,
- 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x2e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62,
- 0x65, 0x74, 0x61, 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x6c, 0x67,
- 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65,
- 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2c, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62,
- 0x65, 0x74, 0x61, 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x41, 0x6c, 0x67,
- 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65,
- 0x70, 0x6c, 0x79, 0x32, 0xe0, 0x02, 0x0a, 0x0d, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f,
- 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x6d, 0x0a, 0x15, 0x47, 0x65, 0x74, 0x45, 0x61, 0x72, 0x6c,
- 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x2a,
- 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65,
- 0x74, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x75,
- 0x6c, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x61, 0x70, 0x69,
- 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x61, 0x72,
- 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52,
- 0x65, 0x70, 0x6c, 0x79, 0x12, 0x58, 0x0a, 0x0e, 0x53, 0x65, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c,
- 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x23, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e,
- 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x74,
- 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x61, 0x70,
- 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x72,
- 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x85,
- 0x01, 0x0a, 0x1d, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x61, 0x72, 0x6c, 0x79,
- 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73,
- 0x12, 0x32, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e,
- 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f,
- 0x70, 0x70, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71,
- 0x75, 0x65, 0x73, 0x74, 0x1a, 0x30, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65,
- 0x74, 0x61, 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x61, 0x72, 0x6c,
- 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67,
- 0x73, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x42, 0x41, 0x5a, 0x3f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
- 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6b, 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x6b, 0x61,
- 0x74, 0x69, 0x62, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, 0x70, 0x69, 0x73, 0x2f, 0x6d, 0x61, 0x6e,
- 0x61, 0x67, 0x65, 0x72, 0x2f, 0x76, 0x31, 0x62, 0x65, 0x74, 0x61, 0x31, 0x3b, 0x61, 0x70, 0x69,
- 0x5f, 0x76, 0x31, 0x5f, 0x62, 0x65, 0x74, 0x61, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
- 0x33,
+ 0x65, 0x74, 0x61, 0x31, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x72, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61,
+ 0x74, 0x75, 0x73, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x85, 0x01, 0x0a, 0x1d, 0x56, 0x61, 0x6c,
+ 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69,
+ 0x6e, 0x67, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x32, 0x2e, 0x61, 0x70, 0x69,
+ 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61,
+ 0x74, 0x65, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x53,
+ 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x30,
+ 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x76, 0x31, 0x2e, 0x62, 0x65, 0x74, 0x61, 0x31, 0x2e, 0x56, 0x61,
+ 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x61, 0x72, 0x6c, 0x79, 0x53, 0x74, 0x6f, 0x70, 0x70,
+ 0x69, 0x6e, 0x67, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x70, 0x6c, 0x79,
+ 0x42, 0x41, 0x5a, 0x3f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6b,
+ 0x75, 0x62, 0x65, 0x66, 0x6c, 0x6f, 0x77, 0x2f, 0x6b, 0x61, 0x74, 0x69, 0x62, 0x2f, 0x70, 0x6b,
+ 0x67, 0x2f, 0x61, 0x70, 0x69, 0x73, 0x2f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2f, 0x76,
+ 0x31, 0x62, 0x65, 0x74, 0x61, 0x31, 0x3b, 0x61, 0x70, 0x69, 0x5f, 0x76, 0x31, 0x5f, 0x62, 0x65,
+ 0x74, 0x61, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
diff --git a/pkg/apis/manager/v1beta1/api.proto b/pkg/apis/manager/v1beta1/api.proto
index 553531918cc..f3fabf977dc 100644
--- a/pkg/apis/manager/v1beta1/api.proto
+++ b/pkg/apis/manager/v1beta1/api.proto
@@ -105,6 +105,7 @@ enum Distribution {
LOG_UNIFORM = 1;
NORMAL = 2;
LOG_NORMAL = 3;
+ DISTRIBUTION_UNKNOWN = 4;
}
/**
diff --git a/pkg/apis/manager/v1beta1/python/api_pb2.py b/pkg/apis/manager/v1beta1/python/api_pb2.py
index 5ad1e167b6c..3a95a64b82b 100644
--- a/pkg/apis/manager/v1beta1/python/api_pb2.py
+++ b/pkg/apis/manager/v1beta1/python/api_pb2.py
@@ -14,7 +14,7 @@
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tapi.proto\x12\x0c\x61pi.v1.beta1\"R\n\nExperiment\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x30\n\x04spec\x18\x02 \x01(\x0b\x32\x1c.api.v1.beta1.ExperimentSpecR\x04spec\"\x85\x04\n\x0e\x45xperimentSpec\x12T\n\x0fparameter_specs\x18\x01 \x01(\x0b\x32+.api.v1.beta1.ExperimentSpec.ParameterSpecsR\x0eparameterSpecs\x12\x39\n\tobjective\x18\x02 \x01(\x0b\x32\x1b.api.v1.beta1.ObjectiveSpecR\tobjective\x12\x39\n\talgorithm\x18\x03 \x01(\x0b\x32\x1b.api.v1.beta1.AlgorithmSpecR\talgorithm\x12\x46\n\x0e\x65\x61rly_stopping\x18\x04 \x01(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingSpecR\rearlyStopping\x12\x30\n\x14parallel_trial_count\x18\x05 \x01(\x05R\x12parallelTrialCount\x12&\n\x0fmax_trial_count\x18\x06 \x01(\x05R\rmaxTrialCount\x12\x36\n\nnas_config\x18\x07 \x01(\x0b\x32\x17.api.v1.beta1.NasConfigR\tnasConfig\x1aM\n\x0eParameterSpecs\x12;\n\nparameters\x18\x01 \x03(\x0b\x32\x1b.api.v1.beta1.ParameterSpecR\nparameters\"\xab\x01\n\rParameterSpec\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x42\n\x0eparameter_type\x18\x02 \x01(\x0e\x32\x1b.api.v1.beta1.ParameterTypeR\rparameterType\x12\x42\n\x0e\x66\x65\x61sible_space\x18\x03 \x01(\x0b\x32\x1b.api.v1.beta1.FeasibleSpaceR\rfeasibleSpace\"\x9b\x01\n\rFeasibleSpace\x12\x10\n\x03max\x18\x01 \x01(\tR\x03max\x12\x10\n\x03min\x18\x02 \x01(\tR\x03min\x12\x12\n\x04list\x18\x03 \x03(\tR\x04list\x12\x12\n\x04step\x18\x04 \x01(\tR\x04step\x12>\n\x0c\x64istribution\x18\x05 \x01(\x0e\x32\x1a.api.v1.beta1.DistributionR\x0c\x64istribution\"\xc0\x01\n\rObjectiveSpec\x12/\n\x04type\x18\x01 \x01(\x0e\x32\x1b.api.v1.beta1.ObjectiveTypeR\x04type\x12\x12\n\x04goal\x18\x02 \x01(\x01R\x04goal\x12\x32\n\x15objective_metric_name\x18\x03 \x01(\tR\x13objectiveMetricName\x12\x36\n\x17\x61\x64\x64itional_metric_names\x18\x04 \x03(\tR\x15\x61\x64\x64itionalMetricNames\"\x85\x01\n\rAlgorithmSpec\x12%\n\x0e\x61lgorithm_name\x18\x01 \x01(\tR\ralgorithmName\x12M\n\x12\x61lgorithm_settings\x18\x02 \x03(\x0b\x32\x1e.api.v1.beta1.AlgorithmSettingR\x11\x61lgorithmSettings\"<\n\x10\x41lgorithmSetting\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\x8d\x01\n\x11\x45\x61rlyStoppingSpec\x12%\n\x0e\x61lgorithm_name\x18\x01 \x01(\tR\ralgorithmName\x12Q\n\x12\x61lgorithm_settings\x18\x02 \x03(\x0b\x32\".api.v1.beta1.EarlyStoppingSettingR\x11\x61lgorithmSettings\"@\n\x14\x45\x61rlyStoppingSetting\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\xd2\x01\n\tNasConfig\x12<\n\x0cgraph_config\x18\x01 \x01(\x0b\x32\x19.api.v1.beta1.GraphConfigR\x0bgraphConfig\x12\x42\n\noperations\x18\x02 \x01(\x0b\x32\".api.v1.beta1.NasConfig.OperationsR\noperations\x1a\x43\n\nOperations\x12\x35\n\toperation\x18\x01 \x03(\x0b\x32\x17.api.v1.beta1.OperationR\toperation\"p\n\x0bGraphConfig\x12\x1d\n\nnum_layers\x18\x01 \x01(\x05R\tnumLayers\x12\x1f\n\x0binput_sizes\x18\x02 \x03(\x05R\ninputSizes\x12!\n\x0coutput_sizes\x18\x03 \x03(\x05R\x0boutputSizes\"\xd2\x01\n\tOperation\x12%\n\x0eoperation_type\x18\x01 \x01(\tR\roperationType\x12O\n\x0fparameter_specs\x18\x02 \x01(\x0b\x32&.api.v1.beta1.Operation.ParameterSpecsR\x0eparameterSpecs\x1aM\n\x0eParameterSpecs\x12;\n\nparameters\x18\x01 \x03(\x0b\x32\x1b.api.v1.beta1.ParameterSpecR\nparameters\"{\n\x05Trial\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12+\n\x04spec\x18\x02 \x01(\x0b\x32\x17.api.v1.beta1.TrialSpecR\x04spec\x12\x31\n\x06status\x18\x03 \x01(\x0b\x32\x19.api.v1.beta1.TrialStatusR\x06status\"\xfe\x02\n\tTrialSpec\x12\x39\n\tobjective\x18\x02 \x01(\x0b\x32\x1b.api.v1.beta1.ObjectiveSpecR\tobjective\x12\x61\n\x15parameter_assignments\x18\x03 \x01(\x0b\x32,.api.v1.beta1.TrialSpec.ParameterAssignmentsR\x14parameterAssignments\x12;\n\x06labels\x18\x04 \x03(\x0b\x32#.api.v1.beta1.TrialSpec.LabelsEntryR\x06labels\x1a[\n\x14ParameterAssignments\x12\x43\n\x0b\x61ssignments\x18\x01 \x03(\x0b\x32!.api.v1.beta1.ParameterAssignmentR\x0b\x61ssignments\x1a\x39\n\x0bLabelsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"?\n\x13ParameterAssignment\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\xed\x02\n\x0bTrialStatus\x12\x1d\n\nstart_time\x18\x01 \x01(\tR\tstartTime\x12\'\n\x0f\x63ompletion_time\x18\x02 \x01(\tR\x0e\x63ompletionTime\x12J\n\tcondition\x18\x03 \x01(\x0e\x32,.api.v1.beta1.TrialStatus.TrialConditionTypeR\tcondition\x12;\n\x0bobservation\x18\x04 \x01(\x0b\x32\x19.api.v1.beta1.ObservationR\x0bobservation\"\x8c\x01\n\x12TrialConditionType\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06KILLED\x10\x03\x12\n\n\x06\x46\x41ILED\x10\x04\x12\x16\n\x12METRICSUNAVAILABLE\x10\x05\x12\x10\n\x0c\x45\x41RLYSTOPPED\x10\x06\x12\x0b\n\x07UNKNOWN\x10\x07\"=\n\x0bObservation\x12.\n\x07metrics\x18\x01 \x03(\x0b\x32\x14.api.v1.beta1.MetricR\x07metrics\"2\n\x06Metric\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\x83\x01\n\x1bReportObservationLogRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\x12\x45\n\x0fobservation_log\x18\x02 \x01(\x0b\x32\x1c.api.v1.beta1.ObservationLogR\x0eobservationLog\"\x1b\n\x19ReportObservationLogReply\"J\n\x0eObservationLog\x12\x38\n\x0bmetric_logs\x18\x01 \x03(\x0b\x32\x17.api.v1.beta1.MetricLogR\nmetricLogs\"X\n\tMetricLog\x12\x1d\n\ntime_stamp\x18\x01 \x01(\tR\ttimeStamp\x12,\n\x06metric\x18\x02 \x01(\x0b\x32\x14.api.v1.beta1.MetricR\x06metric\"\x94\x01\n\x18GetObservationLogRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\x12\x1f\n\x0bmetric_name\x18\x02 \x01(\tR\nmetricName\x12\x1d\n\nstart_time\x18\x03 \x01(\tR\tstartTime\x12\x19\n\x08\x65nd_time\x18\x04 \x01(\tR\x07\x65ndTime\"_\n\x16GetObservationLogReply\x12\x45\n\x0fobservation_log\x18\x01 \x01(\x0b\x32\x1c.api.v1.beta1.ObservationLogR\x0eobservationLog\"<\n\x1b\x44\x65leteObservationLogRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\"\x1b\n\x19\x44\x65leteObservationLogReply\"\xe6\x01\n\x15GetSuggestionsRequest\x12\x38\n\nexperiment\x18\x01 \x01(\x0b\x32\x18.api.v1.beta1.ExperimentR\nexperiment\x12+\n\x06trials\x18\x02 \x03(\x0b\x32\x13.api.v1.beta1.TrialR\x06trials\x12\x34\n\x16\x63urrent_request_number\x18\x04 \x01(\x05R\x14\x63urrentRequestNumber\x12\x30\n\x14total_request_number\x18\x05 \x01(\x05R\x12totalRequestNumber\"\xa4\x04\n\x13GetSuggestionsReply\x12k\n\x15parameter_assignments\x18\x01 \x03(\x0b\x32\x36.api.v1.beta1.GetSuggestionsReply.ParameterAssignmentsR\x14parameterAssignments\x12\x39\n\talgorithm\x18\x02 \x01(\x0b\x32\x1b.api.v1.beta1.AlgorithmSpecR\talgorithm\x12Q\n\x14\x65\x61rly_stopping_rules\x18\x03 \x03(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingRuleR\x12\x65\x61rlyStoppingRules\x1a\x91\x02\n\x14ParameterAssignments\x12\x43\n\x0b\x61ssignments\x18\x01 \x03(\x0b\x32!.api.v1.beta1.ParameterAssignmentR\x0b\x61ssignments\x12\x1d\n\ntrial_name\x18\x02 \x01(\tR\ttrialName\x12Z\n\x06labels\x18\x03 \x03(\x0b\x32\x42.api.v1.beta1.GetSuggestionsReply.ParameterAssignments.LabelsEntryR\x06labels\x1a\x39\n\x0bLabelsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\\\n ValidateAlgorithmSettingsRequest\x12\x38\n\nexperiment\x18\x01 \x01(\x0b\x32\x18.api.v1.beta1.ExperimentR\nexperiment\" \n\x1eValidateAlgorithmSettingsReply\"\xb3\x01\n\x1cGetEarlyStoppingRulesRequest\x12\x38\n\nexperiment\x18\x01 \x01(\x0b\x32\x18.api.v1.beta1.ExperimentR\nexperiment\x12+\n\x06trials\x18\x02 \x03(\x0b\x32\x13.api.v1.beta1.TrialR\x06trials\x12,\n\x12\x64\x62_manager_address\x18\x03 \x01(\tR\x10\x64\x62ManagerAddress\"o\n\x1aGetEarlyStoppingRulesReply\x12Q\n\x14\x65\x61rly_stopping_rules\x18\x01 \x03(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingRuleR\x12\x65\x61rlyStoppingRules\"\x9a\x01\n\x11\x45\x61rlyStoppingRule\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\x12<\n\ncomparison\x18\x03 \x01(\x0e\x32\x1c.api.v1.beta1.ComparisonTypeR\ncomparison\x12\x1d\n\nstart_step\x18\x04 \x01(\x05R\tstartStep\"n\n$ValidateEarlyStoppingSettingsRequest\x12\x46\n\x0e\x65\x61rly_stopping\x18\x01 \x01(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingSpecR\rearlyStopping\"$\n\"ValidateEarlyStoppingSettingsReply\"6\n\x15SetTrialStatusRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\"\x15\n\x13SetTrialStatusReply*U\n\rParameterType\x12\x10\n\x0cUNKNOWN_TYPE\x10\x00\x12\n\n\x06\x44OUBLE\x10\x01\x12\x07\n\x03INT\x10\x02\x12\x0c\n\x08\x44ISCRETE\x10\x03\x12\x0f\n\x0b\x43\x41TEGORICAL\x10\x04*H\n\x0c\x44istribution\x12\x0b\n\x07UNIFORM\x10\x00\x12\x0f\n\x0bLOG_UNIFORM\x10\x01\x12\n\n\x06NORMAL\x10\x02\x12\x0e\n\nLOG_NORMAL\x10\x03*8\n\rObjectiveType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0c\n\x08MINIMIZE\x10\x01\x12\x0c\n\x08MAXIMIZE\x10\x02*J\n\x0e\x43omparisonType\x12\x16\n\x12UNKNOWN_COMPARISON\x10\x00\x12\t\n\x05\x45QUAL\x10\x01\x12\x08\n\x04LESS\x10\x02\x12\x0b\n\x07GREATER\x10\x03\x32\xc6\x02\n\tDBManager\x12j\n\x14ReportObservationLog\x12).api.v1.beta1.ReportObservationLogRequest\x1a\'.api.v1.beta1.ReportObservationLogReply\x12\x61\n\x11GetObservationLog\x12&.api.v1.beta1.GetObservationLogRequest\x1a$.api.v1.beta1.GetObservationLogReply\x12j\n\x14\x44\x65leteObservationLog\x12).api.v1.beta1.DeleteObservationLogRequest\x1a\'.api.v1.beta1.DeleteObservationLogReply2\xe1\x01\n\nSuggestion\x12X\n\x0eGetSuggestions\x12#.api.v1.beta1.GetSuggestionsRequest\x1a!.api.v1.beta1.GetSuggestionsReply\x12y\n\x19ValidateAlgorithmSettings\x12..api.v1.beta1.ValidateAlgorithmSettingsRequest\x1a,.api.v1.beta1.ValidateAlgorithmSettingsReply2\xe0\x02\n\rEarlyStopping\x12m\n\x15GetEarlyStoppingRules\x12*.api.v1.beta1.GetEarlyStoppingRulesRequest\x1a(.api.v1.beta1.GetEarlyStoppingRulesReply\x12X\n\x0eSetTrialStatus\x12#.api.v1.beta1.SetTrialStatusRequest\x1a!.api.v1.beta1.SetTrialStatusReply\x12\x85\x01\n\x1dValidateEarlyStoppingSettings\x12\x32.api.v1.beta1.ValidateEarlyStoppingSettingsRequest\x1a\x30.api.v1.beta1.ValidateEarlyStoppingSettingsReplyBAZ?github.com/kubeflow/katib/pkg/apis/manager/v1beta1;api_v1_beta1b\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tapi.proto\x12\x0c\x61pi.v1.beta1\"R\n\nExperiment\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x30\n\x04spec\x18\x02 \x01(\x0b\x32\x1c.api.v1.beta1.ExperimentSpecR\x04spec\"\x85\x04\n\x0e\x45xperimentSpec\x12T\n\x0fparameter_specs\x18\x01 \x01(\x0b\x32+.api.v1.beta1.ExperimentSpec.ParameterSpecsR\x0eparameterSpecs\x12\x39\n\tobjective\x18\x02 \x01(\x0b\x32\x1b.api.v1.beta1.ObjectiveSpecR\tobjective\x12\x39\n\talgorithm\x18\x03 \x01(\x0b\x32\x1b.api.v1.beta1.AlgorithmSpecR\talgorithm\x12\x46\n\x0e\x65\x61rly_stopping\x18\x04 \x01(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingSpecR\rearlyStopping\x12\x30\n\x14parallel_trial_count\x18\x05 \x01(\x05R\x12parallelTrialCount\x12&\n\x0fmax_trial_count\x18\x06 \x01(\x05R\rmaxTrialCount\x12\x36\n\nnas_config\x18\x07 \x01(\x0b\x32\x17.api.v1.beta1.NasConfigR\tnasConfig\x1aM\n\x0eParameterSpecs\x12;\n\nparameters\x18\x01 \x03(\x0b\x32\x1b.api.v1.beta1.ParameterSpecR\nparameters\"\xab\x01\n\rParameterSpec\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x42\n\x0eparameter_type\x18\x02 \x01(\x0e\x32\x1b.api.v1.beta1.ParameterTypeR\rparameterType\x12\x42\n\x0e\x66\x65\x61sible_space\x18\x03 \x01(\x0b\x32\x1b.api.v1.beta1.FeasibleSpaceR\rfeasibleSpace\"\x9b\x01\n\rFeasibleSpace\x12\x10\n\x03max\x18\x01 \x01(\tR\x03max\x12\x10\n\x03min\x18\x02 \x01(\tR\x03min\x12\x12\n\x04list\x18\x03 \x03(\tR\x04list\x12\x12\n\x04step\x18\x04 \x01(\tR\x04step\x12>\n\x0c\x64istribution\x18\x05 \x01(\x0e\x32\x1a.api.v1.beta1.DistributionR\x0c\x64istribution\"\xc0\x01\n\rObjectiveSpec\x12/\n\x04type\x18\x01 \x01(\x0e\x32\x1b.api.v1.beta1.ObjectiveTypeR\x04type\x12\x12\n\x04goal\x18\x02 \x01(\x01R\x04goal\x12\x32\n\x15objective_metric_name\x18\x03 \x01(\tR\x13objectiveMetricName\x12\x36\n\x17\x61\x64\x64itional_metric_names\x18\x04 \x03(\tR\x15\x61\x64\x64itionalMetricNames\"\x85\x01\n\rAlgorithmSpec\x12%\n\x0e\x61lgorithm_name\x18\x01 \x01(\tR\ralgorithmName\x12M\n\x12\x61lgorithm_settings\x18\x02 \x03(\x0b\x32\x1e.api.v1.beta1.AlgorithmSettingR\x11\x61lgorithmSettings\"<\n\x10\x41lgorithmSetting\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\x8d\x01\n\x11\x45\x61rlyStoppingSpec\x12%\n\x0e\x61lgorithm_name\x18\x01 \x01(\tR\ralgorithmName\x12Q\n\x12\x61lgorithm_settings\x18\x02 \x03(\x0b\x32\".api.v1.beta1.EarlyStoppingSettingR\x11\x61lgorithmSettings\"@\n\x14\x45\x61rlyStoppingSetting\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\xd2\x01\n\tNasConfig\x12<\n\x0cgraph_config\x18\x01 \x01(\x0b\x32\x19.api.v1.beta1.GraphConfigR\x0bgraphConfig\x12\x42\n\noperations\x18\x02 \x01(\x0b\x32\".api.v1.beta1.NasConfig.OperationsR\noperations\x1a\x43\n\nOperations\x12\x35\n\toperation\x18\x01 \x03(\x0b\x32\x17.api.v1.beta1.OperationR\toperation\"p\n\x0bGraphConfig\x12\x1d\n\nnum_layers\x18\x01 \x01(\x05R\tnumLayers\x12\x1f\n\x0binput_sizes\x18\x02 \x03(\x05R\ninputSizes\x12!\n\x0coutput_sizes\x18\x03 \x03(\x05R\x0boutputSizes\"\xd2\x01\n\tOperation\x12%\n\x0eoperation_type\x18\x01 \x01(\tR\roperationType\x12O\n\x0fparameter_specs\x18\x02 \x01(\x0b\x32&.api.v1.beta1.Operation.ParameterSpecsR\x0eparameterSpecs\x1aM\n\x0eParameterSpecs\x12;\n\nparameters\x18\x01 \x03(\x0b\x32\x1b.api.v1.beta1.ParameterSpecR\nparameters\"{\n\x05Trial\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12+\n\x04spec\x18\x02 \x01(\x0b\x32\x17.api.v1.beta1.TrialSpecR\x04spec\x12\x31\n\x06status\x18\x03 \x01(\x0b\x32\x19.api.v1.beta1.TrialStatusR\x06status\"\xfe\x02\n\tTrialSpec\x12\x39\n\tobjective\x18\x02 \x01(\x0b\x32\x1b.api.v1.beta1.ObjectiveSpecR\tobjective\x12\x61\n\x15parameter_assignments\x18\x03 \x01(\x0b\x32,.api.v1.beta1.TrialSpec.ParameterAssignmentsR\x14parameterAssignments\x12;\n\x06labels\x18\x04 \x03(\x0b\x32#.api.v1.beta1.TrialSpec.LabelsEntryR\x06labels\x1a[\n\x14ParameterAssignments\x12\x43\n\x0b\x61ssignments\x18\x01 \x03(\x0b\x32!.api.v1.beta1.ParameterAssignmentR\x0b\x61ssignments\x1a\x39\n\x0bLabelsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"?\n\x13ParameterAssignment\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\xed\x02\n\x0bTrialStatus\x12\x1d\n\nstart_time\x18\x01 \x01(\tR\tstartTime\x12\'\n\x0f\x63ompletion_time\x18\x02 \x01(\tR\x0e\x63ompletionTime\x12J\n\tcondition\x18\x03 \x01(\x0e\x32,.api.v1.beta1.TrialStatus.TrialConditionTypeR\tcondition\x12;\n\x0bobservation\x18\x04 \x01(\x0b\x32\x19.api.v1.beta1.ObservationR\x0bobservation\"\x8c\x01\n\x12TrialConditionType\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06KILLED\x10\x03\x12\n\n\x06\x46\x41ILED\x10\x04\x12\x16\n\x12METRICSUNAVAILABLE\x10\x05\x12\x10\n\x0c\x45\x41RLYSTOPPED\x10\x06\x12\x0b\n\x07UNKNOWN\x10\x07\"=\n\x0bObservation\x12.\n\x07metrics\x18\x01 \x03(\x0b\x32\x14.api.v1.beta1.MetricR\x07metrics\"2\n\x06Metric\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\"\x83\x01\n\x1bReportObservationLogRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\x12\x45\n\x0fobservation_log\x18\x02 \x01(\x0b\x32\x1c.api.v1.beta1.ObservationLogR\x0eobservationLog\"\x1b\n\x19ReportObservationLogReply\"J\n\x0eObservationLog\x12\x38\n\x0bmetric_logs\x18\x01 \x03(\x0b\x32\x17.api.v1.beta1.MetricLogR\nmetricLogs\"X\n\tMetricLog\x12\x1d\n\ntime_stamp\x18\x01 \x01(\tR\ttimeStamp\x12,\n\x06metric\x18\x02 \x01(\x0b\x32\x14.api.v1.beta1.MetricR\x06metric\"\x94\x01\n\x18GetObservationLogRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\x12\x1f\n\x0bmetric_name\x18\x02 \x01(\tR\nmetricName\x12\x1d\n\nstart_time\x18\x03 \x01(\tR\tstartTime\x12\x19\n\x08\x65nd_time\x18\x04 \x01(\tR\x07\x65ndTime\"_\n\x16GetObservationLogReply\x12\x45\n\x0fobservation_log\x18\x01 \x01(\x0b\x32\x1c.api.v1.beta1.ObservationLogR\x0eobservationLog\"<\n\x1b\x44\x65leteObservationLogRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\"\x1b\n\x19\x44\x65leteObservationLogReply\"\xe6\x01\n\x15GetSuggestionsRequest\x12\x38\n\nexperiment\x18\x01 \x01(\x0b\x32\x18.api.v1.beta1.ExperimentR\nexperiment\x12+\n\x06trials\x18\x02 \x03(\x0b\x32\x13.api.v1.beta1.TrialR\x06trials\x12\x34\n\x16\x63urrent_request_number\x18\x04 \x01(\x05R\x14\x63urrentRequestNumber\x12\x30\n\x14total_request_number\x18\x05 \x01(\x05R\x12totalRequestNumber\"\xa4\x04\n\x13GetSuggestionsReply\x12k\n\x15parameter_assignments\x18\x01 \x03(\x0b\x32\x36.api.v1.beta1.GetSuggestionsReply.ParameterAssignmentsR\x14parameterAssignments\x12\x39\n\talgorithm\x18\x02 \x01(\x0b\x32\x1b.api.v1.beta1.AlgorithmSpecR\talgorithm\x12Q\n\x14\x65\x61rly_stopping_rules\x18\x03 \x03(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingRuleR\x12\x65\x61rlyStoppingRules\x1a\x91\x02\n\x14ParameterAssignments\x12\x43\n\x0b\x61ssignments\x18\x01 \x03(\x0b\x32!.api.v1.beta1.ParameterAssignmentR\x0b\x61ssignments\x12\x1d\n\ntrial_name\x18\x02 \x01(\tR\ttrialName\x12Z\n\x06labels\x18\x03 \x03(\x0b\x32\x42.api.v1.beta1.GetSuggestionsReply.ParameterAssignments.LabelsEntryR\x06labels\x1a\x39\n\x0bLabelsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\\\n ValidateAlgorithmSettingsRequest\x12\x38\n\nexperiment\x18\x01 \x01(\x0b\x32\x18.api.v1.beta1.ExperimentR\nexperiment\" \n\x1eValidateAlgorithmSettingsReply\"\xb3\x01\n\x1cGetEarlyStoppingRulesRequest\x12\x38\n\nexperiment\x18\x01 \x01(\x0b\x32\x18.api.v1.beta1.ExperimentR\nexperiment\x12+\n\x06trials\x18\x02 \x03(\x0b\x32\x13.api.v1.beta1.TrialR\x06trials\x12,\n\x12\x64\x62_manager_address\x18\x03 \x01(\tR\x10\x64\x62ManagerAddress\"o\n\x1aGetEarlyStoppingRulesReply\x12Q\n\x14\x65\x61rly_stopping_rules\x18\x01 \x03(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingRuleR\x12\x65\x61rlyStoppingRules\"\x9a\x01\n\x11\x45\x61rlyStoppingRule\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\x12<\n\ncomparison\x18\x03 \x01(\x0e\x32\x1c.api.v1.beta1.ComparisonTypeR\ncomparison\x12\x1d\n\nstart_step\x18\x04 \x01(\x05R\tstartStep\"n\n$ValidateEarlyStoppingSettingsRequest\x12\x46\n\x0e\x65\x61rly_stopping\x18\x01 \x01(\x0b\x32\x1f.api.v1.beta1.EarlyStoppingSpecR\rearlyStopping\"$\n\"ValidateEarlyStoppingSettingsReply\"6\n\x15SetTrialStatusRequest\x12\x1d\n\ntrial_name\x18\x01 \x01(\tR\ttrialName\"\x15\n\x13SetTrialStatusReply*U\n\rParameterType\x12\x10\n\x0cUNKNOWN_TYPE\x10\x00\x12\n\n\x06\x44OUBLE\x10\x01\x12\x07\n\x03INT\x10\x02\x12\x0c\n\x08\x44ISCRETE\x10\x03\x12\x0f\n\x0b\x43\x41TEGORICAL\x10\x04*b\n\x0c\x44istribution\x12\x0b\n\x07UNIFORM\x10\x00\x12\x0f\n\x0bLOG_UNIFORM\x10\x01\x12\n\n\x06NORMAL\x10\x02\x12\x0e\n\nLOG_NORMAL\x10\x03\x12\x18\n\x14\x44ISTRIBUTION_UNKNOWN\x10\x04*8\n\rObjectiveType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0c\n\x08MINIMIZE\x10\x01\x12\x0c\n\x08MAXIMIZE\x10\x02*J\n\x0e\x43omparisonType\x12\x16\n\x12UNKNOWN_COMPARISON\x10\x00\x12\t\n\x05\x45QUAL\x10\x01\x12\x08\n\x04LESS\x10\x02\x12\x0b\n\x07GREATER\x10\x03\x32\xc6\x02\n\tDBManager\x12j\n\x14ReportObservationLog\x12).api.v1.beta1.ReportObservationLogRequest\x1a\'.api.v1.beta1.ReportObservationLogReply\x12\x61\n\x11GetObservationLog\x12&.api.v1.beta1.GetObservationLogRequest\x1a$.api.v1.beta1.GetObservationLogReply\x12j\n\x14\x44\x65leteObservationLog\x12).api.v1.beta1.DeleteObservationLogRequest\x1a\'.api.v1.beta1.DeleteObservationLogReply2\xe1\x01\n\nSuggestion\x12X\n\x0eGetSuggestions\x12#.api.v1.beta1.GetSuggestionsRequest\x1a!.api.v1.beta1.GetSuggestionsReply\x12y\n\x19ValidateAlgorithmSettings\x12..api.v1.beta1.ValidateAlgorithmSettingsRequest\x1a,.api.v1.beta1.ValidateAlgorithmSettingsReply2\xe0\x02\n\rEarlyStopping\x12m\n\x15GetEarlyStoppingRules\x12*.api.v1.beta1.GetEarlyStoppingRulesRequest\x1a(.api.v1.beta1.GetEarlyStoppingRulesReply\x12X\n\x0eSetTrialStatus\x12#.api.v1.beta1.SetTrialStatusRequest\x1a!.api.v1.beta1.SetTrialStatusReply\x12\x85\x01\n\x1dValidateEarlyStoppingSettings\x12\x32.api.v1.beta1.ValidateEarlyStoppingSettingsRequest\x1a\x30.api.v1.beta1.ValidateEarlyStoppingSettingsReplyBAZ?github.com/kubeflow/katib/pkg/apis/manager/v1beta1;api_v1_beta1b\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -29,11 +29,11 @@
_globals['_PARAMETERTYPE']._serialized_start=5425
_globals['_PARAMETERTYPE']._serialized_end=5510
_globals['_DISTRIBUTION']._serialized_start=5512
- _globals['_DISTRIBUTION']._serialized_end=5584
- _globals['_OBJECTIVETYPE']._serialized_start=5586
- _globals['_OBJECTIVETYPE']._serialized_end=5642
- _globals['_COMPARISONTYPE']._serialized_start=5644
- _globals['_COMPARISONTYPE']._serialized_end=5718
+ _globals['_DISTRIBUTION']._serialized_end=5610
+ _globals['_OBJECTIVETYPE']._serialized_start=5612
+ _globals['_OBJECTIVETYPE']._serialized_end=5668
+ _globals['_COMPARISONTYPE']._serialized_start=5670
+ _globals['_COMPARISONTYPE']._serialized_end=5744
_globals['_EXPERIMENT']._serialized_start=27
_globals['_EXPERIMENT']._serialized_end=109
_globals['_EXPERIMENTSPEC']._serialized_start=112
@@ -124,10 +124,10 @@
_globals['_SETTRIALSTATUSREQUEST']._serialized_end=5400
_globals['_SETTRIALSTATUSREPLY']._serialized_start=5402
_globals['_SETTRIALSTATUSREPLY']._serialized_end=5423
- _globals['_DBMANAGER']._serialized_start=5721
- _globals['_DBMANAGER']._serialized_end=6047
- _globals['_SUGGESTION']._serialized_start=6050
- _globals['_SUGGESTION']._serialized_end=6275
- _globals['_EARLYSTOPPING']._serialized_start=6278
- _globals['_EARLYSTOPPING']._serialized_end=6630
+ _globals['_DBMANAGER']._serialized_start=5747
+ _globals['_DBMANAGER']._serialized_end=6073
+ _globals['_SUGGESTION']._serialized_start=6076
+ _globals['_SUGGESTION']._serialized_end=6301
+ _globals['_EARLYSTOPPING']._serialized_start=6304
+ _globals['_EARLYSTOPPING']._serialized_end=6656
# @@protoc_insertion_point(module_scope)
diff --git a/pkg/apis/manager/v1beta1/python/api_pb2.pyi b/pkg/apis/manager/v1beta1/python/api_pb2.pyi
index 47976f7059a..21cbb161d77 100644
--- a/pkg/apis/manager/v1beta1/python/api_pb2.pyi
+++ b/pkg/apis/manager/v1beta1/python/api_pb2.pyi
@@ -20,6 +20,7 @@ class Distribution(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
LOG_UNIFORM: _ClassVar[Distribution]
NORMAL: _ClassVar[Distribution]
LOG_NORMAL: _ClassVar[Distribution]
+ DISTRIBUTION_UNKNOWN: _ClassVar[Distribution]
class ObjectiveType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
@@ -42,6 +43,7 @@ UNIFORM: Distribution
LOG_UNIFORM: Distribution
NORMAL: Distribution
LOG_NORMAL: Distribution
+DISTRIBUTION_UNKNOWN: Distribution
UNKNOWN: ObjectiveType
MINIMIZE: ObjectiveType
MAXIMIZE: ObjectiveType
diff --git a/pkg/apis/v1beta1/openapi_generated.go b/pkg/apis/v1beta1/openapi_generated.go
index de764c3d5ce..503dbc64b33 100644
--- a/pkg/apis/v1beta1/openapi_generated.go
+++ b/pkg/apis/v1beta1/openapi_generated.go
@@ -1090,6 +1090,12 @@ func schema_apis_controller_experiments_v1beta1_FeasibleSpace(ref common.Referen
Format: "",
},
},
+ "distribution": {
+ SchemaProps: spec.SchemaProps{
+ Type: []string{"string"},
+ Format: "",
+ },
+ },
},
},
},
diff --git a/pkg/apis/v1beta1/swagger.json b/pkg/apis/v1beta1/swagger.json
index 2c2ab54f76e..6bf22ab0433 100644
--- a/pkg/apis/v1beta1/swagger.json
+++ b/pkg/apis/v1beta1/swagger.json
@@ -781,6 +781,9 @@
"v1beta1.FeasibleSpace": {
"type": "object",
"properties": {
+ "distribution": {
+ "type": "string"
+ },
"list": {
"type": "array",
"items": {
diff --git a/pkg/client/controller/applyconfiguration/experiments/v1beta1/feasiblespace.go b/pkg/client/controller/applyconfiguration/experiments/v1beta1/feasiblespace.go
index 0026c91b45b..5c01e8aaf1e 100644
--- a/pkg/client/controller/applyconfiguration/experiments/v1beta1/feasiblespace.go
+++ b/pkg/client/controller/applyconfiguration/experiments/v1beta1/feasiblespace.go
@@ -18,13 +18,18 @@ limitations under the License.
package v1beta1
+import (
+ v1beta1 "github.com/kubeflow/katib/pkg/apis/controller/experiments/v1beta1"
+)
+
// FeasibleSpaceApplyConfiguration represents an declarative configuration of the FeasibleSpace type for use
// with apply.
type FeasibleSpaceApplyConfiguration struct {
- Max *string `json:"max,omitempty"`
- Min *string `json:"min,omitempty"`
- List []string `json:"list,omitempty"`
- Step *string `json:"step,omitempty"`
+ Max *string `json:"max,omitempty"`
+ Min *string `json:"min,omitempty"`
+ List []string `json:"list,omitempty"`
+ Step *string `json:"step,omitempty"`
+ Distribution *v1beta1.Distribution `json:"distribution,omitempty"`
}
// FeasibleSpaceApplyConfiguration constructs an declarative configuration of the FeasibleSpace type for use with
@@ -66,3 +71,11 @@ func (b *FeasibleSpaceApplyConfiguration) WithStep(value string) *FeasibleSpaceA
b.Step = &value
return b
}
+
+// WithDistribution sets the Distribution field in the declarative configuration to the given value
+// and returns the receiver, so that objects can be built by chaining "With" function invocations.
+// If called multiple times, the Distribution field is set to the value of the last call.
+func (b *FeasibleSpaceApplyConfiguration) WithDistribution(value v1beta1.Distribution) *FeasibleSpaceApplyConfiguration {
+ b.Distribution = &value
+ return b
+}
diff --git a/pkg/controller.v1beta1/experiment/manifest/generator.go b/pkg/controller.v1beta1/experiment/manifest/generator.go
index a41cc2f0c89..c3d0a2a14bc 100644
--- a/pkg/controller.v1beta1/experiment/manifest/generator.go
+++ b/pkg/controller.v1beta1/experiment/manifest/generator.go
@@ -17,6 +17,7 @@ limitations under the License.
package manifest
import (
+ "errors"
"fmt"
"regexp"
"strings"
@@ -33,6 +34,15 @@ import (
"github.com/kubeflow/katib/pkg/util/v1beta1/katibconfig"
)
+var (
+ errConfigMapNotFound = errors.New("configMap not found")
+ errConvertStringToUnstructuredFailed = errors.New("failed to convert string to unstructured")
+ errConvertUnstructuredToStringFailed = errors.New("failed to convert unstructured to string")
+ errParamNotFoundInParameterAssignment = errors.New("unable to find non-meta parameter from TrialParameters in ParameterAssignment")
+ errParamNotFoundInTrialParameters = errors.New("unable to find parameter from ParameterAssignment in TrialParameters")
+ errTrialTemplateNotFound = errors.New("unable to find trial template in ConfigMap")
+)
+
// Generator is the type for manifests Generator.
type Generator interface {
InjectClient(c client.Client)
@@ -86,7 +96,7 @@ func (g *DefaultGenerator) GetRunSpecWithHyperParameters(experiment *experiments
// Convert Trial template to unstructured
runSpec, err := util.ConvertStringToUnstructured(replacedTemplate)
if err != nil {
- return nil, fmt.Errorf("ConvertStringToUnstructured failed: %v", err)
+ return nil, fmt.Errorf("%w: %w", errConvertStringToUnstructuredFailed, err)
}
// Set name and namespace for Run Spec
@@ -108,7 +118,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi
if trialSpec == nil {
trialSpec, err = util.ConvertStringToUnstructured(trialTemplate)
if err != nil {
- return "", fmt.Errorf("ConvertStringToUnstructured failed: %v", err)
+ return "", fmt.Errorf("%w: %w", errConvertStringToUnstructuredFailed, err)
}
}
@@ -131,7 +141,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi
nonMetaParamCount += 1
continue
} else {
- return "", fmt.Errorf("Unable to find parameter: %v in parameter assignment %v", param.Reference, assignmentsMap)
+ return "", fmt.Errorf("%w: parameter: %v, parameter assignment: %v", errParamNotFoundInParameterAssignment, param.Reference, assignmentsMap)
}
}
metaRefKey = sub[1]
@@ -172,9 +182,10 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi
}
}
- // Number of parameters must be equal
+ // Number of assignment parameters must be equal to the number of non-meta trial parameters
+ // i.e. all parameters in ParameterAssignment must be in TrialParameters
if len(assignments) != nonMetaParamCount {
- return "", fmt.Errorf("Number of TrialAssignment: %v != number of nonMetaTrialParameters in TrialSpec: %v", len(assignments), nonMetaParamCount)
+ return "", fmt.Errorf("%w: parameter assignments: %v, non-meta trial parameter count: %v", errParamNotFoundInTrialParameters, assignments, nonMetaParamCount)
}
// Replacing placeholders with parameter values
@@ -194,7 +205,7 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim
if trialSource.TrialSpec != nil {
trialTemplateString, err = util.ConvertUnstructuredToString(trialSource.TrialSpec)
if err != nil {
- return "", fmt.Errorf("ConvertUnstructuredToString failed: %v", err)
+ return "", fmt.Errorf("%w: %w", errConvertUnstructuredToStringFailed, err)
}
} else {
configMapNS := trialSource.ConfigMap.ConfigMapNamespace
@@ -202,12 +213,12 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim
templatePath := trialSource.ConfigMap.TemplatePath
configMap, err := g.client.GetConfigMap(configMapName, configMapNS)
if err != nil {
- return "", fmt.Errorf("GetConfigMap failed: %v", err)
+ return "", fmt.Errorf("%w: %w", errConfigMapNotFound, err)
}
var ok bool
trialTemplateString, ok = configMap[templatePath]
if !ok {
- return "", fmt.Errorf("TemplatePath: %v not found in configMap: %v", templatePath, configMap)
+ return "", fmt.Errorf("%w: TemplatePath: %v, ConfigMap: %v", errTrialTemplateNotFound, templatePath, configMap)
}
}
diff --git a/pkg/controller.v1beta1/experiment/manifest/generator_test.go b/pkg/controller.v1beta1/experiment/manifest/generator_test.go
index dabd2631063..57e2789a2ac 100644
--- a/pkg/controller.v1beta1/experiment/manifest/generator_test.go
+++ b/pkg/controller.v1beta1/experiment/manifest/generator_test.go
@@ -17,11 +17,11 @@ limitations under the License.
package manifest
import (
- "errors"
"math"
- "reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"go.uber.org/mock/gomock"
batchv1 "k8s.io/api/batch/v1"
v1 "k8s.io/api/core/v1"
@@ -88,23 +88,18 @@ func TestGetRunSpecWithHP(t *testing.T) {
t.Errorf("ConvertObjectToUnstructured failed: %v", err)
}
- tcs := []struct {
- instance *experimentsv1beta1.Experiment
- parameterAssignments []commonapiv1beta1.ParameterAssignment
- expectedRunSpec *unstructured.Unstructured
- err bool
- testDescription string
+ cases := map[string]struct {
+ instance *experimentsv1beta1.Experiment
+ parameterAssignments []commonapiv1beta1.ParameterAssignment
+ wantRunSpecWithHyperParameters *unstructured.Unstructured
+ wantError error
}{
- // Valid run
- {
- instance: newFakeInstance(),
- parameterAssignments: newFakeParameterAssignment(),
- expectedRunSpec: expectedRunSpec,
- err: false,
- testDescription: "Run with valid parameters",
+ "Run with valid parameters": {
+ instance: newFakeInstance(),
+ parameterAssignments: newFakeParameterAssignment(),
+ wantRunSpecWithHyperParameters: expectedRunSpec,
},
- // Invalid JSON in unstructured
- {
+ "Invalid JSON in Unstructured Trial template": {
instance: func() *experimentsv1beta1.Experiment {
i := newFakeInstance()
trialSpec := i.Spec.TrialTemplate.TrialSource.TrialSpec
@@ -114,48 +109,45 @@ func TestGetRunSpecWithHP(t *testing.T) {
return i
}(),
parameterAssignments: newFakeParameterAssignment(),
- err: true,
- testDescription: "Invalid JSON in Trial template",
+ wantError: errConvertUnstructuredToStringFailed,
},
- // len(parameterAssignment) != len(trialParameters)
- {
+ "Non-meta parameter from TrialParameters not found in ParameterAssignment": {
instance: newFakeInstance(),
parameterAssignments: func() []commonapiv1beta1.ParameterAssignment {
pa := newFakeParameterAssignment()
- pa = pa[1:]
+ pa[0] = commonapiv1beta1.ParameterAssignment{
+ Name: "invalid-name",
+ Value: "invalid-value",
+ }
return pa
}(),
- err: true,
- testDescription: "Number of parameter assignments is not equal to number of Trial parameters",
+ wantError: errParamNotFoundInParameterAssignment,
},
- // Parameter from assignments not found in Trial parameters
- {
+ // case in which the lengths of trial parameters and parameter assignments are different
+ "Parameter from ParameterAssignment not found in TrialParameters": {
instance: newFakeInstance(),
parameterAssignments: func() []commonapiv1beta1.ParameterAssignment {
pa := newFakeParameterAssignment()
- pa[0] = commonapiv1beta1.ParameterAssignment{
- Name: "invalid-name",
- Value: "invalid-value",
- }
+ pa = append(pa, commonapiv1beta1.ParameterAssignment{
+ Name: "extra-name",
+ Value: "extra-value",
+ })
return pa
}(),
- err: true,
- testDescription: "Trial parameters don't have parameter from assignments",
+ wantError: errParamNotFoundInTrialParameters,
},
}
- for _, tc := range tcs {
- actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments)
-
- if tc.err && err == nil {
- t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription)
- } else if !tc.err {
- if err != nil {
- t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err)
- } else if !reflect.DeepEqual(tc.expectedRunSpec, actualRunSpec) {
- t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedRunSpec.Object, actualRunSpec.Object)
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ got, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff)
}
- }
+ if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 {
+ t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff)
+ }
+ })
}
}
@@ -204,25 +196,6 @@ spec:
- --momentum=${trialParameters.momentum}
- --invalidParameter={'num_layers': 2, 'input_sizes': [32, 32, 3]}`
- validGetConfigMap1 := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
- map[string]string{templatePath: trialSpec}, nil)
-
- invalidConfigMapName := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
- nil, errors.New("Unable to get ConfigMap"))
-
- validGetConfigMap3 := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
- map[string]string{templatePath: trialSpec}, nil)
-
- invalidTemplate := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
- map[string]string{templatePath: invalidTrialSpec}, nil)
-
- gomock.InOrder(
- validGetConfigMap1,
- invalidConfigMapName,
- validGetConfigMap3,
- invalidTemplate,
- )
-
// We can't compare structures, because in ConfigMap trialSpec is a string and creationTimestamp was not added
expectedStr := `apiVersion: batch/v1
kind: Job
@@ -244,19 +217,23 @@ spec:
- "--momentum=0.9"`
expectedRunSpec, err := util.ConvertStringToUnstructured(expectedStr)
- if err != nil {
- t.Errorf("ConvertStringToUnstructured failed: %v", err)
+ if diff := cmp.Diff(nil, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("ConvertStringToUnstructured failed (-want,+got):\n%s", diff)
}
- tcs := []struct {
- instance *experimentsv1beta1.Experiment
- parameterAssignments []commonapiv1beta1.ParameterAssignment
- err bool
- testDescription string
+ cases := map[string]struct {
+ mockConfigMapGetter func() *gomock.Call
+ instance *experimentsv1beta1.Experiment
+ parameterAssignments []commonapiv1beta1.ParameterAssignment
+ wantRunSpecWithHyperParameters *unstructured.Unstructured
+ wantError error
}{
- // Valid run
- // validGetConfigMap1 case
- {
+ "Run with valid parameters": {
+ mockConfigMapGetter: func() *gomock.Call {
+ return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
+ map[string]string{templatePath: trialSpec}, nil,
+ )
+ },
instance: func() *experimentsv1beta1.Experiment {
i := newFakeInstance()
i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{
@@ -268,13 +245,15 @@ spec:
}
return i
}(),
- parameterAssignments: newFakeParameterAssignment(),
- err: false,
- testDescription: "Run with valid parameters",
+ parameterAssignments: newFakeParameterAssignment(),
+ wantRunSpecWithHyperParameters: expectedRunSpec,
},
- // Invalid ConfigMap name
- // invalidConfigMapName case
- {
+ "Invalid ConfigMap name": {
+ mockConfigMapGetter: func() *gomock.Call {
+ return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
+ nil, errConfigMapNotFound,
+ )
+ },
instance: func() *experimentsv1beta1.Experiment {
i := newFakeInstance()
i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{
@@ -285,12 +264,14 @@ spec:
return i
}(),
parameterAssignments: newFakeParameterAssignment(),
- err: true,
- testDescription: "Invalid ConfigMap name",
+ wantError: errConfigMapNotFound,
},
- // Invalid template path in ConfigMap name
- // validGetConfigMap3 case
- {
+ "Invalid template path in ConfigMap name": {
+ mockConfigMapGetter: func() *gomock.Call {
+ return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
+ map[string]string{templatePath: trialSpec}, nil,
+ )
+ },
instance: func() *experimentsv1beta1.Experiment {
i := newFakeInstance()
i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{
@@ -303,14 +284,16 @@ spec:
return i
}(),
parameterAssignments: newFakeParameterAssignment(),
- err: true,
- testDescription: "Invalid template path in ConfigMap",
+ wantError: errTrialTemplateNotFound,
},
- // Invalid Trial template spec in ConfigMap
// Trial template is a string in ConfigMap
// Because of that, user can specify not valid unstructured template
- // invalidTemplate case
- {
+ "Invalid trial spec in ConfigMap": {
+ mockConfigMapGetter: func() *gomock.Call {
+ return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return(
+ map[string]string{templatePath: invalidTrialSpec}, nil,
+ )
+ },
instance: func() *experimentsv1beta1.Experiment {
i := newFakeInstance()
i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{
@@ -323,22 +306,21 @@ spec:
return i
}(),
parameterAssignments: newFakeParameterAssignment(),
- err: true,
- testDescription: "Invalid Trial spec in ConfigMap",
+ wantError: errConvertStringToUnstructuredFailed,
},
}
- for _, tc := range tcs {
- actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments)
- if tc.err && err == nil {
- t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription)
- } else if !tc.err {
- if err != nil {
- t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err)
- } else if !reflect.DeepEqual(expectedRunSpec, actualRunSpec) {
- t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, expectedRunSpec.Object, actualRunSpec.Object)
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ tc.mockConfigMapGetter()
+ got, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 {
+ t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff)
}
- }
+ })
}
}
diff --git a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go
index 8db6f3b82f3..b77949ac9e5 100644
--- a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go
+++ b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go
@@ -532,13 +532,38 @@ func convertParameterType(typ experimentsv1beta1.ParameterType) suggestionapi.Pa
}
func convertFeasibleSpace(fs experimentsv1beta1.FeasibleSpace) *suggestionapi.FeasibleSpace {
- res := &suggestionapi.FeasibleSpace{
- Max: fs.Max,
- Min: fs.Min,
- List: fs.List,
- Step: fs.Step,
+ distribution := convertDistribution(fs.Distribution)
+ if distribution == suggestionapi.Distribution_DISTRIBUTION_UNKNOWN {
+ return &suggestionapi.FeasibleSpace{
+ Max: fs.Max,
+ Min: fs.Min,
+ List: fs.List,
+ Step: fs.Step,
+ }
+ }
+
+ return &suggestionapi.FeasibleSpace{
+ Max: fs.Max,
+ Min: fs.Min,
+ List: fs.List,
+ Step: fs.Step,
+ Distribution: distribution,
+ }
+}
+
+func convertDistribution(typ experimentsv1beta1.Distribution) suggestionapi.Distribution {
+ switch typ {
+ case experimentsv1beta1.DistributionUniform:
+ return suggestionapi.Distribution_UNIFORM
+ case experimentsv1beta1.DistributionLogUniform:
+ return suggestionapi.Distribution_LOG_UNIFORM
+ case experimentsv1beta1.DistributionNormal:
+ return suggestionapi.Distribution_NORMAL
+ case experimentsv1beta1.DistributionLogNormal:
+ return suggestionapi.Distribution_LOG_NORMAL
+ default:
+ return suggestionapi.Distribution_DISTRIBUTION_UNKNOWN
}
- return res
}
func convertComparison(comparison suggestionapi.ComparisonType) commonapiv1beta1.ComparisonType {
diff --git a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go
index b51e9f6cb29..c4df08bbfaa 100644
--- a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go
+++ b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go
@@ -23,6 +23,8 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"github.com/onsi/gomega"
"go.uber.org/mock/gomock"
"google.golang.org/grpc"
@@ -539,6 +541,96 @@ func TestConvertParameterType(t *testing.T) {
}
}
+func TestConvertFeasibleSpace(t *testing.T) {
+
+ tcs := []struct {
+ inFeasibleSpace experimentsv1beta1.FeasibleSpace
+ expectedFeasibleSpace *suggestionapi.FeasibleSpace
+ testDescription string
+ }{
+ {
+ inFeasibleSpace: experimentsv1beta1.FeasibleSpace{
+ Max: "10",
+ Min: "1",
+ List: []string{"1", "2", "3"},
+ Step: "1",
+ Distribution: experimentsv1beta1.DistributionUnknown,
+ },
+ expectedFeasibleSpace: &suggestionapi.FeasibleSpace{
+ Max: "10",
+ Min: "1",
+ List: []string{"1", "2", "3"},
+ Step: "1",
+ },
+ testDescription: "Convert feasible space with unknown distribution",
+ },
+ {
+ inFeasibleSpace: experimentsv1beta1.FeasibleSpace{
+ Max: "100",
+ Min: "10",
+ Step: "10",
+ Distribution: experimentsv1beta1.DistributionUniform,
+ },
+ expectedFeasibleSpace: &suggestionapi.FeasibleSpace{
+ Max: "100",
+ Min: "10",
+ Step: "10",
+ Distribution: suggestionapi.Distribution_UNIFORM,
+ },
+ testDescription: "Convert feasible space with uniform distribution",
+ },
+ }
+
+ for _, tc := range tcs {
+ actualFeasibleSpace := convertFeasibleSpace(tc.inFeasibleSpace)
+ if diff := cmp.Diff(tc.expectedFeasibleSpace, actualFeasibleSpace, cmpopts.IgnoreUnexported(suggestionapi.FeasibleSpace{})); diff != "" {
+ t.Errorf("Case: %v failed. Unexpected difference (-want +got):\n%s", tc.testDescription, diff)
+ }
+ }
+}
+
+func TestConvertDistribution(t *testing.T) {
+
+ tcs := []struct {
+ inDistribution experimentsv1beta1.Distribution
+ expectedDistribution suggestionapi.Distribution
+ testDescription string
+ }{
+ {
+ inDistribution: experimentsv1beta1.DistributionUniform,
+ expectedDistribution: suggestionapi.Distribution_UNIFORM,
+ testDescription: "Convert uniform distribution",
+ },
+ {
+ inDistribution: experimentsv1beta1.DistributionLogUniform,
+ expectedDistribution: suggestionapi.Distribution_LOG_UNIFORM,
+ testDescription: "Convert log-uniform distribution",
+ },
+ {
+ inDistribution: experimentsv1beta1.DistributionNormal,
+ expectedDistribution: suggestionapi.Distribution_NORMAL,
+ testDescription: "Convert normal distribution",
+ },
+ {
+ inDistribution: experimentsv1beta1.DistributionLogNormal,
+ expectedDistribution: suggestionapi.Distribution_LOG_NORMAL,
+ testDescription: "Convert log-normal distribution",
+ },
+ {
+ inDistribution: experimentsv1beta1.DistributionUnknown,
+ expectedDistribution: suggestionapi.Distribution_DISTRIBUTION_UNKNOWN,
+ testDescription: "Convert unknown distribution",
+ },
+ }
+
+ for _, tc := range tcs {
+ actualDistribution := convertDistribution(tc.inDistribution)
+ if actualDistribution != tc.expectedDistribution {
+ t.Errorf("Case: %v failed. Expected distribution %v, got %v", tc.testDescription, tc.expectedDistribution, actualDistribution)
+ }
+ }
+}
+
func TestConvertTrialObservation(t *testing.T) {
tcs := []struct {
diff --git a/pkg/controller.v1beta1/trial/util/job_util_test.go b/pkg/controller.v1beta1/trial/util/job_util_test.go
index d2a018967c3..c1908144df9 100644
--- a/pkg/controller.v1beta1/trial/util/job_util_test.go
+++ b/pkg/controller.v1beta1/trial/util/job_util_test.go
@@ -17,7 +17,6 @@ limitations under the License.
package util
import (
- "reflect"
"testing"
batchv1 "k8s.io/api/batch/v1"
@@ -25,6 +24,8 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
trialsv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/trials/v1beta1"
"github.com/kubeflow/katib/pkg/controller.v1beta1/util"
)
@@ -39,14 +40,13 @@ func TestGetDeployedJobStatus(t *testing.T) {
successCondition := "status.conditions.#(type==\"Complete\")#|#(status==\"True\")#"
failureCondition := "status.conditions.#(type==\"Failed\")#|#(status==\"True\")#"
- tcs := []struct {
- trial *trialsv1beta1.Trial
- deployedJob *unstructured.Unstructured
- expectedTrialJobStatus *TrialJobStatus
- err bool
- testDescription string
+ cases := map[string]struct {
+ trial *trialsv1beta1.Trial
+ deployedJob *unstructured.Unstructured
+ wantTrialJobStatus *TrialJobStatus
+ wantError error
}{
- {
+ "Job status is running": {
trial: newFakeTrial(successCondition, failureCondition),
deployedJob: func() *unstructured.Unstructured {
job := newFakeJob()
@@ -54,28 +54,24 @@ func TestGetDeployedJobStatus(t *testing.T) {
job.Status.Conditions[1].Status = corev1.ConditionFalse
return newFakeDeployedJob(job)
}(),
- expectedTrialJobStatus: func() *TrialJobStatus {
+ wantTrialJobStatus: func() *TrialJobStatus {
return &TrialJobStatus{
Condition: JobRunning,
}
}(),
- err: false,
- testDescription: "Job status is running",
},
- {
+ "Job status is succeeded, reason and message must be returned": {
trial: newFakeTrial(successCondition, failureCondition),
deployedJob: newFakeDeployedJob(newFakeJob()),
- expectedTrialJobStatus: func() *TrialJobStatus {
+ wantTrialJobStatus: func() *TrialJobStatus {
return &TrialJobStatus{
Condition: JobSucceeded,
Message: testMessage,
Reason: testReason,
}
}(),
- err: false,
- testDescription: "Job status is succeeded, reason and message must be returned",
},
- {
+ "Job status is failed, reason and message must be returned": {
trial: newFakeTrial(successCondition, failureCondition),
deployedJob: func() *unstructured.Unstructured {
job := newFakeJob()
@@ -83,41 +79,35 @@ func TestGetDeployedJobStatus(t *testing.T) {
job.Status.Conditions[1].Status = corev1.ConditionFalse
return newFakeDeployedJob(job)
}(),
- expectedTrialJobStatus: func() *TrialJobStatus {
+ wantTrialJobStatus: func() *TrialJobStatus {
return &TrialJobStatus{
Condition: JobFailed,
Message: testMessage,
Reason: testReason,
}
}(),
- err: false,
- testDescription: "Job status is failed, reason and message must be returned",
},
- {
+ "Job status is succeeded because status.succeeded = 1": {
trial: newFakeTrial("status.[@this].#(succeeded==1)", failureCondition),
deployedJob: newFakeDeployedJob(newFakeJob()),
- expectedTrialJobStatus: func() *TrialJobStatus {
+ wantTrialJobStatus: func() *TrialJobStatus {
return &TrialJobStatus{
Condition: JobSucceeded,
}
}(),
- err: false,
- testDescription: "Job status is succeeded because status.succeeded = 1",
},
}
- for _, tc := range tcs {
- actualTrialJobStatus, err := GetDeployedJobStatus(tc.trial, tc.deployedJob)
-
- if tc.err && err == nil {
- t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription)
- } else if !tc.err {
- if err != nil {
- t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err)
- } else if !reflect.DeepEqual(tc.expectedTrialJobStatus, actualTrialJobStatus) {
- t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedTrialJobStatus, actualTrialJobStatus)
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ got, err := GetDeployedJobStatus(tc.trial, tc.deployedJob)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from GetDeployedJobStatus() (-want,+got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tc.wantTrialJobStatus, got); len(diff) != 0 {
+ t.Errorf("Unexpected trial job status from GetDeployedJobStatus() (-want,+got):\n%s", diff)
}
- }
+ })
}
}
@@ -154,6 +144,7 @@ func newFakeJob() *batchv1.Job {
},
}
}
+
func newFakeDeployedJob(job interface{}) *unstructured.Unstructured {
jobUnstructured, _ := util.ConvertObjectToUnstructured(job)
diff --git a/pkg/earlystopping/v1beta1/medianstop/service.py b/pkg/earlystopping/v1beta1/medianstop/service.py
index 2e4d02acfc2..94c9fdc6bda 100644
--- a/pkg/earlystopping/v1beta1/medianstop/service.py
+++ b/pkg/earlystopping/v1beta1/medianstop/service.py
@@ -12,17 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from datetime import datetime
import logging
import multiprocessing
+from datetime import datetime
from typing import Iterable, Optional
import grpc
-from kubernetes import client
-from kubernetes import config
+from kubernetes import client, config
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
@@ -39,7 +37,6 @@
class MedianStopService(api_pb2_grpc.EarlyStoppingServicer):
-
def __init__(self):
super(MedianStopService, self).__init__()
self.is_first_run = True
@@ -52,22 +49,30 @@ def __init__(self):
# Assume that Trial namespace = Suggestion namespace.
try:
- with open('/var/run/secrets/kubernetes.io/serviceaccount/namespace', 'r') as f:
+ with open(
+ "/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r"
+ ) as f:
self.namespace = f.readline()
# Set config and api instance for k8s client.
config.load_incluster_config()
# This is used when service is not running in k8s, e.g. for unit tests.
except Exception as e:
- logger.info("{}. Service is not running in Kubernetes Pod, \"{}\" namespace is used".format(
- e, DEFAULT_NAMESPACE
- ))
+ logger.info(
+ '{}. Service is not running in Kubernetes Pod, "{}" namespace is used'.format(
+ e, DEFAULT_NAMESPACE
+ )
+ )
self.namespace = DEFAULT_NAMESPACE
# Set config and api instance for k8s client.
config.load_kube_config()
self.api_instance = client.CustomObjectsApi()
- def ValidateEarlyStoppingSettings(self, request: api_pb2.ValidateEarlyStoppingSettingsRequest, context: grpc.ServicerContext) -> api_pb2.ValidateEarlyStoppingSettingsReply:
+ def ValidateEarlyStoppingSettings(
+ self,
+ request: api_pb2.ValidateEarlyStoppingSettingsRequest,
+ context: grpc.ServicerContext,
+ ) -> api_pb2.ValidateEarlyStoppingSettingsReply:
is_valid, message = self.validate_early_stopping_spec(request.early_stopping)
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
@@ -78,7 +83,9 @@ def ValidateEarlyStoppingSettings(self, request: api_pb2.ValidateEarlyStoppingSe
def validate_early_stopping_spec(self, early_stopping_spec):
algorithm_name = early_stopping_spec.algorithm_name
if algorithm_name == "medianstop":
- return self.validate_medianstop_setting(early_stopping_spec.algorithm_settings)
+ return self.validate_medianstop_setting(
+ early_stopping_spec.algorithm_settings
+ )
else:
return False, "unknown algorithm name {}".format(algorithm_name)
@@ -88,59 +95,90 @@ def validate_medianstop_setting(early_stopping_settings):
try:
if setting.name == "min_trials_required":
if not (int(setting.value) > 0):
- return False, "min_trials_required must be greater than zero (>0)"
+ return (
+ False,
+ "min_trials_required must be greater than zero (>0)",
+ )
elif setting.name == "start_step":
if not (int(setting.value) >= 1):
- return False, "start_step must be greater or equal than one (>=1)"
+ return (
+ False,
+ "start_step must be greater or equal than one (>=1)",
+ )
else:
- return False, "unknown setting {} for algorithm medianstop".format(setting.name)
+ return False, "unknown setting {} for algorithm medianstop".format(
+ setting.name
+ )
except Exception as e:
- return False, "failed to validate {}({}): {}".format(setting.name, setting.value, e)
+ return False, "failed to validate {}({}): {}".format(
+ setting.name, setting.value, e
+ )
return True, ""
- def GetEarlyStoppingRules(self, request: api_pb2.GetEarlyStoppingRulesRequest, context: grpc.ServicerContext) -> api_pb2.GetSuggestionsReply:
+ def GetEarlyStoppingRules(
+ self,
+ request: api_pb2.GetEarlyStoppingRulesRequest,
+ context: grpc.ServicerContext,
+ ) -> api_pb2.GetSuggestionsReply:
logger.info("Get new early stopping rules")
# Get required values for the first call.
if self.is_first_run:
self.is_first_run = False
# Get early stopping settings.
- self.get_early_stopping_settings(request.experiment.spec.early_stopping.algorithm_settings)
- logger.info("Median stopping settings are: min_trials_required: {}, start_step: {}".format(
- self.min_trials_required, self.start_step))
+ self.get_early_stopping_settings(
+ request.experiment.spec.early_stopping.algorithm_settings
+ )
+ logger.info(
+ "Median stopping settings are: min_trials_required: {}, start_step: {}".format(
+ self.min_trials_required, self.start_step
+ )
+ )
# Get comparison type and objective metric
if request.experiment.spec.objective.type == api_pb2.MAXIMIZE:
self.comparison = api_pb2.LESS
else:
self.comparison = api_pb2.GREATER
- self.objective_metric = request.experiment.spec.objective.objective_metric_name
+ self.objective_metric = (
+ request.experiment.spec.objective.objective_metric_name
+ )
# Get DB manager address. It should have host and port.
# For example: katib-db-manager.kubeflow:6789 - default one.
- self.db_manager_address = request.db_manager_address.split(':')
+ self.db_manager_address = request.db_manager_address.split(":")
if len(self.db_manager_address) != 2:
- raise Exception("Invalid Katib DB manager service address: {}".format(self.db_manager_address))
+ raise Exception(
+ "Invalid Katib DB manager service address: {}".format(
+ self.db_manager_address
+ )
+ )
early_stopping_rules = []
median = self.get_median_value(request.trials)
if median is not None:
- early_stopping_rules.append(api_pb2.EarlyStoppingRule(
- name=self.objective_metric,
- value=str(median),
- comparison=self.comparison,
- start_step=self.start_step,
- ))
-
- logger.info("New early stopping rules are:\n {}\n\n".format(early_stopping_rules))
+ early_stopping_rules.append(
+ api_pb2.EarlyStoppingRule(
+ name=self.objective_metric,
+ value=str(median),
+ comparison=self.comparison,
+ start_step=self.start_step,
+ )
+ )
+
+ logger.info(
+ "New early stopping rules are:\n {}\n\n".format(early_stopping_rules)
+ )
return api_pb2.GetEarlyStoppingRulesReply(
early_stopping_rules=early_stopping_rules
)
- def get_early_stopping_settings(self, early_stopping_settings: Iterable[api_pb2.EarlyStoppingSetting]):
+ def get_early_stopping_settings(
+ self, early_stopping_settings: Iterable[api_pb2.EarlyStoppingSetting]
+ ):
for setting in early_stopping_settings:
if setting.name == "min_trials_required":
self.min_trials_required = int(setting.value)
@@ -168,8 +206,11 @@ def get_median_value(self, trials: Iterable[api_pb2.Trial]) -> Optional[float]:
)
# Get only first start_step metrics.
- # Since metrics are collected consistently and ordered by time, we slice top start_step metrics.
- first_x_logs = get_log_response.observation_log.metric_logs[:self.start_step]
+ # Since metrics are collected consistently and ordered by time,
+ # we slice top start_step metrics.
+ first_x_logs = get_log_response.observation_log.metric_logs[
+ : self.start_step
+ ]
metric_sum = 0
for log in first_x_logs:
metric_sum += float(log.metric.value)
@@ -177,22 +218,33 @@ def get_median_value(self, trials: Iterable[api_pb2.Trial]) -> Optional[float]:
# Get average metric value for the Trial.
new_average = metric_sum / len(first_x_logs)
self.trials_avg_history[trial.name] = new_average
- logger.info("Adding new succeeded Trial: {} with average metrics value: {}".format(
- trial.name, new_average))
- logger.info("Trials average log history: {}".format(self.trials_avg_history))
+ logger.info(
+ "Adding new succeeded Trial: {} with average metrics value: {}".format(
+ trial.name, new_average
+ )
+ )
+ logger.info(
+ "Trials average log history: {}".format(self.trials_avg_history)
+ )
# If count of succeeded Trials is greater than min_trials_required, calculate median.
if len(self.trials_avg_history) >= self.min_trials_required:
- median = sum(list(self.trials_avg_history.values())) / len(self.trials_avg_history)
+ median = sum(list(self.trials_avg_history.values())) / len(
+ self.trials_avg_history
+ )
logger.info("Generate new Median value: {}".format(median))
return median
# Else, return None.
- logger.info("Count of succeeded Trials: {} is less than min_trials_required: {}".format(
- len(self.trials_avg_history), self.min_trials_required
- ))
+ logger.info(
+ "Count of succeeded Trials: {} is less than min_trials_required: {}".format(
+ len(self.trials_avg_history), self.min_trials_required
+ )
+ )
return None
- def SetTrialStatus(self, request: api_pb2.SetTrialStatusRequest, context: grpc.ServicerContext) -> api_pb2.SetTrialStatusReply:
+ def SetTrialStatus(
+ self, request: api_pb2.SetTrialStatusRequest, context: grpc.ServicerContext
+ ) -> api_pb2.SetTrialStatusReply:
trial_name = request.trial_name
logger.info("Update status for Trial: {}".format(trial_name))
@@ -205,7 +257,8 @@ def SetTrialStatus(self, request: api_pb2.SetTrialStatusRequest, context: grpc.S
self.namespace,
TRIAL_PLURAL,
trial_name,
- async_req=True)
+ async_req=True,
+ )
trial = None
try:
@@ -214,7 +267,10 @@ def SetTrialStatus(self, request: api_pb2.SetTrialStatusRequest, context: grpc.S
raise Exception("Timeout trying to get Katib Trial")
except Exception as e:
raise Exception(
- "Get Trial: {} in namespace: {} failed. Exception: {}".format(trial_name, self.namespace, e))
+ "Get Trial: {} in namespace: {} failed. Exception: {}".format(
+ trial_name, self.namespace, e
+ )
+ )
time_now = datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ")
@@ -237,13 +293,19 @@ def SetTrialStatus(self, request: api_pb2.SetTrialStatusRequest, context: grpc.S
TRIAL_PLURAL,
trial_name,
trial,
- async_req=True)
+ async_req=True,
+ )
except Exception as e:
raise Exception(
"Update status for Trial: {} in namespace: {} failed. Exception: {}".format(
- trial_name, self.namespace, e))
-
- logger.info("Changed status to: {} for Trial: {} in namespace: {}\n\n".format(
- STATUS_EARLY_STOPPED, trial_name, self.namespace))
+ trial_name, self.namespace, e
+ )
+ )
+
+ logger.info(
+ "Changed status to: {} for Trial: {} in namespace: {}\n\n".format(
+ STATUS_EARLY_STOPPED, trial_name, self.namespace
+ )
+ )
return api_pb2.SetTrialStatusReply()
diff --git a/pkg/metricscollector/v1beta1/common/pns.py b/pkg/metricscollector/v1beta1/common/pns.py
index 86f5563bbec..24f0882b72b 100644
--- a/pkg/metricscollector/v1beta1/common/pns.py
+++ b/pkg/metricscollector/v1beta1/common/pns.py
@@ -25,17 +25,20 @@ def WaitMainProcesses(pool_interval, timout, wait_all, completed_marked_dir):
Hold metrics collector parser until required pids are finished
"""
- if not sys.platform.startswith('linux'):
+ if not sys.platform.startswith("linux"):
raise Exception("Platform '{}' unsupported".format(sys.platform))
pids, main_pid = GetMainProcesses(completed_marked_dir)
- return WaitPIDs(pids, main_pid, pool_interval, timout, wait_all, completed_marked_dir)
+ return WaitPIDs(
+ pids, main_pid, pool_interval, timout, wait_all, completed_marked_dir
+ )
def GetMainProcesses(completed_marked_dir):
"""
- Return array with all running processes pids and main process pid which metrics collector is waiting.
+ Return array with all running processes pids
+ and main process pid which metrics collector is waiting.
"""
pids = set()
main_pid = 0
@@ -59,7 +62,10 @@ def GetMainProcesses(completed_marked_dir):
# In addition to that, command line contains completed marker for the main pid.
# For example: echo completed > /var/log/katib/$$$$.pid
# completed_marked_dir is the directory for completed marker, e.g. /var/log/katib
- if main_pid == 0 or ("echo {} > {}".format(const.TRAINING_COMPLETED, completed_marked_dir) in cmd_lind):
+ if main_pid == 0 or (
+ "echo {} > {}".format(const.TRAINING_COMPLETED, completed_marked_dir)
+ in cmd_lind
+ ):
main_pid = pid
pids.add(pid)
@@ -92,16 +98,25 @@ def WaitPIDs(pids, main_pid, pool_interval, timout, wait_all, completed_marked_d
path = "/proc/{}".format(pid)
if not os.path.exists(path):
if pid == main_pid:
- # For main_pid we check if file with "completed" marker exists if completed_marked_dir is set
+ # For main_pid we check if file with "completed"
+ # marker exists if completed_marked_dir is set
if completed_marked_dir:
- mark_file = os.path.join(completed_marked_dir, "{}.pid".format(pid))
+ mark_file = os.path.join(
+ completed_marked_dir, "{}.pid".format(pid)
+ )
# Check if file contains "completed" marker
with open(mark_file) as file_obj:
contents = file_obj.read()
if contents.strip() != const.TRAINING_COMPLETED:
raise Exception(
- "Unable to find marker: {} in file: {} with contents: {} for pid: {}".format(
- const.TRAINING_COMPLETED, mark_file, contents, pid))
+ "Unable to find marker: {} in file: {} with contents: {} "
+ "for pid: {}".format(
+ const.TRAINING_COMPLETED,
+ mark_file,
+ contents,
+ pid,
+ )
+ )
# Add main pid to finished pids set
finished_pids.add(pid)
# Exit loop if wait all is false because main pid is finished
diff --git a/pkg/metricscollector/v1beta1/tfevent-metricscollector/tfevent_loader.py b/pkg/metricscollector/v1beta1/tfevent-metricscollector/tfevent_loader.py
index 931014d1c15..f41597f9237 100644
--- a/pkg/metricscollector/v1beta1/tfevent-metricscollector/tfevent_loader.py
+++ b/pkg/metricscollector/v1beta1/tfevent-metricscollector/tfevent_loader.py
@@ -13,7 +13,8 @@
# limitations under the License.
# TFEventFileParser parses tfevent files and returns an ObservationLog of the metrics specified.
-# When the event file is under a directory(e.g. test dir), please specify "{{dirname}}/{{metrics name}}"
+# When the event file is under a directory(e.g. test dir), please specify
+# "{{dirname}}/{{metrics name}}"
# For example, in the Tensorflow MNIST Classification With Summaries:
# https://github.com/kubeflow/katib/blob/master/examples/v1beta1/trial-images/tf-mnist-with-summaries/mnist.py.
# The "accuracy" and "loss" metric is saved under "train" and "test" directories.
@@ -21,19 +22,15 @@
# Check TFJob example for more information:
# https://github.com/kubeflow/katib/blob/master/examples/v1beta1/kubeflow-training-operator/tfjob-mnist-with-summaries.yaml#L16-L22
-from datetime import datetime
-from logging import getLogger
-from logging import INFO
-from logging import StreamHandler
import os
+from datetime import datetime
+from logging import INFO, StreamHandler, getLogger
import api_pb2
import rfc3339
-from tensorboard.backend.event_processing.event_accumulator import \
- EventAccumulator
-from tensorboard.backend.event_processing.event_accumulator import TensorEvent
-from tensorboard.backend.event_processing.tag_types import TENSORS
import tensorflow as tf
+from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
+from tensorboard.backend.event_processing.tag_types import TENSORS
from pkg.metricscollector.v1beta1.common import const
@@ -54,19 +51,25 @@ def parse_summary(self, tfefile):
event_accumulator.Reload()
for tag in event_accumulator.Tags()[TENSORS]:
for m in self.metric_names:
-
- tfefile_parent_dir = os.path.dirname(m) if len(m.split("/")) >= 2 else os.path.dirname(tfefile)
+ tfefile_parent_dir = (
+ os.path.dirname(m)
+ if len(m.split("/")) >= 2
+ else os.path.dirname(tfefile)
+ )
basedir_name = os.path.dirname(tfefile)
- if not tag.startswith(m.split("/")[-1]) or not basedir_name.endswith(tfefile_parent_dir):
+ if not tag.startswith(m.split("/")[-1]) or not basedir_name.endswith(
+ tfefile_parent_dir
+ ):
continue
for tensor in event_accumulator.Tensors(tag):
ml = api_pb2.MetricLog(
- time_stamp=rfc3339.rfc3339(datetime.fromtimestamp(tensor.wall_time)),
+ time_stamp=rfc3339.rfc3339(
+ datetime.fromtimestamp(tensor.wall_time)
+ ),
metric=api_pb2.Metric(
- name=m,
- value=str(tf.make_ndarray(tensor.tensor_proto))
- )
+ name=m, value=str(tf.make_ndarray(tensor.tensor_proto))
+ ),
)
metric_logs.append(ml)
@@ -109,12 +112,14 @@ def parse_file(self, directory):
api_pb2.MetricLog(
time_stamp=rfc3339.rfc3339(datetime.now()),
metric=api_pb2.Metric(
- name=self.metrics[0],
- value=const.UNAVAILABLE_METRIC_VALUE
- )
+ name=self.metrics[0], value=const.UNAVAILABLE_METRIC_VALUE
+ ),
)
]
- self.logger.info("Objective metric {} is not found in training logs, {} value is reported".format(
- self.metrics[0], const.UNAVAILABLE_METRIC_VALUE))
+ self.logger.info(
+ "Objective metric {} is not found in training logs, {} value is reported".format(
+ self.metrics[0], const.UNAVAILABLE_METRIC_VALUE
+ )
+ )
return api_pb2.ObservationLog(metric_logs=mls)
diff --git a/pkg/suggestion/v1beta1/goptuna/converter_test.go b/pkg/suggestion/v1beta1/goptuna/converter_test.go
index 1e3189a773d..d92b0f391ca 100644
--- a/pkg/suggestion/v1beta1/goptuna/converter_test.go
+++ b/pkg/suggestion/v1beta1/goptuna/converter_test.go
@@ -17,48 +17,44 @@ limitations under the License.
package suggestion_goptuna_v1beta1
import (
- "reflect"
"testing"
"github.com/c-bata/goptuna"
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
api_v1_beta1 "github.com/kubeflow/katib/pkg/apis/manager/v1beta1"
)
func Test_toGoptunaDirection(t *testing.T) {
- for _, tt := range []struct {
- name string
+ for name, tc := range map[string]struct {
objectiveType api_v1_beta1.ObjectiveType
- expected goptuna.StudyDirection
+ wantDirection goptuna.StudyDirection
}{
- {
- name: "minimize",
+ "minimize": {
objectiveType: api_v1_beta1.ObjectiveType_MINIMIZE,
- expected: goptuna.StudyDirectionMinimize,
+ wantDirection: goptuna.StudyDirectionMinimize,
},
- {
- name: "maximize",
+ "maximize": {
objectiveType: api_v1_beta1.ObjectiveType_MAXIMIZE,
- expected: goptuna.StudyDirectionMaximize,
+ wantDirection: goptuna.StudyDirectionMaximize,
},
} {
- t.Run(tt.name, func(t *testing.T) {
- got := toGoptunaDirection(tt.objectiveType)
- if got != tt.expected {
- t.Errorf("toGoptunaDirection() got = %v, want %v", got, tt.expected)
+ t.Run(name, func(t *testing.T) {
+ got := toGoptunaDirection(tc.objectiveType)
+ if diff := cmp.Diff(tc.wantDirection, got); len(diff) != 0 {
+ t.Errorf("Unexpected direction from toGoptunaDirection (-want,+got):\n%s", diff)
}
})
}
}
func Test_toGoptunaSearchSpace(t *testing.T) {
- tests := []struct {
- name string
- parameters []*api_v1_beta1.ParameterSpec
- want map[string]interface{}
- wantErr bool
+ cases := map[string]struct {
+ parameters []*api_v1_beta1.ParameterSpec
+ wantSearchSpace map[string]interface{}
+ wantError error
}{
- {
- name: "Double parameter type",
+ "Double parameter type": {
parameters: []*api_v1_beta1.ParameterSpec{
{
Name: "param-double",
@@ -69,16 +65,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) {
},
},
},
- want: map[string]interface{}{
+ wantSearchSpace: map[string]interface{}{
"param-double": goptuna.UniformDistribution{
High: 5.5,
Low: 1.5,
},
},
- wantErr: false,
},
- {
- name: "Double parameter type with step",
+ "Double parameter type with step": {
parameters: []*api_v1_beta1.ParameterSpec{
{
Name: "param-double",
@@ -90,17 +84,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) {
},
},
},
- want: map[string]interface{}{
+ wantSearchSpace: map[string]interface{}{
"param-double": goptuna.DiscreteUniformDistribution{
High: 5.5,
Low: 1.5,
Q: 0.5,
},
},
- wantErr: false,
},
- {
- name: "Int parameter type",
+ "Int parameter type": {
parameters: []*api_v1_beta1.ParameterSpec{
{
Name: "param-int",
@@ -111,16 +103,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) {
},
},
},
- want: map[string]interface{}{
+ wantSearchSpace: map[string]interface{}{
"param-int": goptuna.IntUniformDistribution{
High: 5,
Low: 1,
},
},
- wantErr: false,
},
- {
- name: "Int parameter type with step",
+ "Int parameter type with step": {
parameters: []*api_v1_beta1.ParameterSpec{
{
Name: "param-int",
@@ -132,17 +122,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) {
},
},
},
- want: map[string]interface{}{
+ wantSearchSpace: map[string]interface{}{
"param-int": goptuna.StepIntUniformDistribution{
High: 5,
Low: 1,
Step: 2,
},
},
- wantErr: false,
},
- {
- name: "Discrete parameter type",
+ "Discrete parameter type": {
parameters: []*api_v1_beta1.ParameterSpec{
{
Name: "param-discrete",
@@ -152,15 +140,13 @@ func Test_toGoptunaSearchSpace(t *testing.T) {
},
},
},
- want: map[string]interface{}{
+ wantSearchSpace: map[string]interface{}{
"param-discrete": goptuna.CategoricalDistribution{
Choices: []string{"3", "2", "6"},
},
},
- wantErr: false,
},
- {
- name: "Categorical parameter type",
+ "Categorical parameter type": {
parameters: []*api_v1_beta1.ParameterSpec{
{
Name: "param-categorical",
@@ -170,23 +156,21 @@ func Test_toGoptunaSearchSpace(t *testing.T) {
},
},
},
- want: map[string]interface{}{
+ wantSearchSpace: map[string]interface{}{
"param-categorical": goptuna.CategoricalDistribution{
Choices: []string{"cat1", "cat2", "cat3"},
},
},
- wantErr: false,
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := toGoptunaSearchSpace(tt.parameters)
- if (err != nil) != tt.wantErr {
- t.Errorf("toGoptunaSearchSpace() error = %v, wantErr %v", err, tt.wantErr)
- return
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ got, err := toGoptunaSearchSpace(tc.parameters)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from toGoptunaSearchSpace (-want,+got):\n%s", diff)
}
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("toGoptunaSearchSpace() got = %v, want %v", got, tt.want)
+ if diff := cmp.Diff(tc.wantSearchSpace, got); len(diff) != 0 {
+ t.Errorf("Unexpected search space from toGoptunaSearchSpace (-want,+got):\n%s", diff)
}
})
}
diff --git a/pkg/suggestion/v1beta1/hyperband/parameter.py b/pkg/suggestion/v1beta1/hyperband/parameter.py
index b6713ab748f..89a58318b7e 100644
--- a/pkg/suggestion/v1beta1/hyperband/parameter.py
+++ b/pkg/suggestion/v1beta1/hyperband/parameter.py
@@ -42,8 +42,17 @@ class ParameterConfig:
{"name": "cat_param", "values": ["true", "false"], "number": 2}.
"""
- def __init__(self, name_ids, dim, lower_bounds, upper_bounds,
- parameter_types, names, discrete_info, categorical_info):
+ def __init__(
+ self,
+ name_ids,
+ dim,
+ lower_bounds,
+ upper_bounds,
+ parameter_types,
+ names,
+ discrete_info,
+ categorical_info,
+ ):
self.name_ids = name_ids
self.dim = dim
self.lower_bounds = np.array(lower_bounds).reshape((1, dim))
@@ -62,6 +71,7 @@ def create_scaler(self):
return scaler
def random_sample(self):
- new_sample = np.random.uniform(self.lower_bounds, self.upper_bounds,
- size=(1, self.dim))
+ new_sample = np.random.uniform(
+ self.lower_bounds, self.upper_bounds, size=(1, self.dim)
+ )
return new_sample
diff --git a/pkg/suggestion/v1beta1/hyperband/parsing_util.py b/pkg/suggestion/v1beta1/hyperband/parsing_util.py
index 4136fe7246b..81860a7fbe3 100644
--- a/pkg/suggestion/v1beta1/hyperband/parsing_util.py
+++ b/pkg/suggestion/v1beta1/hyperband/parsing_util.py
@@ -26,14 +26,14 @@
def _deal_with_discrete(feasible_values, current_value):
- """ function to embed the current values to the feasible discrete space"""
+ """function to embed the current values to the feasible discrete space"""
diff = np.subtract(feasible_values, current_value)
diff = np.absolute(diff)
return feasible_values[np.argmin(diff)]
def _deal_with_categorical(feasible_values, one_hot_values):
- """ function to do the one hot encoding of the categorical values """
+ """function to do the one hot encoding of the categorical values"""
index = np.argmax(one_hot_values)
return feasible_values[int(index)]
@@ -61,17 +61,18 @@ def parse_parameter_configs(parameter_configs):
discrete_values = [int(x) for x in param.feasible_space.list]
new_lower = min(discrete_values)
new_upper = max(discrete_values)
- discrete_info.append(
- {"name": param.name, "values": discrete_values})
+ discrete_info.append({"name": param.name, "values": discrete_values})
elif param.parameter_type == api_pb2.CATEGORICAL:
num_feasible = len(param.feasible_space.list)
new_lower = [0 for _ in range(num_feasible)]
new_upper = [1 for _ in range(num_feasible)]
- categorical_info.append({
- "name": param.name,
- "values": param.feasible_space.list,
- "number": num_feasible,
- })
+ categorical_info.append(
+ {
+ "name": param.name,
+ "values": param.feasible_space.list,
+ "number": num_feasible,
+ }
+ )
if isinstance(new_lower, Iterable): # handles categorical parameters
lower_bounds.extend(new_lower)
upper_bounds.extend(new_upper)
@@ -80,14 +81,16 @@ def parse_parameter_configs(parameter_configs):
lower_bounds.append(new_lower)
upper_bounds.append(new_upper)
dim += 1
- parsed_config = ParameterConfig(name_ids,
- dim,
- lower_bounds,
- upper_bounds,
- parameter_types,
- names,
- discrete_info,
- categorical_info)
+ parsed_config = ParameterConfig(
+ name_ids,
+ dim,
+ lower_bounds,
+ upper_bounds,
+ parameter_types,
+ names,
+ discrete_info,
+ categorical_info,
+ )
return parsed_config
@@ -97,8 +100,7 @@ def parse_previous_observations(parameters_list, dim, name_id, types, categorica
offset = 0
for p in parameters:
map_id = name_id[p.name]
- if types[map_id] in [api_pb2.DOUBLE, api_pb2.INT,
- api_pb2.DISCRETE]:
+ if types[map_id] in [api_pb2.DOUBLE, api_pb2.INT, api_pb2.DISCRETE]:
parsed_X[row_idx, offset] = float(p.value)
offset += 1
elif types[map_id] == api_pb2.CATEGORICAL:
@@ -120,8 +122,10 @@ def parse_metric(y_train, goal):
return y_array
-def parse_x_next_vector(x_next, param_types, param_names, discrete_info, categorical_info):
- """ parse the next suggestion to the proper format """
+def parse_x_next_vector(
+ x_next, param_types, param_names, discrete_info, categorical_info
+):
+ """parse the next suggestion to the proper format"""
counter = 0
result = []
if isinstance(x_next, np.ndarray):
@@ -136,8 +140,7 @@ def parse_x_next_vector(x_next, param_types, param_names, discrete_info, categor
elif par_type == api_pb2.DISCRETE:
for param in discrete_info:
if param["name"] == par_name:
- value = _deal_with_discrete(param["values"],
- x_next[counter])
+ value = _deal_with_discrete(param["values"], x_next[counter])
counter = counter + 1
break
elif par_type == api_pb2.CATEGORICAL:
@@ -145,7 +148,7 @@ def parse_x_next_vector(x_next, param_types, param_names, discrete_info, categor
if param["name"] == par_name:
value = _deal_with_categorical(
feasible_values=param["values"],
- one_hot_values=x_next[counter:counter + param["number"]],
+ one_hot_values=x_next[counter : counter + param["number"]],
)
counter = counter + param["number"]
break
diff --git a/pkg/suggestion/v1beta1/hyperband/service.py b/pkg/suggestion/v1beta1/hyperband/service.py
index 50919826d0a..e69a15865f0 100644
--- a/pkg/suggestion/v1beta1/hyperband/service.py
+++ b/pkg/suggestion/v1beta1/hyperband/service.py
@@ -13,21 +13,18 @@
# limitations under the License.
import logging
-from logging import DEBUG
-from logging import getLogger
-from logging import StreamHandler
import math
import traceback
+from logging import DEBUG, StreamHandler, getLogger
import grpc
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
from pkg.suggestion.v1beta1.hyperband import parsing_util
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
logger = getLogger(__name__)
-FORMAT = '%(asctime)-15s Experiment %(experiment_name)s %(message)s'
+FORMAT = "%(asctime)-15s Experiment %(experiment_name)s %(message)s"
logging.basicConfig(format=FORMAT)
handler = StreamHandler()
handler.setLevel(DEBUG)
@@ -55,12 +52,17 @@ def GetSuggestions(self, request, context):
trials = self._make_bracket(experiment, param)
for trial in trials:
- reply.parameter_assignments.add(assignments=trial.parameter_assignments.assignments)
+ reply.parameter_assignments.add(
+ assignments=trial.parameter_assignments.assignments
+ )
reply.algorithm.CopyFrom(HyperBandParam.generate(param))
return reply
except Exception as e:
- logger.error("Fail to generate trials: \n%s",
- traceback.format_exc(), extra={"experiment_name": experiment.name})
+ logger.error(
+ "Fail to generate trials: \n%s",
+ traceback.format_exc(),
+ extra={"experiment_name": experiment.name},
+ )
raise e
def _update_hbParameters(self, param):
@@ -73,10 +75,13 @@ def _new_hbParameters(self, param):
param.current_i = 0
if param.current_s >= 0:
# when param.current_s < 0, hyperband algorithm reaches the end
- param.n = int(math.ceil(float(param.s_max + 1) * (
- float(param.eta**param.current_s) / float(param.current_s+1))))
- param.r = param.r_l * \
- param.eta**(-param.current_s)
+ param.n = int(
+ math.ceil(
+ float(param.s_max + 1)
+ * (float(param.eta**param.current_s) / float(param.current_s + 1))
+ )
+ )
+ param.r = param.r_l * param.eta ** (-param.current_s)
def _make_bracket(self, experiment, param):
if param.evaluating_trials == 0:
@@ -88,48 +93,76 @@ def _make_bracket(self, experiment, param):
else:
param.evaluating_trials = 0
- logger.info("HyperBand Param eta %d.",
- param.eta, extra={"experiment_name": experiment.name})
- logger.info("HyperBand Param R %d.",
- param.r_l, extra={"experiment_name": experiment.name})
- logger.info("HyperBand Param sMax %d.",
- param.s_max, extra={"experiment_name": experiment.name})
- logger.info("HyperBand Param B %d.",
- param.b_l, extra={"experiment_name": experiment.name})
- logger.info("HyperBand Param n %d.",
- param.n, extra={"experiment_name": experiment.name})
- logger.info("HyperBand Param r %d.",
- param.r, extra={"experiment_name": experiment.name})
- logger.info("HyperBand Param s %d.",
- param.current_s, extra={"experiment_name": experiment.name})
- logger.info("HyperBand Param i %d.",
- param.current_i, extra={"experiment_name": experiment.name})
- logger.info("HyperBand evaluating trials count %d.",
- param.evaluating_trials, extra={"experiment_name": experiment.name})
- logger.info("HyperBand budget resource name %s.",
- param.resource_name, extra={"experiment_name": experiment.name})
+ logger.info(
+ "HyperBand Param eta %d.",
+ param.eta,
+ extra={"experiment_name": experiment.name},
+ )
+ logger.info(
+ "HyperBand Param R %d.",
+ param.r_l,
+ extra={"experiment_name": experiment.name},
+ )
+ logger.info(
+ "HyperBand Param sMax %d.",
+ param.s_max,
+ extra={"experiment_name": experiment.name},
+ )
+ logger.info(
+ "HyperBand Param B %d.",
+ param.b_l,
+ extra={"experiment_name": experiment.name},
+ )
+ logger.info(
+ "HyperBand Param n %d.", param.n, extra={"experiment_name": experiment.name}
+ )
+ logger.info(
+ "HyperBand Param r %d.", param.r, extra={"experiment_name": experiment.name}
+ )
+ logger.info(
+ "HyperBand Param s %d.",
+ param.current_s,
+ extra={"experiment_name": experiment.name},
+ )
+ logger.info(
+ "HyperBand Param i %d.",
+ param.current_i,
+ extra={"experiment_name": experiment.name},
+ )
+ logger.info(
+ "HyperBand evaluating trials count %d.",
+ param.evaluating_trials,
+ extra={"experiment_name": experiment.name},
+ )
+ logger.info(
+ "HyperBand budget resource name %s.",
+ param.resource_name,
+ extra={"experiment_name": experiment.name},
+ )
if param.evaluating_trials == 0:
self._new_hbParameters(param)
return trialSpecs
def _make_child_bracket(self, experiment, param):
- n_i = math.ceil(param.n * param.eta**(-param.current_i))
+ n_i = math.ceil(param.n * param.eta ** (-param.current_i))
top_trials_num = int(math.ceil(n_i / param.eta))
self._update_hbParameters(param)
r_i = int(param.r * param.eta**param.current_i)
last_trials = self._get_top_trial(
- param.evaluating_trials, top_trials_num, experiment)
- trialSpecs = self._copy_trials(
- last_trials, r_i, param.resource_name)
+ param.evaluating_trials, top_trials_num, experiment
+ )
+ trialSpecs = self._copy_trials(last_trials, r_i, param.resource_name)
- logger.info("Generate %d trials by child bracket.",
- top_trials_num, extra={"experiment_name": experiment.name})
+ logger.info(
+ "Generate %d trials by child bracket.",
+ top_trials_num,
+ extra={"experiment_name": experiment.name},
+ )
return trialSpecs
def _get_last_trials(self, all_trials, latest_trials_num):
- sorted_trials = sorted(
- all_trials, key=lambda trial: trial.status.start_time)
+ sorted_trials = sorted(all_trials, key=lambda trial: trial.status.start_time)
if len(sorted_trials) > latest_trials_num:
return sorted_trials[-latest_trials_num:]
else:
@@ -151,11 +184,14 @@ def get_objective_value(t):
for t in latest_trials:
if t.status.condition != api_pb2.TrialStatus.TrialConditionType.SUCCEEDED:
raise Exception(
- "There are some trials which are not completed yet for experiment %s." % experiment.name)
+ "There are some trials which are not completed yet for experiment %s."
+ % experiment.name
+ )
if objective_type == api_pb2.MAXIMIZE:
top_trials.extend(
- sorted(latest_trials, key=get_objective_value, reverse=True))
+ sorted(latest_trials, key=get_objective_value, reverse=True)
+ )
else:
top_trials.extend(sorted(latest_trials, key=get_objective_value))
return top_trials[:top_trials_num]
@@ -169,8 +205,9 @@ def _copy_trials(self, trials, r_i, resourceName):
value = str(r_i)
else:
value = assignment.value
- trial_spec.parameter_assignments.assignments.add(name=assignment.name,
- value=value)
+ trial_spec.parameter_assignments.assignments.add(
+ name=assignment.name, value=value
+ )
trialSpecs.append(trial_spec)
return trialSpecs
@@ -178,7 +215,8 @@ def _make_master_bracket(self, experiment, param):
n = param.n
r = int(param.r)
parameter_config = parsing_util.parse_parameter_configs(
- experiment.spec.parameter_specs.parameters)
+ experiment.spec.parameter_specs.parameters
+ )
trial_specs = []
for _ in range(n):
sample = parameter_config.random_sample()
@@ -187,16 +225,21 @@ def _make_master_bracket(self, experiment, param):
parameter_config.parameter_types,
parameter_config.names,
parameter_config.discrete_info,
- parameter_config.categorical_info)
+ parameter_config.categorical_info,
+ )
trial_spec = api_pb2.TrialSpec()
for hp in suggestion:
- if hp['name'] == param.resource_name:
- hp['value'] = str(r)
- trial_spec.parameter_assignments.assignments.add(name=hp['name'],
- value=str(hp['value']))
+ if hp["name"] == param.resource_name:
+ hp["value"] = str(r)
+ trial_spec.parameter_assignments.assignments.add(
+ name=hp["name"], value=str(hp["value"])
+ )
trial_specs.append(trial_spec)
- logger.info("Generate %d trials by master bracket.",
- n, extra={"experiment_name": experiment.name})
+ logger.info(
+ "Generate %d trials by master bracket.",
+ n,
+ extra={"experiment_name": experiment.name},
+ )
return trial_specs
def _set_validate_context_error(self, context, error_message):
@@ -212,14 +255,20 @@ def ValidateAlgorithmSettings(self, request, context):
for setting in settings:
setting_dict[setting.name] = setting.value
if "r_l" not in setting_dict or "resource_name" not in setting_dict:
- return self._set_validate_context_error(context, "r_l and resource_name must be set.")
+ return self._set_validate_context_error(
+ context, "r_l and resource_name must be set."
+ )
try:
rl = float(setting_dict["r_l"])
except Exception:
- return self._set_validate_context_error(context, "r_l must be a positive float number.")
+ return self._set_validate_context_error(
+ context, "r_l must be a positive float number."
+ )
else:
if rl < 0:
- return self._set_validate_context_error(context, "r_l must be a positive float number.")
+ return self._set_validate_context_error(
+ context, "r_l must be a positive float number."
+ )
if "eta" in setting_dict:
eta = int(float(setting_dict["eta"]))
@@ -228,11 +277,12 @@ def ValidateAlgorithmSettings(self, request, context):
else:
eta = 3
- smax = int(math.log(rl)/math.log(eta))
+ smax = int(math.log(rl) / math.log(eta))
max_parallel = int(math.ceil(eta**smax))
if request.experiment.spec.parallel_trial_count < max_parallel:
- return self._set_validate_context_error(context,
- "parallelTrialCount must be not less than %d." % max_parallel)
+ return self._set_validate_context_error(
+ context, "parallelTrialCount must be not less than %d." % max_parallel
+ )
valid_resourceName = False
for param in params:
@@ -240,17 +290,27 @@ def ValidateAlgorithmSettings(self, request, context):
valid_resourceName = True
break
if not valid_resourceName:
- return self._set_validate_context_error(context,
- "value of resource_name setting must be in parameters.")
+ return self._set_validate_context_error(
+ context, "value of resource_name setting must be in parameters."
+ )
return api_pb2.ValidateAlgorithmSettingsReply()
class HyperBandParam(object):
- def __init__(self, eta=3, s_max=-1, r_l=-1,
- b_l=-1, r=-1, n=-1, current_s=-2,
- current_i=-1, resource_name="",
- evaluating_trials=0):
+ def __init__(
+ self,
+ eta=3,
+ s_max=-1,
+ r_l=-1,
+ b_l=-1,
+ r=-1,
+ n=-1,
+ current_s=-2,
+ current_i=-1,
+ resource_name="",
+ evaluating_trials=0,
+ ):
self.eta = eta
self.s_max = s_max
self.r_l = r_l
@@ -265,45 +325,24 @@ def __init__(self, eta=3, s_max=-1, r_l=-1,
@staticmethod
def generate(param):
algorithm_settings = [
+ api_pb2.AlgorithmSetting(name="eta", value=str(param.eta)),
+ api_pb2.AlgorithmSetting(name="s_max", value=str(param.s_max)),
+ api_pb2.AlgorithmSetting(name="r_l", value=str(param.r_l)),
+ api_pb2.AlgorithmSetting(name="b_l", value=str(param.b_l)),
+ api_pb2.AlgorithmSetting(name="r", value=str(param.r)),
+ api_pb2.AlgorithmSetting(name="n", value=str(param.n)),
+ api_pb2.AlgorithmSetting(name="current_s", value=str(param.current_s)),
+ api_pb2.AlgorithmSetting(name="current_i", value=str(param.current_i)),
+ api_pb2.AlgorithmSetting(name="resource_name", value=param.resource_name),
api_pb2.AlgorithmSetting(
- name="eta",
- value=str(param.eta)
- ), api_pb2.AlgorithmSetting(
- name="s_max",
- value=str(param.s_max)
- ), api_pb2.AlgorithmSetting(
- name="r_l",
- value=str(param.r_l)
- ), api_pb2.AlgorithmSetting(
- name="b_l",
- value=str(param.b_l)
- ), api_pb2.AlgorithmSetting(
- name="r",
- value=str(param.r)
- ), api_pb2.AlgorithmSetting(
- name="n",
- value=str(param.n)
- ), api_pb2.AlgorithmSetting(
- name="current_s",
- value=str(param.current_s)
- ), api_pb2.AlgorithmSetting(
- name="current_i",
- value=str(param.current_i)
- ), api_pb2.AlgorithmSetting(
- name="resource_name",
- value=param.resource_name
- ), api_pb2.AlgorithmSetting(
- name="evaluating_trials",
- value=str(param.evaluating_trials)
- )]
- return api_pb2.AlgorithmSpec(
- algorithm_settings=algorithm_settings
- )
+ name="evaluating_trials", value=str(param.evaluating_trials)
+ ),
+ ]
+ return api_pb2.AlgorithmSpec(algorithm_settings=algorithm_settings)
@staticmethod
def convert(alg_settings):
- """Convert the algorithm settings to HyperBandParam.
- """
+ """Convert the algorithm settings to HyperBandParam."""
param = HyperBandParam()
# Set the param from the algorithm settings.
for setting in alg_settings:
@@ -328,8 +367,7 @@ def convert(alg_settings):
elif setting.name == "resource_name":
param.resource_name = setting.value
else:
- logger.info(
- "Unknown HyperBand Param %s, ignore it", setting.name)
+ logger.info("Unknown HyperBand Param %s, ignore it", setting.name)
if param.current_s == -1:
# Hyperband outlerloop has finished
logger.info("HyperBand outlerloop has finished.")
@@ -339,8 +377,7 @@ def convert(alg_settings):
if param.eta <= 0:
param.eta = 3
if param.s_max < 0:
- param.s_max = int(
- math.log(param.r_l) / math.log(param.eta))
+ param.s_max = int(math.log(param.r_l) / math.log(param.eta))
if param.b_l < 0:
param.b_l = (param.s_max + 1) * param.r_l
if param.current_s < 0:
@@ -348,10 +385,13 @@ def convert(alg_settings):
if param.current_i < 0:
param.current_i = 0
if param.n < 0:
- param.n = int(math.ceil(float(param.s_max + 1) * (
- float(param.eta**param.current_s) / float(param.current_s+1))))
+ param.n = int(
+ math.ceil(
+ float(param.s_max + 1)
+ * (float(param.eta**param.current_s) / float(param.current_s + 1))
+ )
+ )
if param.r < 0:
- param.r = param.r_l * \
- param.eta**(-param.current_s)
+ param.r = param.r_l * param.eta ** (-param.current_s)
return param
diff --git a/pkg/suggestion/v1beta1/hyperopt/base_service.py b/pkg/suggestion/v1beta1/hyperopt/base_service.py
index 894a98637bf..c794ae6fb60 100644
--- a/pkg/suggestion/v1beta1/hyperopt/base_service.py
+++ b/pkg/suggestion/v1beta1/hyperopt/base_service.py
@@ -17,11 +17,13 @@
import hyperopt
import numpy as np
-from pkg.suggestion.v1beta1.internal.constant import CATEGORICAL
-from pkg.suggestion.v1beta1.internal.constant import DISCRETE
-from pkg.suggestion.v1beta1.internal.constant import DOUBLE
-from pkg.suggestion.v1beta1.internal.constant import INTEGER
-from pkg.suggestion.v1beta1.internal.constant import MAX_GOAL
+from pkg.suggestion.v1beta1.internal.constant import (
+ CATEGORICAL,
+ DISCRETE,
+ DOUBLE,
+ INTEGER,
+ MAX_GOAL,
+)
from pkg.suggestion.v1beta1.internal.trial import Assignment
logger = logging.getLogger(__name__)
@@ -31,14 +33,13 @@
class BaseHyperoptService(object):
- def __init__(self,
- algorithm_name=TPE_ALGORITHM_NAME,
- algorithm_conf=None,
- search_space=None):
+ def __init__(
+ self, algorithm_name=TPE_ALGORITHM_NAME, algorithm_conf=None, search_space=None
+ ):
self.algorithm_name = algorithm_name
self.algorithm_conf = algorithm_conf or {}
# pop common configurations
- random_state = self.algorithm_conf.pop('random_state', None)
+ random_state = self.algorithm_conf.pop("random_state", None)
if self.algorithm_name == TPE_ALGORITHM_NAME:
self.hyperopt_algorithm = hyperopt.tpe.suggest
@@ -57,26 +58,26 @@ def __init__(self,
self.is_first_run = True
def create_hyperopt_domain(self):
- # Construct search space, example: {"x": hyperopt.hp.uniform('x', -10, 10), "x2": hyperopt.hp.uniform('x2', -10, 10)}
+ # Construct search space, example: {"x": hyperopt.hp.uniform('x', -10, 10), "x2":
+ # hyperopt.hp.uniform('x2', -10, 10)}
hyperopt_search_space = {}
for param in self.search_space.params:
if param.type == INTEGER:
hyperopt_search_space[param.name] = hyperopt.hp.quniform(
- param.name,
- float(param.min),
- float(param.max),
- float(param.step))
+ param.name, float(param.min), float(param.max), float(param.step)
+ )
elif param.type == DOUBLE:
hyperopt_search_space[param.name] = hyperopt.hp.uniform(
- param.name,
- float(param.min),
- float(param.max))
+ param.name, float(param.min), float(param.max)
+ )
elif param.type == CATEGORICAL or param.type == DISCRETE:
hyperopt_search_space[param.name] = hyperopt.hp.choice(
- param.name, param.list)
+ param.name, param.list
+ )
self.hyperopt_domain = hyperopt.Domain(
- None, hyperopt_search_space, pass_expr_memo_ctrl=None)
+ None, hyperopt_search_space, pass_expr_memo_ctrl=None
+ )
def create_fmin(self):
self.fmin = hyperopt.FMinIter(
@@ -85,7 +86,8 @@ def create_fmin(self):
trials=hyperopt.Trials(),
max_evals=-1,
rstate=self.hyperopt_rstate,
- verbose=False)
+ verbose=False,
+ )
self.fmin.catch_eval_exceptions = False
@@ -107,12 +109,16 @@ def getSuggestions(self, trials, current_request_number):
new_id = self.fmin.trials.new_trial_ids(1)
hyperopt_trial_new_ids.append(new_id[0])
hyperopt_trial_miscs_idxs = {}
- # Example: {'l1_normalization': [0.1], 'learning_rate': [0.1], 'hidden2': [1], 'optimizer': [1]}
+ # Example: {'l1_normalization': [0.1], 'learning_rate': [0.1],
+ # 'hidden2': [1], 'optimizer': [1]}
hyperopt_trial_miscs_vals = {}
# Insert Trial assignment to the misc
hyperopt_trial_misc = dict(
- tid=new_id[0], cmd=self.hyperopt_domain.cmd, workdir=self.hyperopt_domain.workdir)
+ tid=new_id[0],
+ cmd=self.hyperopt_domain.cmd,
+ workdir=self.hyperopt_domain.workdir,
+ )
for param in self.search_space.params:
parameter_value = None
for assignment in trial.assignments:
@@ -135,9 +141,7 @@ def getSuggestions(self, trials, current_request_number):
hyperopt_trial_miscs.append(hyperopt_trial_misc)
# Insert Trial name to the spec
- hyperopt_trial_spec = {
- "trial-name": trial.name
- }
+ hyperopt_trial_spec = {"trial-name": trial.name}
hyperopt_trial_specs.append(hyperopt_trial_spec)
# Insert Trial result to the result
@@ -145,22 +149,23 @@ def getSuggestions(self, trials, current_request_number):
# TODO: Do we need to analyse additional_metrics?
objective_for_hyperopt = float(trial.target_metric.value)
if self.search_space.goal == MAX_GOAL:
- # Now hyperopt only supports fmin and we need to reverse objective value for maximization
+ # Now hyperopt only supports fmin and we need to reverse
+ # objective value for maximization
objective_for_hyperopt = -1 * objective_for_hyperopt
hyperopt_trial_result = {
"loss": objective_for_hyperopt,
- "status": hyperopt.STATUS_OK
+ "status": hyperopt.STATUS_OK,
}
hyperopt_trial_results.append(hyperopt_trial_result)
if len(trials) > 0:
-
# Create new Trial doc
hyperopt_trials = hyperopt.Trials().new_trial_docs(
tids=hyperopt_trial_new_ids,
specs=hyperopt_trial_specs,
results=hyperopt_trial_results,
- miscs=hyperopt_trial_miscs)
+ miscs=hyperopt_trial_miscs,
+ )
for i, _ in enumerate(hyperopt_trials):
hyperopt_trials[i]["state"] = hyperopt.JOB_STATE_DONE
@@ -208,7 +213,8 @@ def getSuggestions(self, trials, current_request_number):
new_ids=hyperopt_trial_new_ids,
domain=self.fmin.domain,
trials=self.fmin.trials,
- seed=random_state)
+ seed=random_state,
+ )
elif self.algorithm_name == TPE_ALGORITHM_NAME:
# n_startup_jobs indicates for how many Trials we run random suggestion
# This must be current_request_number value
@@ -222,24 +228,30 @@ def getSuggestions(self, trials, current_request_number):
trials=self.fmin.trials,
seed=random_state,
n_startup_jobs=current_request_number,
- **self.algorithm_conf)
+ **self.algorithm_conf,
+ )
self.is_first_run = False
else:
for i in range(current_request_number):
# hyperopt_algorithm always returns one new Trial
- new_trials.append(self.hyperopt_algorithm(
- new_ids=[hyperopt_trial_new_ids[i]],
- domain=self.fmin.domain,
- trials=self.fmin.trials,
- seed=random_state,
- n_startup_jobs=current_request_number,
- **self.algorithm_conf)[0])
+ new_trials.append(
+ self.hyperopt_algorithm(
+ new_ids=[hyperopt_trial_new_ids[i]],
+ domain=self.fmin.domain,
+ trials=self.fmin.trials,
+ seed=random_state,
+ n_startup_jobs=current_request_number,
+ **self.algorithm_conf,
+ )[0]
+ )
# Construct return advisor Trials from new hyperopt Trials
list_of_assignments = []
for trial in new_trials:
- vals = trial['misc']['vals']
- list_of_assignments.append(BaseHyperoptService.convert(self.search_space, vals))
+ vals = trial["misc"]["vals"]
+ list_of_assignments.append(
+ BaseHyperoptService.convert(self.search_space, vals)
+ )
if len(list_of_assignments) > 0:
logger.info("GetSuggestions returns {} new Trial\n".format(len(new_trials)))
@@ -256,5 +268,6 @@ def convert(search_space, vals):
assignments.append(Assignment(param.name, vals[param.name][0]))
elif param.type == CATEGORICAL or param.type == DISCRETE:
assignments.append(
- Assignment(param.name, param.list[vals[param.name][0]]))
+ Assignment(param.name, param.list[vals[param.name][0]])
+ )
return assignments
diff --git a/pkg/suggestion/v1beta1/hyperopt/service.py b/pkg/suggestion/v1beta1/hyperopt/service.py
index d3a9cdfe7c5..ab837173375 100644
--- a/pkg/suggestion/v1beta1/hyperopt/service.py
+++ b/pkg/suggestion/v1beta1/hyperopt/service.py
@@ -16,14 +16,11 @@
import grpc
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
from pkg.suggestion.v1beta1.hyperopt.base_service import BaseHyperoptService
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
-from pkg.suggestion.v1beta1.internal.search_space import \
- HyperParameterSearchSpace
-from pkg.suggestion.v1beta1.internal.trial import Assignment
-from pkg.suggestion.v1beta1.internal.trial import Trial
+from pkg.suggestion.v1beta1.internal.search_space import HyperParameterSearchSpace
+from pkg.suggestion.v1beta1.internal.trial import Assignment, Trial
logger = logging.getLogger(__name__)
@@ -40,25 +37,28 @@ def GetSuggestions(self, request, context):
Main function to provide suggestion.
"""
name, config = OptimizerConfiguration.convert_algorithm_spec(
- request.experiment.spec.algorithm)
+ request.experiment.spec.algorithm
+ )
if self.is_first_run:
search_space = HyperParameterSearchSpace.convert(request.experiment)
self.base_service = BaseHyperoptService(
- algorithm_name=name,
- algorithm_conf=config,
- search_space=search_space)
+ algorithm_name=name, algorithm_conf=config, search_space=search_space
+ )
self.is_first_run = False
trials = Trial.convert(request.trials)
- new_assignments = self.base_service.getSuggestions(trials, request.current_request_number)
+ new_assignments = self.base_service.getSuggestions(
+ trials, request.current_request_number
+ )
return api_pb2.GetSuggestionsReply(
parameter_assignments=Assignment.generate(new_assignments)
)
def ValidateAlgorithmSettings(self, request, context):
is_valid, message = OptimizerConfiguration.validate_algorithm_spec(
- request.experiment.spec.algorithm)
+ request.experiment.spec.algorithm
+ )
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(message)
@@ -68,15 +68,15 @@ def ValidateAlgorithmSettings(self, request, context):
class OptimizerConfiguration:
__conversion_dict = {
- 'tpe': {
- 'gamma': lambda x: float(x),
- 'prior_weight': lambda x: float(x),
- 'n_EI_candidates': lambda x: int(x),
+ "tpe": {
+ "gamma": lambda x: float(x),
+ "prior_weight": lambda x: float(x),
+ "n_EI_candidates": lambda x: int(x),
"random_state": lambda x: int(x),
},
"random": {
"random_state": lambda x: int(x),
- }
+ },
}
@classmethod
@@ -92,9 +92,9 @@ def convert_algorithm_spec(cls, algorithm_spec):
@classmethod
def validate_algorithm_spec(cls, algorithm_spec):
algo_name = algorithm_spec.algorithm_name
- if algo_name == 'tpe':
+ if algo_name == "tpe":
return cls._validate_tpe_setting(algorithm_spec.algorithm_settings)
- elif algo_name == 'random':
+ elif algo_name == "random":
return cls._validate_random_setting(algorithm_spec.algorithm_settings)
else:
return False, "unknown algorithm name {}".format(algo_name)
@@ -103,23 +103,24 @@ def validate_algorithm_spec(cls, algorithm_spec):
def _validate_tpe_setting(cls, algorithm_settings):
for s in algorithm_settings:
try:
- if s.name == 'gamma':
+ if s.name == "gamma":
if not 1 > float(s.value) > 0:
return False, "gamma should be in the range of (0, 1)"
- elif s.name == 'prior_weight':
+ elif s.name == "prior_weight":
if not float(s.value) > 0:
return False, "prior_weight should be great than zero"
- elif s.name == 'n_EI_candidates':
+ elif s.name == "n_EI_candidates":
if not int(s.value) > 0:
return False, "n_EI_candidates should be great than zero"
- elif s.name == 'random_state':
+ elif s.name == "random_state":
if not int(s.value) >= 0:
return False, "random_state should be great or equal than zero"
else:
return False, "unknown setting {} for algorithm tpe".format(s.name)
except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(
- name=s.name, value=s.value, exception=e)
+ name=s.name, value=s.value, exception=e
+ )
return True, ""
@@ -127,13 +128,16 @@ def _validate_tpe_setting(cls, algorithm_settings):
def _validate_random_setting(cls, algorithm_settings):
for s in algorithm_settings:
try:
- if s.name == 'random_state':
+ if s.name == "random_state":
if not (int(s.value) >= 0):
return False, "random_state should be great or equal than zero"
else:
- return False, "unknown setting {} for algorithm random".format(s.name)
+ return False, "unknown setting {} for algorithm random".format(
+ s.name
+ )
except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(
- name=s.name, value=s.value, exception=e)
+ name=s.name, value=s.value, exception=e
+ )
return True, ""
diff --git a/pkg/suggestion/v1beta1/internal/base_health_service.py b/pkg/suggestion/v1beta1/internal/base_health_service.py
index 3660c884ed8..13950975b66 100644
--- a/pkg/suggestion/v1beta1/internal/base_health_service.py
+++ b/pkg/suggestion/v1beta1/internal/base_health_service.py
@@ -22,7 +22,7 @@
from pkg.apis.manager.health.python import health_pb2 as _health_pb2
from pkg.apis.manager.health.python import health_pb2_grpc as _health_pb2_grpc
-SERVICE_NAME = _health_pb2.DESCRIPTOR.services_by_name['Health'].full_name
+SERVICE_NAME = _health_pb2.DESCRIPTOR.services_by_name["Health"].full_name
class _Watcher:
@@ -74,9 +74,7 @@ def send_response_callback(response):
class HealthServicer(_health_pb2_grpc.HealthServicer):
"""Servicer handling RPCs for service statuses."""
- def __init__(self,
- experimental_non_blocking=True,
- experimental_thread_pool=None):
+ def __init__(self, experimental_non_blocking=True, experimental_thread_pool=None):
self._lock = threading.RLock()
self._server_status = {}
self._send_response_callbacks = {}
@@ -89,8 +87,7 @@ def _on_close_callback(self, send_response_callback, service):
def callback():
with self._lock:
- self._send_response_callbacks[service].remove(
- send_response_callback)
+ self._send_response_callbacks[service].remove(send_response_callback)
send_response_callback(None)
return callback
@@ -114,19 +111,22 @@ def Watch(self, request, context, send_response_callback=None):
# generator.
blocking_watcher = _Watcher()
send_response_callback = _watcher_to_send_response_callback_adapter(
- blocking_watcher)
+ blocking_watcher
+ )
service = request.service
with self._lock:
status = self._server_status.get(service)
if status is None:
- status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member
- send_response_callback(
- _health_pb2.HealthCheckResponse(status=status))
+ status = (
+ _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN
+ ) # pylint: disable=no-member
+ send_response_callback(_health_pb2.HealthCheckResponse(status=status))
if service not in self._send_response_callbacks:
self._send_response_callbacks[service] = set()
self._send_response_callbacks[service].add(send_response_callback)
context.add_callback(
- self._on_close_callback(send_response_callback, service))
+ self._on_close_callback(send_response_callback, service)
+ )
return blocking_watcher
def set(self, service, status):
@@ -144,9 +144,11 @@ def set(self, service, status):
self._server_status[service] = status
if service in self._send_response_callbacks:
for send_response_callback in self._send_response_callbacks[
- service]:
+ service
+ ]:
send_response_callback(
- _health_pb2.HealthCheckResponse(status=status))
+ _health_pb2.HealthCheckResponse(status=status)
+ )
def enter_graceful_shutdown(self):
"""Permanently sets the status of all services to NOT_SERVING.
@@ -162,6 +164,7 @@ def enter_graceful_shutdown(self):
return
else:
for service in self._server_status:
- self.set(service,
- _health_pb2.HealthCheckResponse.NOT_SERVING) # pylint: disable=no-member
+ self.set(
+ service, _health_pb2.HealthCheckResponse.NOT_SERVING
+ ) # pylint: disable=no-member
self._gracefully_shutting_down = True
diff --git a/pkg/suggestion/v1beta1/internal/search_space.py b/pkg/suggestion/v1beta1/internal/search_space.py
index d50955810e5..7e920d7c363 100644
--- a/pkg/suggestion/v1beta1/internal/search_space.py
+++ b/pkg/suggestion/v1beta1/internal/search_space.py
@@ -16,12 +16,14 @@
import numpy as np
-from pkg.apis.manager.v1beta1.python import api_pb2 as api
-from pkg.suggestion.v1beta1.internal.constant import CATEGORICAL
-from pkg.suggestion.v1beta1.internal.constant import DISCRETE
-from pkg.suggestion.v1beta1.internal.constant import DOUBLE
-from pkg.suggestion.v1beta1.internal.constant import INTEGER
import pkg.suggestion.v1beta1.internal.constant as constant
+from pkg.apis.manager.v1beta1.python import api_pb2 as api
+from pkg.suggestion.v1beta1.internal.constant import (
+ CATEGORICAL,
+ DISCRETE,
+ DOUBLE,
+ INTEGER,
+)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@@ -40,8 +42,7 @@ def convert(experiment):
elif experiment.spec.objective.type == api.MINIMIZE:
search_space.goal = constant.MIN_GOAL
for p in experiment.spec.parameter_specs.parameters:
- search_space.params.append(
- HyperParameterSearchSpace.convert_parameter(p))
+ search_space.params.append(HyperParameterSearchSpace.convert_parameter(p))
return search_space
@staticmethod
@@ -50,15 +51,20 @@ def convert_to_combinations(search_space):
for parameter in search_space.params:
if parameter.type == INTEGER:
- combinations[parameter.name] = range(int(parameter.min), int(parameter.max)+1, int(parameter.step))
+ combinations[parameter.name] = range(
+ int(parameter.min), int(parameter.max) + 1, int(parameter.step)
+ )
elif parameter.type == DOUBLE:
if parameter.step == "" or parameter.step is None:
raise Exception(
- "Param {} step is nil; For discrete search space, all parameters must include step".
- format(parameter.name)
+ "Param {} step is nil; For discrete search space, all parameters "
+ "must include step".format(parameter.name)
)
- double_list = np.arange(float(parameter.min), float(parameter.max)+float(parameter.step),
- float(parameter.step))
+ double_list = np.arange(
+ float(parameter.min),
+ float(parameter.max) + float(parameter.step),
+ float(parameter.step),
+ )
if double_list[-1] > float(parameter.max):
double_list = double_list[:-1]
combinations[parameter.name] = double_list
@@ -68,8 +74,11 @@ def convert_to_combinations(search_space):
return combinations
def __str__(self):
- return "HyperParameterSearchSpace(goal: {}, ".format(self.goal) + \
- "params: {})".format(", ".join([element.__str__() for element in self.params]))
+ return "HyperParameterSearchSpace(goal: {}, ".format(
+ self.goal
+ ) + "params: {})".format(
+ ", ".join([element.__str__() for element in self.params])
+ )
@staticmethod
def convert_parameter(p):
@@ -78,16 +87,26 @@ def convert_parameter(p):
step = 1
if p.feasible_space.step is not None and p.feasible_space.step != "":
step = p.feasible_space.step
- return HyperParameter.int(p.name, p.feasible_space.min, p.feasible_space.max, step)
+ return HyperParameter.int(
+ p.name, p.feasible_space.min, p.feasible_space.max, step
+ )
elif p.parameter_type == api.DOUBLE:
- return HyperParameter.double(p.name, p.feasible_space.min, p.feasible_space.max, p.feasible_space.step)
+ return HyperParameter.double(
+ p.name,
+ p.feasible_space.min,
+ p.feasible_space.max,
+ p.feasible_space.step,
+ )
elif p.parameter_type == api.CATEGORICAL:
return HyperParameter.categorical(p.name, p.feasible_space.list)
elif p.parameter_type == api.DISCRETE:
return HyperParameter.discrete(p.name, p.feasible_space.list)
else:
logger.error(
- "Cannot get the type for the parameter: %s (%s)", p.name, p.parameter_type)
+ "Cannot get the type for the parameter: %s (%s)",
+ p.name,
+ p.parameter_type,
+ )
class HyperParameter(object):
@@ -101,11 +120,15 @@ def __init__(self, name, type_, min_, max_, list_, step):
def __str__(self):
if self.type == constant.INTEGER or self.type == constant.DOUBLE:
- return "HyperParameter(name: {}, type: {}, min: {}, max: {}, step: {})".format(
- self.name, self.type, self.min, self.max, self.step)
+ return (
+ "HyperParameter(name: {}, type: {}, min: {}, max: {}, step: {})".format(
+ self.name, self.type, self.min, self.max, self.step
+ )
+ )
else:
return "HyperParameter(name: {}, type: {}, list: {})".format(
- self.name, self.type, ", ".join(self.list))
+ self.name, self.type, ", ".join(self.list)
+ )
@staticmethod
def int(name, min_, max_, step):
@@ -117,7 +140,9 @@ def double(name, min_, max_, step):
@staticmethod
def categorical(name, lst):
- return HyperParameter(name, constant.CATEGORICAL, 0, 0, [str(e) for e in lst], 0)
+ return HyperParameter(
+ name, constant.CATEGORICAL, 0, 0, [str(e) for e in lst], 0
+ )
@staticmethod
def discrete(name, lst):
diff --git a/pkg/suggestion/v1beta1/internal/trial.py b/pkg/suggestion/v1beta1/internal/trial.py
index 906617d0c0f..37d0a6891b0 100644
--- a/pkg/suggestion/v1beta1/internal/trial.py
+++ b/pkg/suggestion/v1beta1/internal/trial.py
@@ -41,8 +41,11 @@ def __init__(
def convert(trials):
res = []
for trial in trials:
- if trial.status.condition == api.TrialStatus.TrialConditionType.SUCCEEDED or \
- trial.status.condition == api.TrialStatus.TrialConditionType.EARLYSTOPPED:
+ if (
+ trial.status.condition == api.TrialStatus.TrialConditionType.SUCCEEDED
+ or trial.status.condition
+ == api.TrialStatus.TrialConditionType.EARLYSTOPPED
+ ):
new_trial = Trial.convertTrial(trial)
if new_trial is not None:
res.append(Trial.convertTrial(trial))
@@ -77,11 +80,14 @@ def __str__(self):
", ".join([str(e) for e in self.assignments])
)
else:
- return "Trial(assignment: {}, metric_name: {}, metric: {}, additional_metrics: {})".format(
- ", ".join([str(e) for e in self.assignments]),
- self.metric_name,
- self.target_metric,
- ", ".join(str(e) for e in self.additional_metrics),
+ return (
+ "Trial(assignment: {}, metric_name: {}, metric: {}, "
+ "additional_metrics: {})".format(
+ ", ".join(str(e) for e in self.assignments),
+ self.metric_name,
+ self.target_metric,
+ ", ".join(str(e) for e in self.additional_metrics),
+ )
)
diff --git a/pkg/suggestion/v1beta1/nas/common/validation.py b/pkg/suggestion/v1beta1/nas/common/validation.py
index 26e510d51a0..8535d0ee985 100644
--- a/pkg/suggestion/v1beta1/nas/common/validation.py
+++ b/pkg/suggestion/v1beta1/nas/common/validation.py
@@ -16,10 +16,8 @@
def validate_operations(operations: list[api_pb2.Operation]) -> (bool, str):
-
# Validate each operation
for operation in operations:
-
# Check OperationType
if not operation.operation_type:
return False, "Missing operationType in Operation:\n{}".format(operation)
@@ -31,33 +29,62 @@ def validate_operations(operations: list[api_pb2.Operation]) -> (bool, str):
# Validate each ParameterConfig in Operation
parameters_list = list(operation.parameter_specs.parameters)
for parameter in parameters_list:
-
# Check Name
if not parameter.name:
return False, "Missing Name in ParameterConfig:\n{}".format(parameter)
# Check ParameterType
if not parameter.parameter_type:
- return False, "Missing ParameterType in ParameterConfig:\n{}".format(parameter)
+ return False, "Missing ParameterType in ParameterConfig:\n{}".format(
+ parameter
+ )
# Check List in Categorical or Discrete Type
- if parameter.parameter_type == api_pb2.CATEGORICAL or parameter.parameter_type == api_pb2.DISCRETE:
+ if (
+ parameter.parameter_type == api_pb2.CATEGORICAL
+ or parameter.parameter_type == api_pb2.DISCRETE
+ ):
if not parameter.feasible_space.list:
- return False, "Missing List in ParameterConfig.feasibleSpace:\n{}".format(parameter)
+ return (
+ False,
+ "Missing List in ParameterConfig.feasibleSpace:\n{}".format(
+ parameter
+ ),
+ )
# Check Max, Min, Step in Int or Double Type
- elif parameter.parameter_type == api_pb2.INT or parameter.parameter_type == api_pb2.DOUBLE:
- if not parameter.feasible_space.min and not parameter.feasible_space.max:
- return False, "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(parameter)
+ elif (
+ parameter.parameter_type == api_pb2.INT
+ or parameter.parameter_type == api_pb2.DOUBLE
+ ):
+ if (
+ not parameter.feasible_space.min
+ and not parameter.feasible_space.max
+ ):
+ return (
+ False,
+ "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(
+ parameter
+ ),
+ )
try:
- if (parameter.parameter_type == api_pb2.DOUBLE and
- (not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0)):
- return False, \
- "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(parameter)
+ if parameter.parameter_type == api_pb2.DOUBLE and (
+ not parameter.feasible_space.step
+ or float(parameter.feasible_space.step) <= 0
+ ):
+ return (
+ False,
+ "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n"
+ "{}".format(parameter),
+ )
except Exception as e:
- return False, \
- "failed to validate ParameterConfig.feasibleSpace \n{parameter}):\n{exception}".format(
- parameter=parameter, exception=e)
+ return (
+ False,
+ (
+ "failed to validate ParameterConfig.feasibleSpace \n"
+ "{parameter}):\n{exception}"
+ ).format(parameter=parameter, exception=e),
+ )
return True, ""
diff --git a/pkg/suggestion/v1beta1/nas/darts/service.py b/pkg/suggestion/v1beta1/nas/darts/service.py
index 55e7751a8f3..835aaa58513 100644
--- a/pkg/suggestion/v1beta1/nas/darts/service.py
+++ b/pkg/suggestion/v1beta1/nas/darts/service.py
@@ -14,14 +14,11 @@
import json
import logging
-from logging import getLogger
-from logging import INFO
-from logging import StreamHandler
+from logging import INFO, StreamHandler, getLogger
import grpc
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations
@@ -33,7 +30,7 @@ def __init__(self):
self.is_first_run = True
self.logger = getLogger(__name__)
- FORMAT = '%(asctime)-15s Experiment %(experiment_name)s %(message)s'
+ FORMAT = "%(asctime)-15s Experiment %(experiment_name)s %(message)s"
logging.basicConfig(format=FORMAT)
handler = StreamHandler()
handler.setLevel(INFO)
@@ -62,8 +59,8 @@ def GetSuggestions(self, request, context):
search_space_json = json.dumps(search_space)
algorithm_settings_json = json.dumps(algorithm_settings)
- search_space_str = str(search_space_json).replace('\"', '\'')
- algorithm_settings_str = str(algorithm_settings_json).replace('\"', '\'')
+ search_space_str = str(search_space_json).replace('"', "'")
+ algorithm_settings_str = str(algorithm_settings_json).replace('"', "'")
self.is_first_run = False
@@ -84,17 +81,14 @@ def GetSuggestions(self, request, context):
api_pb2.GetSuggestionsReply.ParameterAssignments(
assignments=[
api_pb2.ParameterAssignment(
- name="algorithm-settings",
- value=algorithm_settings_str
+ name="algorithm-settings", value=algorithm_settings_str
),
api_pb2.ParameterAssignment(
- name="search-space",
- value=search_space_str
+ name="search-space", value=search_space_str
),
api_pb2.ParameterAssignment(
- name="num-layers",
- value=num_layers
- )
+ name="num-layers", value=num_layers
+ ),
]
)
)
@@ -114,27 +108,29 @@ def get_search_space(operations):
# Currently support only one Categorical parameter - filter size
opt_spec = list(operation.parameter_specs.parameters)[0]
for filter_size in list(opt_spec.feasible_space.list):
- search_space.append(opt_type+"_{}x{}".format(filter_size, filter_size))
+ search_space.append(
+ opt_type + "_{}x{}".format(filter_size, filter_size)
+ )
return search_space
def get_algorithm_settings(settings_raw):
algorithm_settings_default = {
- "num_epochs": 50,
- "w_lr": 0.025,
- "w_lr_min": 0.001,
- "w_momentum": 0.9,
- "w_weight_decay": 3e-4,
- "w_grad_clip": 5.,
- "alpha_lr": 3e-4,
- "alpha_weight_decay": 1e-3,
- "batch_size": 128,
- "num_workers": 4,
- "init_channels": 16,
- "print_step": 50,
- "num_nodes": 4,
- "stem_multiplier": 3,
+ "num_epochs": 50,
+ "w_lr": 0.025,
+ "w_lr_min": 0.001,
+ "w_momentum": 0.9,
+ "w_weight_decay": 3e-4,
+ "w_grad_clip": 5.0,
+ "alpha_lr": 3e-4,
+ "alpha_weight_decay": 1e-3,
+ "batch_size": 128,
+ "num_workers": 4,
+ "init_channels": 16,
+ "print_step": 50,
+ "num_nodes": 4,
+ "stem_multiplier": 3,
}
for setting in settings_raw:
@@ -162,7 +158,9 @@ def validate_algorithm_spec(spec: api_pb2.ExperimentSpec) -> (bool, str):
# validate_algorithm_settings is implemented based on quark0/darts and pt.darts.
# quark0/darts: https://github.com/quark0/darts
# pt.darts: https://github.com/khanrc/pt.darts
-def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSetting]) -> (bool, str):
+def validate_algorithm_settings(
+ algorithm_settings: list[api_pb2.AlgorithmSetting],
+) -> (bool, str):
for s in algorithm_settings:
try:
if s.name == "num_epochs":
@@ -172,17 +170,23 @@ def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSettin
# Validate learning rate
if s.name in {"w_lr", "w_lr_min", "alpha_lr"}:
if not float(s.value) >= 0.0:
- return False, "{} should be greater than or equal to zero".format(s.name)
+ return False, "{} should be greater than or equal to zero".format(
+ s.name
+ )
# Validate weight decay
if s.name in {"w_weight_decay", "alpha_weight_decay"}:
if not float(s.value) >= 0.0:
- return False, "{} should be greater than or equal to zero".format(s.name)
+ return False, "{} should be greater than or equal to zero".format(
+ s.name
+ )
# Validate w_momentum and w_grad_clip
if s.name in {"w_momentum", "w_grad_clip"}:
if not float(s.value) >= 0.0:
- return False, "{} should be greater than or equal to zero".format(s.name)
+ return False, "{} should be greater than or equal to zero".format(
+ s.name
+ )
if s.name == "batch_size":
if s.value != "None" and not int(s.value) >= 1:
@@ -193,12 +197,20 @@ def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSettin
return False, "num_workers should be greater than or equal to zero"
# Validate "init_channels", "print_step", "num_nodes" and "stem_multiplier"
- if s.name in {"init_channels", "print_step", "num_nodes", "stem_multiplier"}:
+ if s.name in {
+ "init_channels",
+ "print_step",
+ "num_nodes",
+ "stem_multiplier",
+ }:
if not int(s.value) >= 1:
- return False, "{} should be greater than or equal to one".format(s.name)
+ return False, "{} should be greater than or equal to one".format(
+ s.name
+ )
except Exception as e:
- return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
- exception=e)
+ return False, "failed to validate {name}({value}): {exception}".format(
+ name=s.name, value=s.value, exception=e
+ )
return True, ""
diff --git a/pkg/suggestion/v1beta1/nas/enas/AlgorithmSettings.py b/pkg/suggestion/v1beta1/nas/enas/AlgorithmSettings.py
index abf3ac39e03..685ff4cf25d 100644
--- a/pkg/suggestion/v1beta1/nas/enas/AlgorithmSettings.py
+++ b/pkg/suggestion/v1beta1/nas/enas/AlgorithmSettings.py
@@ -14,34 +14,38 @@
algorithmSettingsValidator = {
- "controller_hidden_size": [int, [1, 'inf']],
- "controller_temperature": [float, [0, 'inf']],
- "controller_tanh_const": [float, [0, 'inf']],
- "controller_entropy_weight": [float, [0.0, 'inf']],
- "controller_baseline_decay": [float, [0.0, 1.0]],
- "controller_learning_rate": [float, [0.0, 1.0]],
- "controller_skip_target": [float, [0.0, 1.0]],
- "controller_skip_weight": [float, [0.0, 'inf']],
- "controller_train_steps": [int, [1, 'inf']],
- "controller_log_every_steps": [int, [1, 'inf']],
+ "controller_hidden_size": [int, [1, "inf"]],
+ "controller_temperature": [float, [0, "inf"]],
+ "controller_tanh_const": [float, [0, "inf"]],
+ "controller_entropy_weight": [float, [0.0, "inf"]],
+ "controller_baseline_decay": [float, [0.0, 1.0]],
+ "controller_learning_rate": [float, [0.0, 1.0]],
+ "controller_skip_target": [float, [0.0, 1.0]],
+ "controller_skip_weight": [float, [0.0, "inf"]],
+ "controller_train_steps": [int, [1, "inf"]],
+ "controller_log_every_steps": [int, [1, "inf"]],
}
enableNoneSettingsList = [
- "controller_temperature", "controller_tanh_const", "controller_entropy_weight", "controller_skip_weight"]
+ "controller_temperature",
+ "controller_tanh_const",
+ "controller_entropy_weight",
+ "controller_skip_weight",
+]
def parseAlgorithmSettings(settings_raw):
algorithm_settings_default = {
- "controller_hidden_size": 64,
- "controller_temperature": 5.,
- "controller_tanh_const": 2.25,
- "controller_entropy_weight": 1e-5,
- "controller_baseline_decay": 0.999,
- "controller_learning_rate": 5e-5,
- "controller_skip_target": 0.4,
- "controller_skip_weight": 0.8,
- "controller_train_steps": 50,
- "controller_log_every_steps": 10,
+ "controller_hidden_size": 64,
+ "controller_temperature": 5.0,
+ "controller_tanh_const": 2.25,
+ "controller_entropy_weight": 1e-5,
+ "controller_baseline_decay": 0.999,
+ "controller_learning_rate": 5e-5,
+ "controller_skip_target": 0.4,
+ "controller_skip_weight": 0.8,
+ "controller_train_steps": 50,
+ "controller_log_every_steps": 10,
}
for setting in settings_raw:
diff --git a/pkg/suggestion/v1beta1/nas/enas/Controller.py b/pkg/suggestion/v1beta1/nas/enas/Controller.py
index 12a43a32c56..42b4c84b5e5 100644
--- a/pkg/suggestion/v1beta1/nas/enas/Controller.py
+++ b/pkg/suggestion/v1beta1/nas/enas/Controller.py
@@ -17,19 +17,21 @@
class Controller(object):
- def __init__(self,
- num_layers=12,
- num_operations=16,
- controller_hidden_size=64,
- controller_temperature=5.,
- controller_tanh_const=2.25,
- controller_entropy_weight=1e-5,
- controller_baseline_decay=0.999,
- controller_learning_rate=5e-5,
- controller_skip_target=0.4,
- controller_skip_weight=0.8,
- controller_name="controller",
- logger=None):
+ def __init__(
+ self,
+ num_layers=12,
+ num_operations=16,
+ controller_hidden_size=64,
+ controller_temperature=5.0,
+ controller_tanh_const=2.25,
+ controller_entropy_weight=1e-5,
+ controller_baseline_decay=0.999,
+ controller_learning_rate=5e-5,
+ controller_skip_target=0.4,
+ controller_skip_weight=0.8,
+ controller_name="controller",
+ logger=None,
+ ):
self.logger = logger
self.logger.info(">>> Building Controller\n")
@@ -59,23 +61,38 @@ def _build_params(self):
with tf.compat.v1.variable_scope(self.controller_name, initializer=initializer):
with tf.compat.v1.variable_scope("lstm"):
- self.w_lstm = tf.compat.v1.get_variable("w", [2 * hidden_size, 4 * hidden_size])
+ self.w_lstm = tf.compat.v1.get_variable(
+ "w", [2 * hidden_size, 4 * hidden_size]
+ )
self.g_emb = tf.compat.v1.get_variable("g_emb", [1, hidden_size])
with tf.compat.v1.variable_scope("embedding"):
- self.w_emb = tf.compat.v1.get_variable("w", [self.num_operations, hidden_size])
+ self.w_emb = tf.compat.v1.get_variable(
+ "w", [self.num_operations, hidden_size]
+ )
with tf.compat.v1.variable_scope("softmax"):
- self.w_soft = tf.compat.v1.get_variable("w", [hidden_size, self.num_operations])
-
- with tf.compat.v1.variable_scope('attention'):
- self.attn_w_1 = tf.compat.v1.get_variable('w_1', [hidden_size, hidden_size])
- self.attn_w_2 = tf.compat.v1.get_variable("w_2", [hidden_size, hidden_size])
- self.attn_v = tf.compat.v1.get_variable('v', [hidden_size, 1])
-
- num_params = sum([np.prod(v.shape)
- for v in tf.compat.v1.trainable_variables() if v.name.startswith(self.controller_name)])
+ self.w_soft = tf.compat.v1.get_variable(
+ "w", [hidden_size, self.num_operations]
+ )
+
+ with tf.compat.v1.variable_scope("attention"):
+ self.attn_w_1 = tf.compat.v1.get_variable(
+ "w_1", [hidden_size, hidden_size]
+ )
+ self.attn_w_2 = tf.compat.v1.get_variable(
+ "w_2", [hidden_size, hidden_size]
+ )
+ self.attn_v = tf.compat.v1.get_variable("v", [hidden_size, 1])
+
+ num_params = sum(
+ [
+ np.prod(v.shape)
+ for v in tf.compat.v1.trainable_variables()
+ if v.name.startswith(self.controller_name)
+ ]
+ )
self.logger.info(">>> Controller has {} Trainable params\n".format(num_params))
def _build_sampler(self):
@@ -97,8 +114,10 @@ def _build_sampler(self):
prev_c = tf.zeros([1, hidden_size], tf.float32)
prev_h = tf.zeros([1, hidden_size], tf.float32)
- skip_targets = tf.constant([1.0 - self.controller_skip_target, self.controller_skip_target],
- dtype=tf.float32)
+ skip_targets = tf.constant(
+ [1.0 - self.controller_skip_target, self.controller_skip_target],
+ dtype=tf.float32,
+ )
inputs = self.g_emb
@@ -121,7 +140,8 @@ def _build_sampler(self):
arc_seq.append(func)
log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(
- logits=logits, labels=func)
+ logits=logits, labels=func
+ )
sample_log_probs.append(log_prob)
entropy = log_prob * tf.exp(-log_prob)
@@ -153,17 +173,23 @@ def _build_sampler(self):
arc_seq.append(skip_index)
skip_prob = tf.sigmoid(logits)
- kl = skip_prob * tf.math.log(skip_prob/skip_targets)
+ kl = skip_prob * tf.math.log(skip_prob / skip_targets)
kl = tf.reduce_sum(input_tensor=kl)
skip_penalties.append(kl)
log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(
- logits=logits, labels=skip_index)
+ logits=logits, labels=skip_index
+ )
- sample_log_probs.append(tf.reduce_sum(input_tensor=log_prob, keepdims=True))
+ sample_log_probs.append(
+ tf.reduce_sum(input_tensor=log_prob, keepdims=True)
+ )
entropy = tf.stop_gradient(
- tf.reduce_sum(input_tensor=log_prob * tf.exp(-log_prob), keepdims=True))
+ tf.reduce_sum(
+ input_tensor=log_prob * tf.exp(-log_prob), keepdims=True
+ )
+ )
sample_entropies.append(entropy)
skip_index = tf.dtypes.cast(skip_index, tf.float32)
@@ -173,7 +199,7 @@ def _build_sampler(self):
inputs = tf.matmul(skip_index, tf.concat(all_h, axis=0))
- inputs /= (1.0 + tf.reduce_sum(input_tensor=skip_index))
+ inputs /= 1.0 + tf.reduce_sum(input_tensor=skip_index)
else:
inputs = self.g_emb
@@ -201,7 +227,9 @@ def build_trainer(self):
self.reward = self.child_val_accuracy
- normalize = tf.dtypes.cast((self.num_layers * (self.num_layers - 1) / 2), tf.float32)
+ normalize = tf.dtypes.cast(
+ (self.num_layers * (self.num_layers - 1) / 2), tf.float32
+ )
self.skip_rate = tf.dtypes.cast((self.skip_count / normalize), tf.float32)
if self.controller_entropy_weight is not None:
@@ -210,7 +238,9 @@ def build_trainer(self):
self.sample_log_probs = tf.reduce_sum(input_tensor=self.sample_log_probs)
self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False)
baseline_update = tf.compat.v1.assign_sub(
- self.baseline, (1 - self.controller_baseline_decay) * (self.baseline - self.reward))
+ self.baseline,
+ (1 - self.controller_baseline_decay) * (self.baseline - self.reward),
+ )
with tf.control_dependencies([baseline_update]):
self.reward = tf.identity(self.reward)
@@ -221,16 +251,24 @@ def build_trainer(self):
self.loss += self.controller_skip_weight * self.skip_penalties
self.train_step = tf.Variable(
- 0, dtype=tf.int32, trainable=False, name=self.controller_name + '_train_step')
-
- tf_variables = [var for var in tf.compat.v1.trainable_variables()
- if var.name.startswith(self.controller_name)]
+ 0,
+ dtype=tf.int32,
+ trainable=False,
+ name=self.controller_name + "_train_step",
+ )
+
+ tf_variables = [
+ var
+ for var in tf.compat.v1.trainable_variables()
+ if var.name.startswith(self.controller_name)
+ ]
self.train_op, self.grad_norm = _build_train_op(
loss=self.loss,
tf_variables=tf_variables,
train_step=self.train_step,
- learning_rate=self.controller_learning_rate)
+ learning_rate=self.controller_learning_rate,
+ )
# TODO: will remove this function and use tf.nn.LSTMCell instead
@@ -252,6 +290,8 @@ def _build_train_op(loss, tf_variables, train_step, learning_rate):
grads = tf.gradients(ys=loss, xs=tf_variables)
grad_norm = tf.linalg.global_norm(grads)
- train_op = optimizer.apply_gradients(zip(grads, tf_variables), global_step=train_step)
+ train_op = optimizer.apply_gradients(
+ zip(grads, tf_variables), global_step=train_step
+ )
return train_op, grad_norm
diff --git a/pkg/suggestion/v1beta1/nas/enas/Operation.py b/pkg/suggestion/v1beta1/nas/enas/Operation.py
index 3fe94df792e..56b9b6f4780 100644
--- a/pkg/suggestion/v1beta1/nas/enas/Operation.py
+++ b/pkg/suggestion/v1beta1/nas/enas/Operation.py
@@ -27,9 +27,9 @@ def __init__(self, opt_id, opt_type, opt_params):
def get_dict(self):
opt_dict = dict()
- opt_dict['opt_id'] = self.opt_id
- opt_dict['opt_type'] = self.opt_type
- opt_dict['opt_params'] = self.opt_params
+ opt_dict["opt_id"] = self.opt_id
+ opt_dict["opt_type"] = self.opt_type
+ opt_dict["opt_params"] = self.opt_params
return opt_dict
def print_op(self, logger):
@@ -68,14 +68,12 @@ def _parse_operations(self):
spec_min = int(ispec.feasible_space.min)
spec_max = int(ispec.feasible_space.max)
spec_step = int(ispec.feasible_space.step)
- avail_space[spec_name] = range(
- spec_min, spec_max+1, spec_step)
+ avail_space[spec_name] = range(spec_min, spec_max + 1, spec_step)
elif ispec.parameter_type == api_pb2.DOUBLE:
spec_min = float(ispec.feasible_space.min)
spec_max = float(ispec.feasible_space.max)
spec_step = float(ispec.feasible_space.step)
- double_list = np.arange(
- spec_min, spec_max+spec_step, spec_step)
+ double_list = np.arange(spec_min, spec_max + spec_step, spec_step)
if double_list[-1] > spec_max:
del double_list[-1]
avail_space[spec_name] = double_list
diff --git a/pkg/suggestion/v1beta1/nas/enas/service.py b/pkg/suggestion/v1beta1/nas/enas/service.py
index c0e7970f127..94534d3f076 100644
--- a/pkg/suggestion/v1beta1/nas/enas/service.py
+++ b/pkg/suggestion/v1beta1/nas/enas/service.py
@@ -14,24 +14,20 @@
import json
import logging
-from logging import getLogger
-from logging import INFO
-from logging import StreamHandler
import os
+from logging import INFO, StreamHandler, getLogger
import grpc
import tensorflow as tf
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations
-from pkg.suggestion.v1beta1.nas.enas.AlgorithmSettings import \
- algorithmSettingsValidator
-from pkg.suggestion.v1beta1.nas.enas.AlgorithmSettings import \
- enableNoneSettingsList
-from pkg.suggestion.v1beta1.nas.enas.AlgorithmSettings import \
- parseAlgorithmSettings
+from pkg.suggestion.v1beta1.nas.enas.AlgorithmSettings import (
+ algorithmSettingsValidator,
+ enableNoneSettingsList,
+ parseAlgorithmSettings,
+)
from pkg.suggestion.v1beta1.nas.enas.Controller import Controller
from pkg.suggestion.v1beta1.nas.enas.Operation import SearchSpace
@@ -43,8 +39,7 @@ def __init__(self, request, logger):
self.experiment = request.experiment
self.num_trials = 1
self.tf_graph = tf.Graph()
- self.ctrl_cache_file = "ctrl_cache/{}.ckpt".format(
- self.experiment_name)
+ self.ctrl_cache_file = "ctrl_cache/{}.ckpt".format(self.experiment_name)
self.suggestion_step = 0
self.algorithm_settings = None
self.controller = None
@@ -55,12 +50,18 @@ def __init__(self, request, logger):
self.search_space = None
self.opt_direction = None
self.objective_name = None
- self.logger.info("-" * 100 + "\nSetting Up Suggestion for Experiment {}\n".format(
- self.experiment_name) + "-" * 100)
+ self.logger.info(
+ "-" * 100
+ + "\nSetting Up Suggestion for Experiment {}\n".format(self.experiment_name)
+ + "-" * 100
+ )
self._get_experiment_param()
self._setup_controller()
- self.logger.info(">>> Suggestion for Experiment {} has been initialized.\n".format(
- self.experiment_name))
+ self.logger.info(
+ ">>> Suggestion for Experiment {} has been initialized.\n".format(
+ self.experiment_name
+ )
+ )
def _get_experiment_param(self):
# this function need to
@@ -96,53 +97,69 @@ def _get_experiment_param(self):
self.print_algorithm_settings()
def _setup_controller(self):
-
with self.tf_graph.as_default():
-
self.controller = Controller(
num_layers=self.num_layers,
num_operations=self.num_operations,
- controller_hidden_size=self.algorithm_settings['controller_hidden_size'],
- controller_temperature=self.algorithm_settings['controller_temperature'],
- controller_tanh_const=self.algorithm_settings['controller_tanh_const'],
- controller_entropy_weight=self.algorithm_settings['controller_entropy_weight'],
- controller_baseline_decay=self.algorithm_settings['controller_baseline_decay'],
- controller_learning_rate=self.algorithm_settings["controller_learning_rate"],
- controller_skip_target=self.algorithm_settings['controller_skip_target'],
- controller_skip_weight=self.algorithm_settings['controller_skip_weight'],
+ controller_hidden_size=self.algorithm_settings[
+ "controller_hidden_size"
+ ],
+ controller_temperature=self.algorithm_settings[
+ "controller_temperature"
+ ],
+ controller_tanh_const=self.algorithm_settings["controller_tanh_const"],
+ controller_entropy_weight=self.algorithm_settings[
+ "controller_entropy_weight"
+ ],
+ controller_baseline_decay=self.algorithm_settings[
+ "controller_baseline_decay"
+ ],
+ controller_learning_rate=self.algorithm_settings[
+ "controller_learning_rate"
+ ],
+ controller_skip_target=self.algorithm_settings[
+ "controller_skip_target"
+ ],
+ controller_skip_weight=self.algorithm_settings[
+ "controller_skip_weight"
+ ],
controller_name="Ctrl_" + self.experiment_name,
- logger=self.logger)
+ logger=self.logger,
+ )
self.controller.build_trainer()
def print_search_space(self):
if self.search_space is None:
- self.logger.warning(
- "Error! The Suggestion has not yet been initialized!")
+ self.logger.warning("Error! The Suggestion has not yet been initialized!")
return
self.logger.info(
- ">>> Search Space for Experiment {}".format(self.experiment_name))
+ ">>> Search Space for Experiment {}".format(self.experiment_name)
+ )
for opt in self.search_space:
opt.print_op(self.logger)
self.logger.info(
- "There are {} operations in total.\n".format(self.num_operations))
+ "There are {} operations in total.\n".format(self.num_operations)
+ )
def print_algorithm_settings(self):
if self.algorithm_settings is None:
- self.logger.warning(
- "Error! The Suggestion has not yet been initialized!")
+ self.logger.warning("Error! The Suggestion has not yet been initialized!")
return
- self.logger.info(">>> Parameters of LSTM Controller for Experiment {}\n".format(
- self.experiment_name))
+ self.logger.info(
+ ">>> Parameters of LSTM Controller for Experiment {}\n".format(
+ self.experiment_name
+ )
+ )
for spec in self.algorithm_settings:
if len(spec) > 22:
- self.logger.info("{}:\t{}".format(
- spec, self.algorithm_settings[spec]))
+ self.logger.info("{}:\t{}".format(spec, self.algorithm_settings[spec]))
else:
- self.logger.info("{}:\t\t{}".format(
- spec, self.algorithm_settings[spec]))
+ self.logger.info(
+ "{}:\t\t{}".format(spec, self.algorithm_settings[spec])
+ )
self.logger.info("")
@@ -154,7 +171,7 @@ def __init__(self, logger=None):
self.experiment = None
if logger is None:
self.logger = getLogger(__name__)
- FORMAT = '%(asctime)-15s Experiment %(experiment_name)s %(message)s'
+ FORMAT = "%(asctime)-15s Experiment %(experiment_name)s %(message)s"
logging.basicConfig(format=FORMAT)
handler = StreamHandler()
handler.setLevel(INFO)
@@ -175,18 +192,21 @@ def ValidateAlgorithmSettings(self, request, context):
# Validate GraphConfig
# Check InputSize
if not graph_config.input_sizes:
- return self.set_validate_context_error(context,
- "Missing InputSizes in GraphConfig:\n{}".format(graph_config))
+ return self.set_validate_context_error(
+ context, "Missing InputSizes in GraphConfig:\n{}".format(graph_config)
+ )
# Check OutputSize
if not graph_config.output_sizes:
- return self.set_validate_context_error(context,
- "Missing OutputSizes in GraphConfig:\n{}".format(graph_config))
+ return self.set_validate_context_error(
+ context, "Missing OutputSizes in GraphConfig:\n{}".format(graph_config)
+ )
# Check NumLayers
if not graph_config.num_layers:
- return self.set_validate_context_error(context,
- "Missing NumLayers in GraphConfig:\n{}".format(graph_config))
+ return self.set_validate_context_error(
+ context, "Missing NumLayers in GraphConfig:\n{}".format(graph_config)
+ )
# Validate Operations
is_valid, message = validate_operations(nas_config.operations.operation)
@@ -204,34 +224,46 @@ def ValidateAlgorithmSettings(self, request, context):
try:
converted_value = setting_type(setting.value)
except Exception as e:
- return self.set_validate_context_error(context,
- "Algorithm Setting {} must be {} type: exception {}".format(
- setting.name, setting_type.__name__, e))
+ return self.set_validate_context_error(
+ context,
+ "Algorithm Setting {} must be {} type: exception {}".format(
+ setting.name, setting_type.__name__, e
+ ),
+ )
if setting_type == float:
- if (converted_value <= setting_range[0] or
- (setting_range[1] != 'inf' and converted_value > setting_range[1])):
+ if converted_value <= setting_range[0] or (
+ setting_range[1] != "inf" and converted_value > setting_range[1]
+ ):
return self.set_validate_context_error(
- context, "Algorithm Setting {}: {} with {} type must be in range ({}, {}]".format(
+ context,
+ (
+ "Algorithm Setting {}: {} with {} type must be in range "
+ "({}, {})"
+ ).format(
setting.name,
converted_value,
setting_type.__name__,
setting_range[0],
- setting_range[1])
+ setting_range[1],
+ ),
)
elif converted_value < setting_range[0]:
return self.set_validate_context_error(
- context, "Algorithm Setting {}: {} with {} type must be in range [{}, {})".format(
+ context,
+ "Algorithm Setting {}: {} with {} type must be in range [{}, {})".format(
setting.name,
converted_value,
setting_type.__name__,
setting_range[0],
- setting_range[1])
+ setting_range[1],
+ ),
)
else:
- return self.set_validate_context_error(context,
- "Unknown Algorithm Setting name: {}".format(setting.name))
+ return self.set_validate_context_error(
+ context, "Unknown Algorithm Setting name: {}".format(setting.name)
+ )
self.logger.info("All Experiment Settings are Valid")
return api_pb2.ValidateAlgorithmSettingsReply()
@@ -248,11 +280,18 @@ def GetSuggestions(self, request, context):
experiment = self.experiment
if request.current_request_number > 0:
experiment.num_trials = request.current_request_number
- self.logger.info("-" * 100 + "\nSuggestion Step {} for Experiment {}\n".format(
- experiment.suggestion_step, experiment.experiment_name) + "-" * 100)
+ self.logger.info(
+ "-" * 100
+ + "\nSuggestion Step {} for Experiment {}\n".format(
+ experiment.suggestion_step, experiment.experiment_name
+ )
+ + "-" * 100
+ )
self.logger.info("")
- self.logger.info(">>> Current Request Number:\t\t{}".format(experiment.num_trials))
+ self.logger.info(
+ ">>> Current Request Number:\t\t{}".format(experiment.num_trials)
+ )
self.logger.info("")
with experiment.tf_graph.as_default():
@@ -272,16 +311,20 @@ def GetSuggestions(self, request, context):
}
if self.is_first_run:
- self.logger.info(">>> First time running suggestion for {}. Random architecture will be given.".format(
- experiment.experiment_name))
+ self.logger.info(
+ ">>> First time running suggestion for {}. "
+ "Random architecture will be given.".format(
+ experiment.experiment_name
+ )
+ )
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.global_variables_initializer())
candidates = list()
for _ in range(experiment.num_trials):
- candidates.append(
- sess.run(controller_ops["sample_arc"]))
+ candidates.append(sess.run(controller_ops["sample_arc"]))
- # TODO: will use PVC to store the checkpoint to protect against unexpected suggestion pod restart
+ # TODO: will use PVC to store the checkpoint to protect
+ # against unexpected suggestion pod restart
saver.save(sess, experiment.ctrl_cache_file)
self.is_first_run = False
@@ -293,46 +336,66 @@ def GetSuggestions(self, request, context):
result = self.GetEvaluationResult(request.trials)
# TODO: (andreyvelich) I deleted this part, should it be handle by controller?
- # Sometimes training container may fail and GetEvaluationResult() will return None
+ # Sometimes training container may fail and GetEvaluationResult()
+ # will return None
# In this case, the Suggestion will:
- # 1. Firstly try to respawn the previous trials after waiting for RESPAWN_SLEEP seconds
- # 2. If respawning the trials for RESPAWN_LIMIT times still cannot collect valid results,
- # then fail the task because it may indicate that the training container has errors.
+ # 1. Firstly try to respawn the previous trials after waiting for
+ # RESPAWN_SLEEP seconds
+ # 2. If respawning the trials for RESPAWN_LIMIT times still cannot
+ # collect valid results,
+ # then fail the task because it may indicate that the training
+ # container has errors.
if result is None:
self.logger.warning(
- ">>> Suggestion has spawned trials, but they all failed.")
+ ">>> Suggestion has spawned trials, but they all failed."
+ )
self.logger.warning(
- ">>> Please check whether the training container is correctly implemented")
- self.logger.info(">>> Experiment {} failed".format(
- experiment.experiment_name))
+ ">>> Please check whether the training container "
+ "is correctly implemented"
+ )
+ self.logger.info(
+ ">>> Experiment {} failed".format(
+ experiment.experiment_name
+ )
+ )
return []
# This LSTM network is designed to maximize the metrics
- # However, if the user wants to minimize the metrics, we can take the negative of the result
+ # However, if the user wants to minimize the metrics,
+ # we can take the negative of the result
if experiment.opt_direction == api_pb2.MINIMIZE:
result = -result
- self.logger.info(">>> Suggestion updated. LSTM Controller Training\n")
- log_every = experiment.algorithm_settings["controller_log_every_steps"]
- for ctrl_step in range(1, experiment.algorithm_settings["controller_train_steps"]+1):
+ self.logger.info(
+ ">>> Suggestion updated. LSTM Controller Training\n"
+ )
+ log_every = experiment.algorithm_settings[
+ "controller_log_every_steps"
+ ]
+ for ctrl_step in range(
+ 1, experiment.algorithm_settings["controller_train_steps"] + 1
+ ):
run_ops = [
controller_ops["loss"],
controller_ops["entropy"],
controller_ops["grad_norm"],
controller_ops["baseline"],
controller_ops["skip_rate"],
- controller_ops["train_op"]
+ controller_ops["train_op"],
]
loss, entropy, grad_norm, baseline, skip_rate, _ = sess.run(
fetches=run_ops,
- feed_dict={controller_ops["child_val_accuracy"]: result})
+ feed_dict={controller_ops["child_val_accuracy"]: result},
+ )
controller_step = sess.run(controller_ops["train_step"])
if ctrl_step % log_every == 0:
log_string = ""
- log_string += "Controller Step: {} - ".format(controller_step)
+ log_string += "Controller Step: {} - ".format(
+ controller_step
+ )
log_string += "Loss: {:.4f} - ".format(loss)
log_string += "Entropy: {:.9} - ".format(entropy)
log_string += "Gradient Norm: {:.7f} - ".format(grad_norm)
@@ -342,8 +405,7 @@ def GetSuggestions(self, request, context):
candidates = list()
for _ in range(experiment.num_trials):
- candidates.append(
- sess.run(controller_ops["sample_arc"]))
+ candidates.append(sess.run(controller_ops["sample_arc"]))
saver.save(sess, experiment.ctrl_cache_file)
@@ -355,27 +417,29 @@ def GetSuggestions(self, request, context):
organized_arc = [0 for _ in range(experiment.num_layers)]
record = 0
for layer in range(experiment.num_layers):
- organized_arc[layer] = arc[record: record + layer + 1]
+ organized_arc[layer] = arc[record : record + layer + 1]
record += layer + 1
organized_candidates.append(organized_arc)
nn_config = dict()
- nn_config['num_layers'] = experiment.num_layers
- nn_config['input_sizes'] = experiment.input_sizes
- nn_config['output_sizes'] = experiment.output_sizes
- nn_config['embedding'] = dict()
+ nn_config["num_layers"] = experiment.num_layers
+ nn_config["input_sizes"] = experiment.input_sizes
+ nn_config["output_sizes"] = experiment.output_sizes
+ nn_config["embedding"] = dict()
for layer in range(experiment.num_layers):
opt = organized_arc[layer][0]
- nn_config['embedding'][opt] = experiment.search_space[opt].get_dict()
+ nn_config["embedding"][opt] = experiment.search_space[opt].get_dict()
organized_arc_json = json.dumps(organized_arc)
nn_config_json = json.dumps(nn_config)
- organized_arc_str = str(organized_arc_json).replace('\"', '\'')
- nn_config_str = str(nn_config_json).replace('\"', '\'')
+ organized_arc_str = str(organized_arc_json).replace('"', "'")
+ nn_config_str = str(nn_config_json).replace('"', "'")
self.logger.info(
- "\n>>> New Neural Network Architecture Candidate #{} (internal representation):".format(i))
+ "\n>>> New Neural Network Architecture Candidate #{} "
+ "(internal representation):".format(i)
+ )
self.logger.info(organized_arc_json)
self.logger.info("\n>>> Corresponding Seach Space Description:")
self.logger.info(nn_config_str)
@@ -384,20 +448,21 @@ def GetSuggestions(self, request, context):
api_pb2.GetSuggestionsReply.ParameterAssignments(
assignments=[
api_pb2.ParameterAssignment(
- name="architecture",
- value=organized_arc_str
+ name="architecture", value=organized_arc_str
),
api_pb2.ParameterAssignment(
- name="nn_config",
- value=nn_config_str
- )
+ name="nn_config", value=nn_config_str
+ ),
]
)
)
self.logger.info("")
- self.logger.info(">>> {} Trials were created for Experiment {}".format(
- experiment.num_trials, experiment.experiment_name))
+ self.logger.info(
+ ">>> {} Trials were created for Experiment {}".format(
+ experiment.num_trials, experiment.experiment_name
+ )
+ )
self.logger.info("")
experiment.suggestion_step += 1
@@ -423,11 +488,15 @@ def GetEvaluationResult(self, trials_list):
failed_trials.append(t.name)
n_completed = len(completed_trials)
- self.logger.info(">>> By now: {} Trials succeeded, {} Trials failed".format(
- n_completed, len(failed_trials)))
+ self.logger.info(
+ ">>> By now: {} Trials succeeded, {} Trials failed".format(
+ n_completed, len(failed_trials)
+ )
+ )
for tname in completed_trials:
- self.logger.info("Trial: {}, Value: {}".format(
- tname, completed_trials[tname]))
+ self.logger.info(
+ "Trial: {}, Value: {}".format(tname, completed_trials[tname])
+ )
for tname in failed_trials:
self.logger.info("Trial: {} was failed".format(tname))
diff --git a/pkg/suggestion/v1beta1/optuna/base_service.py b/pkg/suggestion/v1beta1/optuna/base_service.py
index ce790173024..0a18f395771 100644
--- a/pkg/suggestion/v1beta1/optuna/base_service.py
+++ b/pkg/suggestion/v1beta1/optuna/base_service.py
@@ -16,21 +16,19 @@
import optuna
-from pkg.suggestion.v1beta1.internal.constant import CATEGORICAL
-from pkg.suggestion.v1beta1.internal.constant import DISCRETE
-from pkg.suggestion.v1beta1.internal.constant import DOUBLE
-from pkg.suggestion.v1beta1.internal.constant import INTEGER
-from pkg.suggestion.v1beta1.internal.constant import MAX_GOAL
-from pkg.suggestion.v1beta1.internal.search_space import \
- HyperParameterSearchSpace
+from pkg.suggestion.v1beta1.internal.constant import (
+ CATEGORICAL,
+ DISCRETE,
+ DOUBLE,
+ INTEGER,
+ MAX_GOAL,
+)
+from pkg.suggestion.v1beta1.internal.search_space import HyperParameterSearchSpace
from pkg.suggestion.v1beta1.internal.trial import Assignment
class BaseOptunaService(object):
- def __init__(self,
- algorithm_name="",
- algorithm_config=None,
- search_space=None):
+ def __init__(self, algorithm_name="", algorithm_config=None, search_space=None):
self.algorithm_name = algorithm_name
self.algorithm_config = algorithm_config
self.search_space = search_space
@@ -56,7 +54,9 @@ def _create_sampler(self):
return optuna.samplers.RandomSampler(**self.algorithm_config)
elif self.algorithm_name == "grid":
- combinations = HyperParameterSearchSpace.convert_to_combinations(self.search_space)
+ combinations = HyperParameterSearchSpace.convert_to_combinations(
+ self.search_space
+ )
return optuna.samplers.GridSampler(combinations, **self.algorithm_config)
def get_suggestions(self, trials, current_request_number):
@@ -67,13 +67,17 @@ def get_suggestions(self, trials, current_request_number):
def _ask(self, current_request_number):
list_of_assignments = []
for _ in range(current_request_number):
- optuna_trial = self.study.ask(fixed_distributions=self._get_optuna_search_space())
+ optuna_trial = self.study.ask(
+ fixed_distributions=self._get_optuna_search_space()
+ )
assignments = [Assignment(k, v) for k, v in optuna_trial.params.items()]
list_of_assignments.append(assignments)
assignments_key = self._get_assignments_key(assignments)
- self.assignments_to_optuna_number[assignments_key].append(optuna_trial.number)
+ self.assignments_to_optuna_number[assignments_key].append(
+ optuna_trial.number
+ )
return list_of_assignments
@@ -84,13 +88,17 @@ def _tell(self, trials):
value = float(trial.target_metric.value)
assignments_key = self._get_assignments_key(trial.assignments)
- optuna_trial_numbers = self.assignments_to_optuna_number[assignments_key]
+ optuna_trial_numbers = self.assignments_to_optuna_number[
+ assignments_key
+ ]
if len(optuna_trial_numbers) != 0:
trial_number = optuna_trial_numbers.pop(0)
self.study.tell(trial_number, value)
else:
- raise ValueError("An unknown trial has been passed in the GetSuggestion request.")
+ raise ValueError(
+ "An unknown trial has been passed in the GetSuggestion request."
+ )
@staticmethod
def _get_assignments_key(assignments):
@@ -102,9 +110,15 @@ def _get_optuna_search_space(self):
search_space = {}
for param in self.search_space.params:
if param.type == INTEGER:
- search_space[param.name] = optuna.distributions.IntDistribution(int(param.min), int(param.max))
+ search_space[param.name] = optuna.distributions.IntDistribution(
+ int(param.min), int(param.max)
+ )
elif param.type == DOUBLE:
- search_space[param.name] = optuna.distributions.FloatDistribution(float(param.min), float(param.max))
+ search_space[param.name] = optuna.distributions.FloatDistribution(
+ float(param.min), float(param.max)
+ )
elif param.type == CATEGORICAL or param.type == DISCRETE:
- search_space[param.name] = optuna.distributions.CategoricalDistribution(param.list)
+ search_space[param.name] = optuna.distributions.CategoricalDistribution(
+ param.list
+ )
return search_space
diff --git a/pkg/suggestion/v1beta1/optuna/service.py b/pkg/suggestion/v1beta1/optuna/service.py
index a99d0043676..c793e615727 100644
--- a/pkg/suggestion/v1beta1/optuna/service.py
+++ b/pkg/suggestion/v1beta1/optuna/service.py
@@ -18,13 +18,10 @@
import grpc
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
-from pkg.suggestion.v1beta1.internal.search_space import \
- HyperParameterSearchSpace
-from pkg.suggestion.v1beta1.internal.trial import Assignment
-from pkg.suggestion.v1beta1.internal.trial import Trial
+from pkg.suggestion.v1beta1.internal.search_space import HyperParameterSearchSpace
+from pkg.suggestion.v1beta1.internal.trial import Assignment, Trial
from pkg.suggestion.v1beta1.optuna.base_service import BaseOptunaService
logger = logging.getLogger(__name__)
@@ -247,7 +244,8 @@ def _validate_grid_setting(cls, experiment):
if max_trial_count > num_combinations:
return (
False,
- "Max Trial Count: {max_trial} > all possible search combinations: {combinations}".format(
+ "Max Trial Count: {max_trial} > all possible search combinations: "
+ "{combinations}".format(
max_trial=max_trial_count, combinations=num_combinations
),
)
diff --git a/pkg/suggestion/v1beta1/pbt/service.py b/pkg/suggestion/v1beta1/pbt/service.py
index 38087786e7a..7791390eb59 100644
--- a/pkg/suggestion/v1beta1/pbt/service.py
+++ b/pkg/suggestion/v1beta1/pbt/service.py
@@ -20,15 +20,14 @@
import grpc
import numpy as np
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
-from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
import pkg.suggestion.v1beta1.internal.constant as constant
-from pkg.suggestion.v1beta1.internal.search_space import HyperParameter
-from pkg.suggestion.v1beta1.internal.search_space import \
- HyperParameterSearchSpace
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
+from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
+from pkg.suggestion.v1beta1.internal.search_space import (
+ HyperParameter,
+ HyperParameterSearchSpace,
+)
from pkg.suggestion.v1beta1.internal.trial import Assignment
-from pkg.suggestion.v1beta1.internal.trial import Trial
logger = logging.getLogger(__name__)
@@ -71,7 +70,8 @@ def ValidateAlgorithmSettings(self, request, context):
):
return self._set_validate_context_error(
context,
- "Param(resample_probability) should be null to perturb at 0.8 or 1.2, or be between 0 and 1, inclusive, to resample",
+ "Param(resample_probability) should be null to perturb at 0.8 or 1.2, "
+ "or be between 0 and 1, inclusive, to resample",
)
return api_pb2.ValidateAlgorithmSettingsReply()
@@ -97,9 +97,11 @@ def GetSuggestions(self, request, context):
request.experiment.name,
int(settings["n_population"]),
float(settings["truncation_threshold"]),
- None
- if not "resample_probability" in settings
- else float(settings["resample_probability"]),
+ (
+ None
+ if "resample_probability" not in settings
+ else float(settings["resample_probability"])
+ ),
search_space,
objective_metric,
objective_scale,
@@ -183,7 +185,7 @@ def get(self):
labels = {
"pbt.suggestion.katib.kubeflow.org/generation": self.generation,
}
- if not self.parent is None:
+ if self.parent is not None:
labels["pbt.suggestion.katib.kubeflow.org/parent"] = self.parent
return assignments, labels, self.uid
@@ -283,9 +285,7 @@ def get(self):
return obj.get()
def update(self, trial):
- trial_labels = trial.spec.labels
uid = trial.name
- generation = trial_labels["pbt.suggestion.katib.kubeflow.org/generation"]
# Do not update active/pending trials
if trial.status.condition in (
diff --git a/pkg/suggestion/v1beta1/skopt/base_service.py b/pkg/suggestion/v1beta1/skopt/base_service.py
index 434c3a8b908..4ac56e30729 100644
--- a/pkg/suggestion/v1beta1/skopt/base_service.py
+++ b/pkg/suggestion/v1beta1/skopt/base_service.py
@@ -17,11 +17,13 @@
import skopt
-from pkg.suggestion.v1beta1.internal.constant import CATEGORICAL
-from pkg.suggestion.v1beta1.internal.constant import DISCRETE
-from pkg.suggestion.v1beta1.internal.constant import DOUBLE
-from pkg.suggestion.v1beta1.internal.constant import INTEGER
-from pkg.suggestion.v1beta1.internal.constant import MAX_GOAL
+from pkg.suggestion.v1beta1.internal.constant import (
+ CATEGORICAL,
+ DISCRETE,
+ DOUBLE,
+ INTEGER,
+ MAX_GOAL,
+)
from pkg.suggestion.v1beta1.internal.trial import Assignment
logger = logging.getLogger(__name__)
@@ -32,13 +34,15 @@ class BaseSkoptService(object):
Refer to https://github.com/scikit-optimize/scikit-optimize .
"""
- def __init__(self,
- base_estimator="GP",
- n_initial_points=10,
- acq_func="gp_hedge",
- acq_optimizer="auto",
- random_state=None,
- search_space=None):
+ def __init__(
+ self,
+ base_estimator="GP",
+ n_initial_points=10,
+ acq_func="gp_hedge",
+ acq_optimizer="auto",
+ random_state=None,
+ search_space=None,
+ ):
self.base_estimator = base_estimator
self.n_initial_points = n_initial_points
self.acq_func = acq_func
@@ -56,14 +60,22 @@ def create_optimizer(self):
for param in self.search_space.params:
if param.type == INTEGER:
- skopt_search_space.append(skopt.space.Integer(
- int(param.min), int(param.max), name=param.name))
+ skopt_search_space.append(
+ skopt.space.Integer(int(param.min), int(param.max), name=param.name)
+ )
elif param.type == DOUBLE:
- skopt_search_space.append(skopt.space.Real(
- float(param.min), float(param.max), "log-uniform", name=param.name))
+ skopt_search_space.append(
+ skopt.space.Real(
+ float(param.min),
+ float(param.max),
+ "log-uniform",
+ name=param.name,
+ )
+ )
elif param.type == CATEGORICAL or param.type == DISCRETE:
skopt_search_space.append(
- skopt.space.Categorical(param.list, name=param.name))
+ skopt.space.Categorical(param.list, name=param.name)
+ )
self.skopt_optimizer = skopt.Optimizer(
skopt_search_space,
@@ -71,20 +83,27 @@ def create_optimizer(self):
n_initial_points=self.n_initial_points,
acq_func=self.acq_func,
acq_optimizer=self.acq_optimizer,
- random_state=self.random_state)
+ random_state=self.random_state,
+ )
def getSuggestions(self, trials, current_request_number):
"""
Get the new suggested trials with skopt algorithm.
"""
logger.info("-" * 100 + "\n")
- logger.info("New GetSuggestions call with current request number: {}\n".format(current_request_number))
+ logger.info(
+ "New GetSuggestions call with current request number: {}\n".format(
+ current_request_number
+ )
+ )
skopt_suggested = []
loss_for_skopt = []
if len(trials) > self.succeeded_trials or self.succeeded_trials == 0:
self.succeeded_trials = len(trials)
if self.succeeded_trials != 0:
- logger.info("Succeeded Trials changed: {}\n".format(self.succeeded_trials))
+ logger.info(
+ "Succeeded Trials changed: {}\n".format(self.succeeded_trials)
+ )
for trial in trials:
if trial.name not in self.recorded_trials_names:
self.recorded_trials_names.append(trial.name)
@@ -113,11 +132,21 @@ def getSuggestions(self, trials, current_request_number):
logger.info("Objective values: {}\n".format(loss_for_skopt))
t1 = datetime.datetime.now()
self.skopt_optimizer.tell(skopt_suggested, loss_for_skopt)
- logger.info("Optimizer tell method takes {} seconds".format((datetime.datetime.now()-t1).seconds))
- logger.info("List of recorded Trials names: {}\n".format(self.recorded_trials_names))
+ logger.info(
+ "Optimizer tell method takes {} seconds".format(
+ (datetime.datetime.now() - t1).seconds
+ )
+ )
+ logger.info(
+ "List of recorded Trials names: {}\n".format(
+ self.recorded_trials_names
+ )
+ )
else:
- logger.error("Succeeded Trials didn't change: {}\n".format(self.succeeded_trials))
+ logger.error(
+ "Succeeded Trials didn't change: {}\n".format(self.succeeded_trials)
+ )
logger.info("Running Optimizer ask to query new parameters for Trials\n")
@@ -127,9 +156,12 @@ def getSuggestions(self, trials, current_request_number):
for suggestion in skopt_suggested:
logger.info("New suggested parameters for Trial: {}".format(suggestion))
return_trial_list.append(
- BaseSkoptService.convert(self.search_space, suggestion))
+ BaseSkoptService.convert(self.search_space, suggestion)
+ )
- logger.info("GetSuggestions returns {} new Trials\n\n".format(len(return_trial_list)))
+ logger.info(
+ "GetSuggestions returns {} new Trials\n\n".format(len(return_trial_list))
+ )
return return_trial_list
@staticmethod
diff --git a/pkg/suggestion/v1beta1/skopt/service.py b/pkg/suggestion/v1beta1/skopt/service.py
index 4ac0041866f..f970a74fabf 100644
--- a/pkg/suggestion/v1beta1/skopt/service.py
+++ b/pkg/suggestion/v1beta1/skopt/service.py
@@ -16,13 +16,10 @@
import grpc
-from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.apis.manager.v1beta1.python import api_pb2_grpc
+from pkg.apis.manager.v1beta1.python import api_pb2, api_pb2_grpc
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
-from pkg.suggestion.v1beta1.internal.search_space import \
- HyperParameterSearchSpace
-from pkg.suggestion.v1beta1.internal.trial import Assignment
-from pkg.suggestion.v1beta1.internal.trial import Trial
+from pkg.suggestion.v1beta1.internal.search_space import HyperParameterSearchSpace
+from pkg.suggestion.v1beta1.internal.trial import Assignment, Trial
from pkg.suggestion.v1beta1.skopt.base_service import BaseSkoptService
logger = logging.getLogger(__name__)
@@ -40,7 +37,8 @@ def GetSuggestions(self, request, context):
Main function to provide suggestion.
"""
algorithm_name, config = OptimizerConfiguration.convert_algorithm_spec(
- request.experiment.spec.algorithm)
+ request.experiment.spec.algorithm
+ )
if self.is_first_run:
search_space = HyperParameterSearchSpace.convert(request.experiment)
@@ -50,18 +48,22 @@ def GetSuggestions(self, request, context):
acq_func=config.acq_func,
acq_optimizer=config.acq_optimizer,
random_state=config.random_state,
- search_space=search_space)
+ search_space=search_space,
+ )
self.is_first_run = False
trials = Trial.convert(request.trials)
- new_trials = self.base_service.getSuggestions(trials, request.current_request_number)
+ new_trials = self.base_service.getSuggestions(
+ trials, request.current_request_number
+ )
return api_pb2.GetSuggestionsReply(
parameter_assignments=Assignment.generate(new_trials)
)
def ValidateAlgorithmSettings(self, request, context):
is_valid, message = OptimizerConfiguration.validate_algorithm_spec(
- request.experiment.spec.algorithm)
+ request.experiment.spec.algorithm
+ )
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(message)
@@ -70,11 +72,14 @@ def ValidateAlgorithmSettings(self, request, context):
class OptimizerConfiguration(object):
- def __init__(self, base_estimator="GP",
- n_initial_points=10,
- acq_func="gp_hedge",
- acq_optimizer="auto",
- random_state=None):
+ def __init__(
+ self,
+ base_estimator="GP",
+ n_initial_points=10,
+ acq_func="gp_hedge",
+ acq_optimizer="auto",
+ random_state=None,
+ ):
self.base_estimator = base_estimator
self.n_initial_points = n_initial_points
self.acq_func = acq_func
@@ -102,7 +107,9 @@ def validate_algorithm_spec(cls, algorithm_spec):
algo_name = algorithm_spec.algorithm_name
if algo_name == "bayesianoptimization":
- return cls._validate_bayesianoptimization_setting(algorithm_spec.algorithm_settings)
+ return cls._validate_bayesianoptimization_setting(
+ algorithm_spec.algorithm_settings
+ )
else:
return False, "unknown algorithm name {}".format(algo_name)
@@ -112,23 +119,47 @@ def _validate_bayesianoptimization_setting(cls, algorithm_settings):
try:
if s.name == "base_estimator":
if s.value not in ["GP", "RF", "ET", "GBRT"]:
- return False, "base_estimator {} is not supported in Bayesian optimization".format(s.value)
+ return (
+ False,
+ "base_estimator {} is not supported in Bayesian optimization".format(
+ s.value
+ ),
+ )
elif s.name == "n_initial_points":
if not (int(s.value) >= 0):
- return False, "n_initial_points should be great or equal than zero"
+ return (
+ False,
+ "n_initial_points should be great or equal than zero",
+ )
elif s.name == "acq_func":
if s.value not in ["gp_hedge", "LCB", "EI", "PI", "EIps", "PIps"]:
- return False, "acq_func {} is not supported in Bayesian optimization".format(s.value)
+ return (
+ False,
+ "acq_func {} is not supported in Bayesian optimization".format(
+ s.value
+ ),
+ )
elif s.name == "acq_optimizer":
if s.value not in ["auto", "sampling", "lbfgs"]:
- return False, "acq_optimizer {} is not supported in Bayesian optimization".format(s.value)
+ return (
+ False,
+ "acq_optimizer {} is not supported in Bayesian optimization".format(
+ s.value
+ ),
+ )
elif s.name == "random_state":
if not (int(s.value) >= 0):
return False, "random_state should be great or equal than zero"
else:
- return False, "unknown setting {} for algorithm bayesianoptimization".format(s.name)
+ return (
+ False,
+ "unknown setting {} for algorithm bayesianoptimization".format(
+ s.name
+ ),
+ )
except Exception as e:
- return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
- exception=e)
+ return False, "failed to validate {name}({value}): {exception}".format(
+ name=s.name, value=s.value, exception=e
+ )
return True, ""
diff --git a/pkg/webhook/v1beta1/experiment/validator/validator.go b/pkg/webhook/v1beta1/experiment/validator/validator.go
index 56bde4b0621..9f0f1c3ff5b 100644
--- a/pkg/webhook/v1beta1/experiment/validator/validator.go
+++ b/pkg/webhook/v1beta1/experiment/validator/validator.go
@@ -19,7 +19,6 @@ package validator
import (
"encoding/json"
"fmt"
- "k8s.io/apimachinery/pkg/util/validation/field"
"path/filepath"
"regexp"
"strconv"
@@ -30,6 +29,7 @@ import (
"k8s.io/apimachinery/pkg/api/equality"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
+ "k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
logf "sigs.k8s.io/controller-runtime/pkg/log"
@@ -264,6 +264,17 @@ func (g *DefaultValidator) validateParameters(parameters []experimentsv1beta1.Pa
param.ParameterType, fmt.Sprintf("parameterType: %v is not supported", param.ParameterType)))
}
+ if param.FeasibleSpace.Distribution != "" {
+ if param.FeasibleSpace.Distribution != experimentsv1beta1.DistributionUniform &&
+ param.FeasibleSpace.Distribution != experimentsv1beta1.DistributionLogUniform &&
+ param.FeasibleSpace.Distribution != experimentsv1beta1.DistributionNormal &&
+ param.FeasibleSpace.Distribution != experimentsv1beta1.DistributionLogNormal &&
+ param.FeasibleSpace.Distribution != experimentsv1beta1.DistributionUnknown {
+ allErrs = append(allErrs, field.Invalid(parametersPath.Index(i).Child("feasibleSpace").Child("distribution"),
+ param.FeasibleSpace.Distribution, fmt.Sprintf("distribution: %v is not supported", param.FeasibleSpace.Distribution)))
+ }
+ }
+
if equality.Semantic.DeepEqual(param.FeasibleSpace, experimentsv1beta1.FeasibleSpace{}) {
allErrs = append(allErrs, field.Required(parametersPath.Index(i).Child("feasibleSpace"),
"feasibleSpace must be specified"))
diff --git a/pkg/webhook/v1beta1/experiment/validator/validator_test.go b/pkg/webhook/v1beta1/experiment/validator/validator_test.go
index 9b60cc6d4f6..c815b799d74 100644
--- a/pkg/webhook/v1beta1/experiment/validator/validator_test.go
+++ b/pkg/webhook/v1beta1/experiment/validator/validator_test.go
@@ -18,12 +18,13 @@ package validator
import (
"errors"
+ "testing"
+
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
experimentutil "github.com/kubeflow/katib/pkg/controller.v1beta1/experiment/util"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/apimachinery/pkg/util/validation/field"
- "testing"
"go.uber.org/mock/gomock"
batchv1 "k8s.io/api/batch/v1"
@@ -454,6 +455,22 @@ func TestValidateParameters(t *testing.T) {
},
testDescription: "Not empty max for categorical parameter type",
},
+ {
+ parameters: func() []experimentsv1beta1.ParameterSpec {
+ ps := newFakeInstance().Spec.Parameters
+ ps[0].FeasibleSpace.Distribution = "invalid-distribution"
+ return ps
+ }(),
+ wantErr: field.ErrorList{
+ field.Invalid(field.NewPath("spec").Child("parameters").Index(0).Child("feasibleSpace").Child("distribution"), "", ""),
+ },
+ testDescription: "Invalid distribution type",
+ },
+ {
+ parameters: newFakeInstance().Spec.Parameters,
+ wantErr: nil,
+ testDescription: "Valid parameters case",
+ },
}
for _, tc := range tcs {
@@ -1374,8 +1391,9 @@ func newFakeInstance() *experimentsv1beta1.Experiment {
Name: "lr",
ParameterType: experimentsv1beta1.ParameterTypeInt,
FeasibleSpace: experimentsv1beta1.FeasibleSpace{
- Max: "5",
- Min: "1",
+ Max: "5",
+ Min: "1",
+ Distribution: experimentsv1beta1.DistributionUniform,
},
},
{
diff --git a/pkg/webhook/v1beta1/pod/inject_webhook.go b/pkg/webhook/v1beta1/pod/inject_webhook.go
index 3932a6bbfd3..96deaf23c1d 100644
--- a/pkg/webhook/v1beta1/pod/inject_webhook.go
+++ b/pkg/webhook/v1beta1/pod/inject_webhook.go
@@ -20,6 +20,7 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
"net/http"
"path/filepath"
"strconv"
@@ -47,6 +48,13 @@ import (
var log = logf.Log.WithName("injector-webhook")
+var (
+ errInvalidOwnerAPIVersion = errors.New("invalid owner API version")
+ errInvalidSuggestionName = errors.New("invalid suggestion name")
+ errPodNotBelongToKatibJob = errors.New("pod does not belong to Katib Job")
+ errFailedToGetTrialTemplateJob = errors.New("unable to get Job in the trialTemplate")
+)
+
// SidecarInjector injects metrics collect sidecar to the primary pod.
type SidecarInjector struct {
client client.Client
@@ -266,7 +274,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa
// Get group and version from owner API version
gv, err := schema.ParseGroupVersion(owners[i].APIVersion)
if err != nil {
- return "", "", err
+ return "", "", fmt.Errorf("%w: %w", errInvalidOwnerAPIVersion, err)
}
gvk := schema.GroupVersionKind{
Group: gv.Group,
@@ -279,7 +287,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa
// Nested object namespace must be equal to object namespace
err = s.client.Get(context.TODO(), apitypes.NamespacedName{Name: owners[i].Name, Namespace: namespace}, nestedJob)
if err != nil {
- return "", "", err
+ return "", "", fmt.Errorf("%w: %w", errFailedToGetTrialTemplateJob, err)
}
// Recursively search for Trial ownership in nested object
jobKind, jobName, err = s.getKatibJob(nestedJob, namespace)
@@ -292,7 +300,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa
// If jobKind is empty after the loop, Trial doesn't own the object
if jobKind == "" {
- return "", "", errors.New("The Pod doesn't belong to Katib Job")
+ return "", "", errPodNotBelongToKatibJob
}
return jobKind, jobName, nil
@@ -329,7 +337,7 @@ func (s *SidecarInjector) getMetricsCollectorArgs(trial *trialsv1beta1.Trial, me
suggestion := &suggestionsv1beta1.Suggestion{}
err := s.client.Get(context.TODO(), apitypes.NamespacedName{Name: suggestionName, Namespace: trial.Namespace}, suggestion)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("%w: %w", errInvalidSuggestionName, err)
}
args = append(args, "-s-earlystop", util.GetEarlyStoppingEndpoint(suggestion))
}
diff --git a/pkg/webhook/v1beta1/pod/inject_webhook_test.go b/pkg/webhook/v1beta1/pod/inject_webhook_test.go
index 4436c10e7f1..8350264cfaa 100644
--- a/pkg/webhook/v1beta1/pod/inject_webhook_test.go
+++ b/pkg/webhook/v1beta1/pod/inject_webhook_test.go
@@ -20,7 +20,6 @@ import (
"context"
"fmt"
"path/filepath"
- "reflect"
"sync"
"testing"
"time"
@@ -32,7 +31,6 @@ import (
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
- "k8s.io/apimachinery/pkg/api/equality"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
@@ -78,16 +76,15 @@ func TestWrapWorkerContainer(t *testing.T) {
metricsFile := "metric.log"
- testCases := []struct {
- trial *trialsv1beta1.Trial
- pod *v1.Pod
- metricsFile string
- pathKind common.FileSystemKind
- expectedPod *v1.Pod
- err bool
- testDescription string
+ cases := map[string]struct {
+ trial *trialsv1beta1.Trial
+ pod *v1.Pod
+ metricsFile string
+ pathKind common.FileSystemKind
+ wantPod *v1.Pod
+ wantError error
}{
- {
+ "Tensorflow container without sh -c": {
trial: trial,
pod: &v1.Pod{
Spec: v1.PodSpec{
@@ -103,7 +100,7 @@ func TestWrapWorkerContainer(t *testing.T) {
},
metricsFile: metricsFile,
pathKind: common.FileKind,
- expectedPod: &v1.Pod{
+ wantPod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
@@ -118,10 +115,8 @@ func TestWrapWorkerContainer(t *testing.T) {
},
},
},
- err: false,
- testDescription: "Tensorflow container without sh -c",
},
- {
+ "Tensorflow container with sh -c": {
trial: trial,
pod: &v1.Pod{
Spec: v1.PodSpec{
@@ -138,7 +133,7 @@ func TestWrapWorkerContainer(t *testing.T) {
},
metricsFile: metricsFile,
pathKind: common.FileKind,
- expectedPod: &v1.Pod{
+ wantPod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
@@ -153,10 +148,8 @@ func TestWrapWorkerContainer(t *testing.T) {
},
},
},
- err: false,
- testDescription: "Tensorflow container with sh -c",
},
- {
+ "Training pod doesn't have primary container": {
trial: trial,
pod: &v1.Pod{
Spec: v1.PodSpec{
@@ -167,11 +160,19 @@ func TestWrapWorkerContainer(t *testing.T) {
},
},
},
- pathKind: common.FileKind,
- err: true,
- testDescription: "Training pod doesn't have primary container",
+ pathKind: common.FileKind,
+ wantPod: &v1.Pod{
+ Spec: v1.PodSpec{
+ Containers: []v1.Container{
+ {
+ Name: "not-primary-container",
+ },
+ },
+ },
+ },
+ wantError: errPrimaryContainerNotFound,
},
- {
+ "Container with early stopping command": {
trial: func() *trialsv1beta1.Trial {
t := trial.DeepCopy()
t.Spec.EarlyStoppingRules = []common.EarlyStoppingRule{
@@ -197,7 +198,7 @@ func TestWrapWorkerContainer(t *testing.T) {
},
metricsFile: metricsFile,
pathKind: common.FileKind,
- expectedPod: &v1.Pod{
+ wantPod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
@@ -216,23 +217,19 @@ func TestWrapWorkerContainer(t *testing.T) {
},
},
},
- err: false,
- testDescription: "Container with early stopping command",
},
}
- for _, c := range testCases {
- err := wrapWorkerContainer(c.trial, c.pod, c.trial.Namespace, c.metricsFile, c.pathKind)
- if c.err && err == nil {
- t.Errorf("Case %s failed. Expected error, got nil", c.testDescription)
- } else if !c.err {
- if err != nil {
- t.Errorf("Case %s failed. Expected nil, got error: %v", c.testDescription, err)
- } else if !equality.Semantic.DeepEqual(c.pod.Spec.Containers, c.expectedPod.Spec.Containers) {
- t.Errorf("Case %s failed. Expected pod: %v, got: %v",
- c.testDescription, c.expectedPod.Spec.Containers, c.pod.Spec.Containers)
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ err := wrapWorkerContainer(tc.trial, tc.pod, tc.trial.Namespace, tc.metricsFile, tc.pathKind)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from wrapWorkerContainer (-want,+got):\n%s", diff)
}
- }
+ if diff := cmp.Diff(tc.wantPod.Spec.Containers, tc.pod.Spec.Containers); len(diff) != 0 {
+ t.Errorf("Unexpected pod from wrapWorkerContainer (-want,+got):\n%s", diff)
+ }
+ })
}
}
@@ -320,17 +317,16 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
}
- testCases := []struct {
+ cases := map[string]struct {
trial *trialsv1beta1.Trial
metricNames string
mCSpec common.MetricsCollectorSpec
earlyStoppingRules []string
katibConfig configv1beta1.MetricsCollectorConfig
- expectedArgs []string
- name string
- err bool
+ wantArgs []string
+ wantError error
}{
- {
+ "StdOut MC": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -341,7 +337,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
katibConfig: configv1beta1.MetricsCollectorConfig{
WaitAllProcesses: &waitAllProcessesValue,
},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
@@ -350,9 +346,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
"-format", string(common.TextFormat),
"-w", "false",
},
- name: "StdOut MC",
},
- {
+ "File MC with Filter": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -373,7 +368,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
},
katibConfig: configv1beta1.MetricsCollectorConfig{},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
@@ -382,9 +377,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
"-f", "{mn1: ([a-b]), mv1: [0-9]};{mn2: ([a-b]), mv2: ([0-9])}",
"-format", string(common.TextFormat),
},
- name: "File MC with Filter",
},
- {
+ "File MC with Json Format": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -399,7 +393,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
},
katibConfig: configv1beta1.MetricsCollectorConfig{},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
@@ -407,9 +401,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
"-path", testPath,
"-format", string(common.JsonFormat),
},
- name: "File MC with Json Format",
},
- {
+ "Tf Event MC": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -423,16 +416,15 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
},
katibConfig: configv1beta1.MetricsCollectorConfig{},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
"-s-db", katibDBAddress,
"-path", testPath,
},
- name: "Tf Event MC",
},
- {
+ "Custom MC without Path": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -441,15 +433,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
},
katibConfig: configv1beta1.MetricsCollectorConfig{},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
"-s-db", katibDBAddress,
},
- name: "Custom MC without Path",
},
- {
+ "Custom MC with Path": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -463,16 +454,15 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
},
katibConfig: configv1beta1.MetricsCollectorConfig{},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
"-s-db", katibDBAddress,
"-path", testPath,
},
- name: "Custom MC with Path",
},
- {
+ "Prometheus MC without Path": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -481,15 +471,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
},
katibConfig: configv1beta1.MetricsCollectorConfig{},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
"-s-db", katibDBAddress,
},
- name: "Prometheus MC without Path",
},
- {
+ "Trial with EarlyStopping rules": {
trial: testTrial,
metricNames: testMetricName,
mCSpec: common.MetricsCollectorSpec{
@@ -499,7 +488,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
earlyStoppingRules: earlyStoppingRules,
katibConfig: configv1beta1.MetricsCollectorConfig{},
- expectedArgs: []string{
+ wantArgs: []string{
"-t", testTrialName,
"-m", testMetricName,
"-o-type", string(testObjective),
@@ -510,9 +499,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
"-stop-rule", earlyStoppingRules[1],
"-s-earlystop", katibEarlyStopAddress,
},
- name: "Trial with EarlyStopping rules",
},
- {
+ "Trial with invalid Experiment label name. Suggestion is not created": {
trial: func() *trialsv1beta1.Trial {
trial := testTrial.DeepCopy()
trial.ObjectMeta.Labels[consts.LabelExperimentName] = "invalid-name"
@@ -525,8 +513,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
},
earlyStoppingRules: earlyStoppingRules,
katibConfig: configv1beta1.MetricsCollectorConfig{},
- name: "Trial with invalid Experiment label name. Suggestion is not created",
- err: true,
+ wantError: errInvalidSuggestionName,
},
}
@@ -537,25 +524,25 @@ func TestGetMetricsCollectorArgs(t *testing.T) {
return c.Get(context.TODO(), types.NamespacedName{Namespace: testNamespace, Name: testSuggestionName}, testSuggestion)
}, timeout).ShouldNot(gomega.HaveOccurred())
- for _, tc := range testCases {
- args, err := si.getMetricsCollectorArgs(tc.trial, tc.metricNames, tc.mCSpec, tc.katibConfig, tc.earlyStoppingRules)
-
- if !tc.err && err != nil {
- t.Errorf("Case: %v failed. Expected nil, got %v", tc.name, err)
- } else if tc.err && err == nil {
- t.Errorf("Case: %v failed. Expected err, got nil", tc.name)
- } else if !tc.err && !reflect.DeepEqual(tc.expectedArgs, args) {
- t.Errorf("Case %v failed. ExpectedArgs: %v, got %v", tc.name, tc.expectedArgs, args)
- }
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ got, err := si.getMetricsCollectorArgs(tc.trial, tc.metricNames, tc.mCSpec, tc.katibConfig, tc.earlyStoppingRules)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from getMetricsCollectorArgs (-want,+got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tc.wantArgs, got); len(diff) != 0 {
+ t.Errorf("Unexpected args from getMetricsCollectorArgs (-want,+got):\n%s", diff)
+ }
+ })
}
}
func TestNeedWrapWorkerContainer(t *testing.T) {
- testCases := []struct {
+ testCases := map[string]struct {
mCSpec common.MetricsCollectorSpec
needWrap bool
}{
- {
+ "Valid case with needWrap true": {
mCSpec: common.MetricsCollectorSpec{
Collector: &common.CollectorSpec{
Kind: common.StdOutCollector,
@@ -563,7 +550,7 @@ func TestNeedWrapWorkerContainer(t *testing.T) {
},
needWrap: true,
},
- {
+ "Valid case with needWrap false": {
mCSpec: common.MetricsCollectorSpec{
Collector: &common.CollectorSpec{
Kind: common.CustomCollector,
@@ -573,114 +560,125 @@ func TestNeedWrapWorkerContainer(t *testing.T) {
},
}
- for _, tc := range testCases {
- needWrap := needWrapWorkerContainer(tc.mCSpec)
- if needWrap != tc.needWrap {
- t.Errorf("Expected needWrap %v, got %v", tc.needWrap, needWrap)
- }
+ for name, tc := range testCases {
+ t.Run(name, func(t *testing.T) {
+ needWrap := needWrapWorkerContainer(tc.mCSpec)
+ if needWrap != tc.needWrap {
+ t.Errorf("Expected needWrap %v, got %v", tc.needWrap, needWrap)
+ }
+ })
}
}
func TestMutateMetricsCollectorVolume(t *testing.T) {
- tc := struct {
+ testCases := map[string]struct {
pod v1.Pod
- expectedPod v1.Pod
- JobKind string
- MountPath string
- SidecarContainerName string
- PrimaryContainerName string
+ wantPod v1.Pod
+ jobKind string
+ mountPath string
+ sidecarContainerName string
+ primaryContainerName string
pathKind common.FileSystemKind
- err bool
+ wantError error
}{
- pod: v1.Pod{
- Spec: v1.PodSpec{
- Containers: []v1.Container{
- {
- Name: "train-job",
- },
- {
- Name: "init-container",
- },
- {
- Name: "metrics-collector",
+ "Valid case": {
+ pod: v1.Pod{
+ Spec: v1.PodSpec{
+ Containers: []v1.Container{
+ {
+ Name: "train-job",
+ },
+ {
+ Name: "init-container",
+ },
+ {
+ Name: "metrics-collector",
+ },
},
},
},
- },
- expectedPod: v1.Pod{
- Spec: v1.PodSpec{
- Containers: []v1.Container{
- {
- Name: "train-job",
- VolumeMounts: []v1.VolumeMount{
- {
- Name: common.MetricsVolume,
- MountPath: filepath.Dir(common.DefaultFilePath),
+ wantPod: v1.Pod{
+ Spec: v1.PodSpec{
+ Containers: []v1.Container{
+ {
+ Name: "train-job",
+ VolumeMounts: []v1.VolumeMount{
+ {
+ Name: common.MetricsVolume,
+ MountPath: filepath.Dir(common.DefaultFilePath),
+ },
},
},
- },
- {
- Name: "init-container",
- },
- {
- Name: "metrics-collector",
- VolumeMounts: []v1.VolumeMount{
- {
- Name: common.MetricsVolume,
- MountPath: filepath.Dir(common.DefaultFilePath),
+ {
+ Name: "init-container",
+ },
+ {
+ Name: "metrics-collector",
+ VolumeMounts: []v1.VolumeMount{
+ {
+ Name: common.MetricsVolume,
+ MountPath: filepath.Dir(common.DefaultFilePath),
+ },
},
},
},
- },
- Volumes: []v1.Volume{
- {
- Name: common.MetricsVolume,
- VolumeSource: v1.VolumeSource{
- EmptyDir: &v1.EmptyDirVolumeSource{},
+ Volumes: []v1.Volume{
+ {
+ Name: common.MetricsVolume,
+ VolumeSource: v1.VolumeSource{
+ EmptyDir: &v1.EmptyDirVolumeSource{},
+ },
},
},
},
},
+ mountPath: common.DefaultFilePath,
+ sidecarContainerName: "metrics-collector",
+ primaryContainerName: "train-job",
+ pathKind: common.FileKind,
},
- MountPath: common.DefaultFilePath,
- SidecarContainerName: "metrics-collector",
- PrimaryContainerName: "train-job",
- pathKind: common.FileKind,
}
- err := mutateMetricsCollectorVolume(
- &tc.pod,
- tc.MountPath,
- tc.SidecarContainerName,
- tc.PrimaryContainerName,
- tc.pathKind)
- if err != nil {
- t.Errorf("mutateMetricsCollectorVolume failed: %v", err)
- } else if !equality.Semantic.DeepEqual(tc.pod, tc.expectedPod) {
- t.Errorf("Expected pod %v, got %v", tc.expectedPod, tc.pod)
+ for name, tc := range testCases {
+ t.Run(name, func(t *testing.T) {
+ err := mutateMetricsCollectorVolume(
+ &tc.pod,
+ tc.mountPath,
+ tc.sidecarContainerName,
+ tc.primaryContainerName,
+ tc.pathKind)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from mutateMetricsCollectorVolume (-want,+got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tc.wantPod, tc.pod); len(diff) != 0 {
+ t.Errorf("Unexpected pod from mutateMetricsCollectorVolume (-want,+got):\n%s", diff)
+ }
+ })
}
}
func TestGetSidecarContainerName(t *testing.T) {
- testCases := []struct {
+ testCases := map[string]struct {
collectorKind common.CollectorKind
expectedCollectorKind string
}{
- {
+ "Expected kind is metrics-logger-and-collector": {
collectorKind: common.StdOutCollector,
expectedCollectorKind: mccommon.MetricLoggerCollectorContainerName,
},
- {
+ "Expected kind is metrics-collector": {
collectorKind: common.TfEventCollector,
expectedCollectorKind: mccommon.MetricCollectorContainerName,
},
}
- for _, tc := range testCases {
- collectorKind := getSidecarContainerName(tc.collectorKind)
- if collectorKind != tc.expectedCollectorKind {
- t.Errorf("Expected Collector Kind: %v, got %v", tc.expectedCollectorKind, collectorKind)
- }
+ for name, tc := range testCases {
+ t.Run(name, func(t *testing.T) {
+ collectorKind := getSidecarContainerName(tc.collectorKind)
+ if collectorKind != tc.expectedCollectorKind {
+ t.Errorf("Expected Collector Kind: %v, got %v", tc.expectedCollectorKind, collectorKind)
+ }
+ })
}
}
@@ -722,16 +720,15 @@ func TestGetKatibJob(t *testing.T) {
deployName := "deploy-name"
jobName := "job-name"
- testCases := []struct {
- pod *v1.Pod
- job *batchv1.Job
- deployment *appsv1.Deployment
- expectedJobKind string
- expectedJobName string
- err bool
- testDescription string
+ cases := map[string]struct {
+ pod *v1.Pod
+ job *batchv1.Job
+ deployment *appsv1.Deployment
+ wantJobKind string
+ wantJobName string
+ wantError error
}{
- {
+ "Valid run with ownership sequence: Trial -> Job -> Pod": {
pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
@@ -772,12 +769,10 @@ func TestGetKatibJob(t *testing.T) {
},
},
},
- expectedJobKind: "Job",
- expectedJobName: jobName + "-1",
- err: false,
- testDescription: "Valid run with ownership sequence: Trial -> Job -> Pod",
+ wantJobKind: "Job",
+ wantJobName: jobName + "-1",
},
- {
+ "Valid run with ownership sequence: Trial -> Deployment -> Pod, Job -> Pod": {
pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
@@ -851,12 +846,10 @@ func TestGetKatibJob(t *testing.T) {
},
},
},
- expectedJobKind: "Deployment",
- expectedJobName: deployName + "-2",
- err: false,
- testDescription: "Valid run with ownership sequence: Trial -> Deployment -> Pod, Job -> Pod",
+ wantJobKind: "Deployment",
+ wantJobName: deployName + "-2",
},
- {
+ "Run for not Trial's pod with ownership sequence: Job -> Pod": {
pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
@@ -889,10 +882,9 @@ func TestGetKatibJob(t *testing.T) {
},
},
},
- err: true,
- testDescription: "Run for not Trial's pod with ownership sequence: Job -> Pod",
+ wantError: errPodNotBelongToKatibJob,
},
- {
+ "Run when Pod owns Job that doesn't exists": {
pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
@@ -906,10 +898,9 @@ func TestGetKatibJob(t *testing.T) {
},
},
},
- err: true,
- testDescription: "Run when Pod owns Job that doesn't exists",
+ wantError: errFailedToGetTrialTemplateJob,
},
- {
+ "Run when Pod owns Job with invalid API version": {
pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
@@ -923,64 +914,63 @@ func TestGetKatibJob(t *testing.T) {
},
},
},
- err: true,
- testDescription: "Run when Pod owns Job with invalid API version",
+ wantError: errInvalidOwnerAPIVersion,
},
}
- for _, tc := range testCases {
- // Create Job if it is needed
- if tc.job != nil {
- jobUnstr, err := util.ConvertObjectToUnstructured(tc.job)
- gvk := schema.GroupVersionKind{
- Group: "batch",
- Version: "v1",
- Kind: "Job",
- }
- jobUnstr.SetGroupVersionKind(gvk)
- if err != nil {
- t.Errorf("ConvertObjectToUnstructured error %v", err)
- }
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ // Create Job if it is needed
+ if tc.job != nil {
+ jobUnstr, err := util.ConvertObjectToUnstructured(tc.job)
+ gvk := schema.GroupVersionKind{
+ Group: "batch",
+ Version: "v1",
+ Kind: "Job",
+ }
+ jobUnstr.SetGroupVersionKind(gvk)
+ if err != nil {
+ t.Errorf("ConvertObjectToUnstructured error %v", err)
+ }
- g.Expect(c.Create(context.TODO(), jobUnstr)).NotTo(gomega.HaveOccurred())
+ g.Expect(c.Create(context.TODO(), jobUnstr)).NotTo(gomega.HaveOccurred())
- // Wait that Job is created
- g.Eventually(func() error {
- return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.job.Name}, jobUnstr)
- }, timeout).ShouldNot(gomega.HaveOccurred())
- }
+ // Wait that Job is created
+ g.Eventually(func() error {
+ return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.job.Name}, jobUnstr)
+ }, timeout).ShouldNot(gomega.HaveOccurred())
+ }
- // Create Deployment if it is needed
- if tc.deployment != nil {
- g.Expect(c.Create(context.TODO(), tc.deployment)).NotTo(gomega.HaveOccurred())
+ // Create Deployment if it is needed
+ if tc.deployment != nil {
+ g.Expect(c.Create(context.TODO(), tc.deployment)).NotTo(gomega.HaveOccurred())
- // Wait that Deployment is created
- g.Eventually(func() error {
- return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.deployment.Name}, tc.deployment)
- }, timeout).ShouldNot(gomega.HaveOccurred())
- }
+ // Wait that Deployment is created
+ g.Eventually(func() error {
+ return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.deployment.Name}, tc.deployment)
+ }, timeout).ShouldNot(gomega.HaveOccurred())
+ }
- object, _ := util.ConvertObjectToUnstructured(tc.pod)
- jobKind, jobName, err := si.getKatibJob(object, namespace)
- if !tc.err && err != nil {
- t.Errorf("Case %v failed. Error %v", tc.testDescription, err)
- } else if !tc.err && (tc.expectedJobKind != jobKind || tc.expectedJobName != jobName) {
- t.Errorf("Case %v failed. Expected jobKind %v, got %v, Expected jobName %v, got %v",
- tc.testDescription, tc.expectedJobKind, jobKind, tc.expectedJobName, jobName)
- } else if tc.err && err == nil {
- t.Errorf("Expected error got nil")
- }
+ object, _ := util.ConvertObjectToUnstructured(tc.pod)
+ jobKind, jobName, err := si.getKatibJob(object, namespace)
+ if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
+ t.Errorf("Unexpected error from getKatibJob (-want,+got):\n%s", diff)
+ }
+ if tc.wantError == nil && (tc.wantJobKind != jobKind || tc.wantJobName != jobName) {
+ t.Errorf("Unexpected error from getKatibJob, expected jobKind %v, got %v, expected jobName %v, got %v",
+ tc.wantJobKind, jobKind, tc.wantJobName, jobName)
+ }
+ })
}
}
func TestIsPrimaryPod(t *testing.T) {
- testCases := []struct {
+ testCases := map[string]struct {
podLabels map[string]string
primaryPodLabels map[string]string
isPrimary bool
- testDescription string
}{
- {
+ "Pod contains all labels from primary pod labels": {
podLabels: map[string]string{
"test-key-1": "test-value-1",
"test-key-2": "test-value-2",
@@ -990,10 +980,9 @@ func TestIsPrimaryPod(t *testing.T) {
"test-key-1": "test-value-1",
"test-key-2": "test-value-2",
},
- isPrimary: true,
- testDescription: "Pod contains all labels from primary pod labels",
+ isPrimary: true,
},
- {
+ "Pod doesn't contain primary label": {
podLabels: map[string]string{
"test-key-1": "test-value-1",
},
@@ -1001,26 +990,26 @@ func TestIsPrimaryPod(t *testing.T) {
"test-key-1": "test-value-1",
"test-key-2": "test-value-2",
},
- isPrimary: false,
- testDescription: "Pod doesn't contain primary label",
+ isPrimary: false,
},
- {
+ "Pod contains label with incorrect value": {
podLabels: map[string]string{
"test-key-1": "invalid",
},
primaryPodLabels: map[string]string{
"test-key-1": "test-value-1",
},
- isPrimary: false,
- testDescription: "Pod contains label with incorrect value",
+ isPrimary: false,
},
}
- for _, tc := range testCases {
- isPrimary := isPrimaryPod(tc.podLabels, tc.primaryPodLabels)
- if isPrimary != tc.isPrimary {
- t.Errorf("Case %v. Expected isPrimary %v, got %v", tc.testDescription, tc.isPrimary, isPrimary)
- }
+ for name, tc := range testCases {
+ t.Run(name, func(t *testing.T) {
+ isPrimary := isPrimaryPod(tc.podLabels, tc.primaryPodLabels)
+ if diff := cmp.Diff(tc.isPrimary, isPrimary); len(diff) != 0 {
+ t.Errorf("Unexpected result (-want,got):\n%s", diff)
+ }
+ })
}
}
@@ -1030,14 +1019,12 @@ func TestMutatePodMetadata(t *testing.T) {
"katib-experiment": "katib-value",
consts.LabelTrialName: "test-trial",
}
-
- testCases := []struct {
- pod *v1.Pod
- trial *trialsv1beta1.Trial
- mutatedPod *v1.Pod
- testDescription string
+ testCases := map[string]struct {
+ pod *v1.Pod
+ trial *trialsv1beta1.Trial
+ mutatedPod *v1.Pod
}{
- {
+ "Mutated Pod should contain label from the origin Pod and Trial": {
pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
@@ -1058,20 +1045,21 @@ func TestMutatePodMetadata(t *testing.T) {
Labels: mutatedPodLabels,
},
},
- testDescription: "Mutated Pod should contain label from the origin Pod and Trial",
},
}
- for _, tc := range testCases {
- mutatePodMetadata(tc.pod, tc.trial)
- if !reflect.DeepEqual(tc.mutatedPod, tc.pod) {
- t.Errorf("Case %v. Expected Pod %v, got %v", tc.testDescription, tc.mutatedPod, tc.pod)
- }
+ for name, tc := range testCases {
+ t.Run(name, func(t *testing.T) {
+ mutatePodMetadata(tc.pod, tc.trial)
+ if diff := cmp.Diff(tc.mutatedPod, tc.pod); len(diff) != 0 {
+ t.Errorf("Unexpected mutated result (-want,+got):\n%s", diff)
+ }
+ })
}
}
func TestMutatePodEnv(t *testing.T) {
- testcases := map[string]struct {
+ testCases := map[string]struct {
pod *v1.Pod
trial *trialsv1beta1.Trial
mutatedPod *v1.Pod
@@ -1148,22 +1136,22 @@ func TestMutatePodEnv(t *testing.T) {
},
}
- for name, testcase := range testcases {
+ for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
- err := mutatePodEnv(testcase.pod, testcase.trial)
+ err := mutatePodEnv(tc.pod, tc.trial)
// Compare error with expected error
- if testcase.wantError != nil && err != nil {
- if diff := cmp.Diff(testcase.wantError.Error(), err.Error()); len(diff) != 0 {
+ if tc.wantError != nil && err != nil {
+ if diff := cmp.Diff(tc.wantError.Error(), err.Error()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
- } else if testcase.wantError != nil || err != nil {
+ } else if tc.wantError != nil || err != nil {
t.Errorf(
"Unexpected error (-want,+got):\n%s",
- cmp.Diff(testcase.wantError, err, cmpopts.EquateErrors()),
+ cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()),
)
}
// Compare Pod with expected pod after mutation
- if diff := cmp.Diff(testcase.mutatedPod, testcase.pod); len(diff) != 0 {
+ if diff := cmp.Diff(tc.mutatedPod, tc.pod); len(diff) != 0 {
t.Errorf("Unexpected mutated result (-want,+got):\n%s", diff)
}
})
diff --git a/pkg/webhook/v1beta1/pod/utils.go b/pkg/webhook/v1beta1/pod/utils.go
index 7dad82553cf..a3dc66e1cc8 100644
--- a/pkg/webhook/v1beta1/pod/utils.go
+++ b/pkg/webhook/v1beta1/pod/utils.go
@@ -18,6 +18,7 @@ package pod
import (
"context"
+ "errors"
"fmt"
"path/filepath"
"strings"
@@ -35,6 +36,8 @@ import (
mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common"
)
+var errPrimaryContainerNotFound = errors.New("unable to find primary container in mutated pod containers")
+
func isPrimaryPod(podLabels, primaryLabels map[string]string) bool {
for primaryKey, primaryValue := range primaryLabels {
@@ -190,8 +193,7 @@ func wrapWorkerContainer(trial *trialsv1beta1.Trial, pod *v1.Pod, namespace,
c.Command = command
c.Args = []string{argsStr}
} else {
- return fmt.Errorf("Unable to find primary container %v in mutated pod containers %v",
- trial.Spec.PrimaryContainerName, pod.Spec.Containers)
+ return fmt.Errorf("%w: primary container: %v, mutated pod containers: %v", errPrimaryContainerNotFound, trial.Spec.PrimaryContainerName, pod.Spec.Containers)
}
return nil
}
diff --git a/sdk/python/v1beta1/docs/V1beta1FeasibleSpace.md b/sdk/python/v1beta1/docs/V1beta1FeasibleSpace.md
index e30d6292d7f..2a313ee5dd3 100644
--- a/sdk/python/v1beta1/docs/V1beta1FeasibleSpace.md
+++ b/sdk/python/v1beta1/docs/V1beta1FeasibleSpace.md
@@ -3,6 +3,7 @@
## Properties
Name | Type | Description | Notes
------------ | ------------- | ------------- | -------------
+**distribution** | **str** | | [optional]
**list** | **list[str]** | | [optional]
**max** | **str** | | [optional]
**min** | **str** | | [optional]
diff --git a/sdk/python/v1beta1/kubeflow/__init__.py b/sdk/python/v1beta1/kubeflow/__init__.py
index 69e3be50dac..8db66d3d0f0 100644
--- a/sdk/python/v1beta1/kubeflow/__init__.py
+++ b/sdk/python/v1beta1/kubeflow/__init__.py
@@ -1 +1 @@
-__path__ = __import__('pkgutil').extend_path(__path__, __name__)
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
index b7c64fb6080..5cb535cd40b 100644
--- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
+++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
@@ -20,14 +20,14 @@
from typing import Any, Callable, Dict, List, Optional, Union
import grpc
+import kubeflow.katib.katib_api_pb2 as katib_api_pb2
from kubeflow.katib import models
from kubeflow.katib.api_client import ApiClient
from kubeflow.katib.constants import constants
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import kubeflow.katib.katib_api_pb2_grpc as katib_api_pb2_grpc
from kubeflow.katib.utils import utils
-from kubernetes import client
-from kubernetes import config
+from kubernetes import client, config
logger = logging.getLogger(__name__)
@@ -107,7 +107,7 @@ def create_experiment(
namespace = namespace or self.namespace
experiment_name = None
- if type(experiment) == models.V1beta1Experiment:
+ if type(experiment) is models.V1beta1Experiment:
if experiment.metadata.name is not None:
experiment_name = experiment.metadata.name
elif experiment.metadata.generate_name is not None:
@@ -138,7 +138,8 @@ def create_experiment(
except Exception as e:
if hasattr(e, "status") and e.status == 409:
raise Exception(
- f"A Katib Experiment with the name {namespace}/{experiment_name} already exists."
+ f"A Katib Experiment with the name "
+ f"{namespace}/{experiment_name} already exists."
)
raise RuntimeError(
f"Failed to create Katib Experiment: {namespace}/{experiment_name}"
@@ -153,7 +154,8 @@ def create_experiment(
IPython.display.display(
IPython.display.HTML(
"Katib Experiment {} "
- 'link here'.format(
+ 'link here'.format(
experiment_name,
namespace,
experiment_name,
@@ -216,7 +218,8 @@ def tune(
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
algorithm_name: Search algorithm for the HyperParameter tuning.
algorithm_settings: Settings for the search algorithm given.
- For available fields, check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#search-algorithms-in-detail.
+ For available fields, check this doc:
+ https://www.kubeflow.org/docs/components/katib/experiment/#search-algorithms-in-detail.
objective_metric_name: Objective metric that Katib optimizes.
additional_metric_names: List of metrics that Katib collects from the
objective function in addition to objective metric.
@@ -224,7 +227,8 @@ def tune(
Must be one of `minimize` or `maximize`.
objective_goal: Objective goal that Experiment should reach to be Succeeded.
max_trial_count: Maximum number of Trials to run. For the default
- values check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#configuration-spec.
+ values check this doc:
+ https://www.kubeflow.org/docs/components/katib/experiment/#configuration-spec.
parallel_trial_count: Number of Trials that Experiment runs in parallel.
max_failed_trial_count: Maximum number of Trials allowed to fail.
resources_per_trial: A parameter that lets you specify how much
@@ -250,7 +254,7 @@ def tune(
to the base image packages. These packages are installed before
executing the objective function.
pip_index_url: The PyPI url from which to install Python packages.
- metrics_collector_config: Specify the config of metrics collector,
+ metrics_collector_config: Specify the config of metrics collector,
for example, `metrics_collector_config = {"kind": "Push"}`.
Currently, we only support `StdOut` and `Push` metrics collector.
@@ -331,10 +335,14 @@ def tune(
# Otherwise, add value to the function input.
input_params[p_name] = p_value
- # Wrap objective function to execute it from the file. For example
+ # Wrap objective function to execute it from the file. For example:
# def objective(parameters):
# print(f'Parameters are {parameters}')
- # objective({'lr': '${trialParameters.lr}', 'epochs': '${trialParameters.epochs}', 'is_dist': False})
+ # objective({
+ # 'lr': '${trialParameters.lr}',
+ # 'epochs': '${trialParameters.epochs}',
+ # 'is_dist': False
+ # })
objective_code = f"{objective_code}\n{objective.__name__}({input_params})\n"
# Prepare execute script template.
@@ -386,7 +394,8 @@ def tune(
)
# Add metrics collector to the Katib Experiment.
- # Up to now, We only support parameter `kind`, of which default value is `StdOut`, to specify the kind of metrics collector.
+ # Up to now, we only support parameter `kind`, of which default value
+ # is `StdOut`, to specify the kind of metrics collector.
experiment.spec.metrics_collector_spec = models.V1beta1MetricsCollectorSpec(
collector=models.V1beta1CollectorSpec(kind=metrics_collector_config["kind"])
)
@@ -765,7 +774,9 @@ def wait_for_experiment_condition(
)
):
utils.print_experiment_status(experiment)
- logger.debug(f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n")
+ logger.debug(
+ f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n"
+ )
return experiment
# Raise exception if Experiment is Failed.
@@ -785,7 +796,9 @@ def wait_for_experiment_condition(
)
):
utils.print_experiment_status(experiment)
- logger.debug(f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n")
+ logger.debug(
+ f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n"
+ )
return experiment
# Check if Experiment reaches Running condition.
@@ -796,7 +809,9 @@ def wait_for_experiment_condition(
)
):
utils.print_experiment_status(experiment)
- logger.debug(f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n")
+ logger.debug(
+ f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n"
+ )
return experiment
# Check if Experiment reaches Restarting condition.
@@ -807,7 +822,9 @@ def wait_for_experiment_condition(
)
):
utils.print_experiment_status(experiment)
- logger.debug(f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n")
+ logger.debug(
+ f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n"
+ )
return experiment
# Check if Experiment reaches Succeeded condition.
@@ -818,18 +835,24 @@ def wait_for_experiment_condition(
)
):
utils.print_experiment_status(experiment)
- logger.debug(f"Experiment: {namespace}/{name} is {expected_condition}\n\n\n")
+
+ logger.debug(
+ f"waiting for experiment: {namespace}/{name} "
+ f"to reach {expected_condition} condition\n\n\n"
+ )
return experiment
# Otherwise, print the current Experiment results and sleep for the pooling interval.
utils.print_experiment_status(experiment)
logger.debug(
- f"Waiting for Experiment: {namespace}/{name} to reach {expected_condition} condition\n\n\n"
+ f"waiting for experiment: {namespace}/{name} "
+ f"to reach {expected_condition} condition\n\n\n"
)
time.sleep(polling_interval)
raise TimeoutError(
- f"Timeout waiting for Experiment: {namespace}/{name} to reach {expected_condition} state"
+ f"Timeout waiting for Experiment: {namespace}/{name} "
+ f"to reach {expected_condition} state"
)
def edit_experiment_budget(
@@ -845,7 +868,8 @@ def edit_experiment_budget(
budget to resume Succeeded Experiments with `LongRunning` and `FromVolume`
resume policies.
- Learn about resuming Experiments here: https://www.kubeflow.org/docs/components/katib/resume-experiment/
+ Learn about resuming Experiments here:
+ https://www.kubeflow.org/docs/components/katib/resume-experiment/
Args:
name: Name for the Experiment.
@@ -1258,10 +1282,12 @@ def get_trial_metrics(
use the default Katib DB Manager address: `katib-db-manager.kubeflow:6789`.
If you run this API outside the cluster, you have to port-forward the
- Katib DB Manager before getting the Trial metrics: `kubectl port-forward svc/katib-db-manager -n kubeflow 6789`.
+ Katib DB Manager before getting the Trial metrics:
+ `kubectl port-forward svc/katib-db-manager -n kubeflow 6789`.
In that case, you can use this Katib DB Manager address: `localhost:6789`.
- You can use `curl` to verify that Katib DB Manager is reachable: `curl `.
+ You can use `curl` to verify that Katib DB Manager is reachable:
+ `curl `.
Args:
name: Name for the Trial.
diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
index ead5ca90a9a..e7a8663af18 100644
--- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
+++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
@@ -1,21 +1,22 @@
import multiprocessing
from typing import List, Optional
-from unittest.mock import Mock
-from unittest.mock import patch
+from unittest.mock import Mock, patch
-from kubeflow.katib import KatibClient
-from kubeflow.katib import V1beta1AlgorithmSpec
-from kubeflow.katib import V1beta1Experiment
-from kubeflow.katib import V1beta1ExperimentSpec
-from kubeflow.katib import V1beta1FeasibleSpace
-from kubeflow.katib import V1beta1ObjectiveSpec
-from kubeflow.katib import V1beta1ParameterSpec
-from kubeflow.katib import V1beta1TrialParameterSpec
-from kubeflow.katib import V1beta1TrialTemplate
+import pytest
+from kubeflow.katib import (
+ KatibClient,
+ V1beta1AlgorithmSpec,
+ V1beta1Experiment,
+ V1beta1ExperimentSpec,
+ V1beta1FeasibleSpace,
+ V1beta1ObjectiveSpec,
+ V1beta1ParameterSpec,
+ V1beta1TrialParameterSpec,
+ V1beta1TrialTemplate,
+)
from kubeflow.katib.constants import constants
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
from kubernetes.client import V1ObjectMeta
-import pytest
TEST_RESULT_SUCCESS = "success"
@@ -57,16 +58,12 @@ def get_observation_log_response(*args, **kwargs):
def generate_trial_template() -> V1beta1TrialTemplate:
- trial_spec={
+ trial_spec = {
"apiVersion": "batch/v1",
"kind": "Job",
"spec": {
"template": {
- "metadata": {
- "annotations": {
- "sidecar.istio.io/inject": "false"
- }
- },
+ "metadata": {"annotations": {"sidecar.istio.io/inject": "false"}},
"spec": {
"containers": [
{
@@ -79,13 +76,13 @@ def generate_trial_template() -> V1beta1TrialTemplate:
"--batch-size=64",
"--lr=${trialParameters.learningRate}",
"--momentum=${trialParameters.momentum}",
- ]
+ ],
}
],
- "restartPolicy": "Never"
- }
+ "restartPolicy": "Never",
+ },
}
- }
+ },
}
return V1beta1TrialTemplate(
@@ -94,15 +91,15 @@ def generate_trial_template() -> V1beta1TrialTemplate:
V1beta1TrialParameterSpec(
name="learningRate",
description="Learning rate for the training model",
- reference="lr"
+ reference="lr",
),
V1beta1TrialParameterSpec(
name="momentum",
description="Momentum for the training model",
- reference="momentum"
+ reference="momentum",
),
],
- trial_spec=trial_spec
+ trial_spec=trial_spec,
)
@@ -125,60 +122,49 @@ def generate_experiment(
objective=objective_spec,
parameters=parameters,
trial_template=trial_template,
- )
+ ),
)
def create_experiment(
- name: Optional[str] = None,
- generate_name: Optional[str] = None
+ name: Optional[str] = None, generate_name: Optional[str] = None
) -> V1beta1Experiment:
experiment_namespace = "test"
if name is not None:
metadata = V1ObjectMeta(name=name, namespace=experiment_namespace)
elif generate_name is not None:
- metadata = V1ObjectMeta(generate_name=generate_name, namespace=experiment_namespace)
+ metadata = V1ObjectMeta(
+ generate_name=generate_name, namespace=experiment_namespace
+ )
else:
metadata = V1ObjectMeta(namespace=experiment_namespace)
- algorithm_spec=V1beta1AlgorithmSpec(
- algorithm_name="random"
- )
+ algorithm_spec = V1beta1AlgorithmSpec(algorithm_name="random")
- objective_spec=V1beta1ObjectiveSpec(
+ objective_spec = V1beta1ObjectiveSpec(
type="minimize",
- goal= 0.001,
+ goal=0.001,
objective_metric_name="loss",
)
- parameters=[
+ parameters = [
V1beta1ParameterSpec(
name="lr",
parameter_type="double",
- feasible_space=V1beta1FeasibleSpace(
- min="0.01",
- max="0.06"
- ),
+ feasible_space=V1beta1FeasibleSpace(min="0.01", max="0.06"),
),
V1beta1ParameterSpec(
name="momentum",
parameter_type="double",
- feasible_space=V1beta1FeasibleSpace(
- min="0.5",
- max="0.9"
- ),
+ feasible_space=V1beta1FeasibleSpace(min="0.5", max="0.9"),
),
]
trial_template = generate_trial_template()
experiment = generate_experiment(
- metadata,
- algorithm_spec,
- objective_spec,
- parameters,
- trial_template
+ metadata, algorithm_spec, objective_spec, parameters, trial_template
)
return experiment
@@ -316,7 +302,9 @@ def katib_client():
yield client
-@pytest.mark.parametrize("test_name,kwargs,expected_output", test_create_experiment_data)
+@pytest.mark.parametrize(
+ "test_name,kwargs,expected_output", test_create_experiment_data
+)
def test_create_experiment(katib_client, test_name, kwargs, expected_output):
"""
test create_experiment function of katib client
diff --git a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py
index 6250052310a..b8c513e9b78 100644
--- a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py
+++ b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from datetime import datetime
-from datetime import timezone
import os
+from datetime import datetime, timezone
from typing import Any, Dict
import grpc
@@ -38,7 +37,7 @@ def report_metrics(
For examle, `metrics = {"loss": 0.01, "accuracy": 0.99}`.
db-manager-address: Address for the Katib DB Manager in this format: `ip-address:port`.
timeout: Optional, gRPC API Server timeout in seconds to report metrics.
-
+
Raises:
ValueError: The Trial name is not passed to environment variables or
metrics value has incorrect format (cannot be converted to type `float`).
@@ -49,9 +48,7 @@ def report_metrics(
namespace = utils.get_current_k8s_namespace()
name = os.getenv("KATIB_TRIAL_NAME")
if name is None:
- raise ValueError(
- "The Trial name is not passed to environment variables"
- )
+ raise ValueError("The Trial name is not passed to environment variables")
# Get channel for grpc call to db manager
channel = grpc.insecure_channel(db_manager_address)
@@ -59,7 +56,7 @@ def report_metrics(
# Validate metrics value in dict
for value in metrics.values():
utils.validate_metrics_value(value)
-
+
# Dial katib db manager to report metrics
client = katib_api_pb2_grpc.DBManagerStub(channel)
try:
diff --git a/sdk/python/v1beta1/kubeflow/katib/api/search.py b/sdk/python/v1beta1/kubeflow/katib/api/search.py
index 12a9ffcddcd..84d56998070 100644
--- a/sdk/python/v1beta1/kubeflow/katib/api/search.py
+++ b/sdk/python/v1beta1/kubeflow/katib/api/search.py
@@ -61,5 +61,6 @@ def categorical(list: List):
"""
return models.V1beta1ParameterSpec(
- parameter_type="categorical", feasible_space=models.V1beta1FeasibleSpace(list),
+ parameter_type="categorical",
+ feasible_space=models.V1beta1FeasibleSpace(list),
)
diff --git a/sdk/python/v1beta1/kubeflow/katib/models/v1beta1_feasible_space.py b/sdk/python/v1beta1/kubeflow/katib/models/v1beta1_feasible_space.py
index 248318e80ed..439e9f65ff5 100644
--- a/sdk/python/v1beta1/kubeflow/katib/models/v1beta1_feasible_space.py
+++ b/sdk/python/v1beta1/kubeflow/katib/models/v1beta1_feasible_space.py
@@ -33,6 +33,7 @@ class V1beta1FeasibleSpace(object):
and the value is json key in definition.
"""
openapi_types = {
+ 'distribution': 'str',
'list': 'list[str]',
'max': 'str',
'min': 'str',
@@ -40,24 +41,28 @@ class V1beta1FeasibleSpace(object):
}
attribute_map = {
+ 'distribution': 'distribution',
'list': 'list',
'max': 'max',
'min': 'min',
'step': 'step'
}
- def __init__(self, list=None, max=None, min=None, step=None, local_vars_configuration=None): # noqa: E501
+ def __init__(self, distribution=None, list=None, max=None, min=None, step=None, local_vars_configuration=None): # noqa: E501
"""V1beta1FeasibleSpace - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration
+ self._distribution = None
self._list = None
self._max = None
self._min = None
self._step = None
self.discriminator = None
+ if distribution is not None:
+ self.distribution = distribution
if list is not None:
self.list = list
if max is not None:
@@ -67,6 +72,27 @@ def __init__(self, list=None, max=None, min=None, step=None, local_vars_configur
if step is not None:
self.step = step
+ @property
+ def distribution(self):
+ """Gets the distribution of this V1beta1FeasibleSpace. # noqa: E501
+
+
+ :return: The distribution of this V1beta1FeasibleSpace. # noqa: E501
+ :rtype: str
+ """
+ return self._distribution
+
+ @distribution.setter
+ def distribution(self, distribution):
+ """Sets the distribution of this V1beta1FeasibleSpace.
+
+
+ :param distribution: The distribution of this V1beta1FeasibleSpace. # noqa: E501
+ :type: str
+ """
+
+ self._distribution = distribution
+
@property
def list(self):
"""Gets the list of this V1beta1FeasibleSpace. # noqa: E501
diff --git a/sdk/python/v1beta1/kubeflow/katib/utils/utils.py b/sdk/python/v1beta1/kubeflow/katib/utils/utils.py
index 696390df5ac..c6e0734438f 100644
--- a/sdk/python/v1beta1/kubeflow/katib/utils/utils.py
+++ b/sdk/python/v1beta1/kubeflow/katib/utils/utils.py
@@ -72,6 +72,7 @@ def print_experiment_status(experiment: models.V1beta1Experiment):
print(f"Current Optimal Trial:\n {experiment.status.current_optimal_trial}")
print(f"Experiment conditions:\n {experiment.status.conditions}")
+
def validate_metrics_value(value: Any):
"""Validate if the metrics value can be converted to type `float`."""
try:
diff --git a/test/e2e/v1beta1/scripts/gh-actions/build-load.sh b/test/e2e/v1beta1/scripts/gh-actions/build-load.sh
index 2ce492da79a..cb0ea03cd5a 100755
--- a/test/e2e/v1beta1/scripts/gh-actions/build-load.sh
+++ b/test/e2e/v1beta1/scripts/gh-actions/build-load.sh
@@ -25,9 +25,10 @@ pushd .
cd "$(dirname "$0")/../../../../.."
trap popd EXIT
-TRIAL_IMAGES=${1:-""}
-EXPERIMENTS=${2:-""}
-DEPLOY_KATIB_UI=${3:-false}
+DEPLOY_KATIB_UI=${1:-false}
+TUNE_API=${2:-false}
+TRIAL_IMAGES=${3:-""}
+EXPERIMENTS=${4:-""}
REGISTRY="docker.io/kubeflowkatib"
TAG="e2e-test"
@@ -162,6 +163,12 @@ for name in "${TRIAL_IMAGE_ARRAY[@]}"; do
run "$name" "examples/$VERSION/trial-images/$name/Dockerfile"
done
+# Testing image for tune function
+if "$TUNE_API"; then
+ echo -e "\nPulling and building testing image for tune function..."
+ _build_containers "suggestion-hyperopt" "$CMD_PREFIX/suggestion/hyperopt/$VERSION/Dockerfile"
+fi
+
echo -e "\nCleanup Build Cache...\n"
docker buildx prune -f
diff --git a/test/e2e/v1beta1/scripts/gh-actions/run-e2e-experiment.py b/test/e2e/v1beta1/scripts/gh-actions/run-e2e-experiment.py
index 26ef2e9f6e2..a7c70b47c38 100644
--- a/test/e2e/v1beta1/scripts/gh-actions/run-e2e-experiment.py
+++ b/test/e2e/v1beta1/scripts/gh-actions/run-e2e-experiment.py
@@ -1,14 +1,12 @@
import argparse
import logging
-import time
-from kubeflow.katib import ApiClient
-from kubeflow.katib import KatibClient
-from kubeflow.katib import models
+import yaml
+from kubeflow.katib import ApiClient, KatibClient, models
from kubeflow.katib.constants import constants
from kubeflow.katib.utils.utils import FakeResponse
from kubernetes import client
-import yaml
+from verify import verify_experiment_results
# Experiment timeout is 40 min.
EXPERIMENT_TIMEOUT = 60 * 40
@@ -17,143 +15,6 @@
logging.basicConfig(level=logging.INFO)
-def verify_experiment_results(
- katib_client: KatibClient,
- experiment: models.V1beta1Experiment,
- exp_name: str,
- exp_namespace: str,
-):
-
- # Get the best objective metric.
- best_objective_metric = None
- for metric in experiment.status.current_optimal_trial.observation.metrics:
- if metric.name == experiment.spec.objective.objective_metric_name:
- best_objective_metric = metric
- break
-
- if best_objective_metric is None:
- raise Exception(
- "Unable to get the best metrics for objective: {}. Current Optimal Trial: {}".format(
- experiment.spec.objective.objective_metric_name,
- experiment.status.current_optimal_trial,
- )
- )
-
- # Get Experiment Succeeded reason.
- for c in experiment.status.conditions:
- if (
- c.type == constants.EXPERIMENT_CONDITION_SUCCEEDED
- and c.status == constants.CONDITION_STATUS_TRUE
- ):
- succeeded_reason = c.reason
- break
-
- trials_completed = experiment.status.trials_succeeded or 0
- trials_completed += experiment.status.trials_early_stopped or 0
- max_trial_count = experiment.spec.max_trial_count
-
- # If Experiment is Succeeded because of Max Trial Reached, all Trials must be completed.
- if (
- succeeded_reason == "ExperimentMaxTrialsReached"
- and trials_completed != max_trial_count
- ):
- raise Exception(
- "All Trials must be Completed. Max Trial count: {}, Experiment status: {}".format(
- max_trial_count, experiment.status
- )
- )
-
- # If Experiment is Succeeded because of Goal reached, the metrics must be correct.
- if succeeded_reason == "ExperimentGoalReached" and (
- (
- experiment.spec.objective.type == "minimize"
- and float(best_objective_metric.min) > float(experiment.spec.objective.goal)
- )
- or (
- experiment.spec.objective.type == "maximize"
- and float(best_objective_metric.max) < float(experiment.spec.objective.goal)
- )
- ):
- raise Exception(
- "Experiment goal is reached, but metrics are incorrect. "
- f"Experiment objective: {experiment.spec.objective}. "
- f"Experiment best objective metric: {best_objective_metric}"
- )
-
- # Verify Suggestion's resources. Suggestion name = Experiment name.
- suggestion = katib_client.get_suggestion(exp_name, exp_namespace)
-
- # For the Never or FromVolume resume policies Suggestion must be Succeeded.
- # For the LongRunning resume policy Suggestion must be always Running.
- for c in suggestion.status.conditions:
- if (
- c.type == constants.EXPERIMENT_CONDITION_SUCCEEDED
- and c.status == constants.CONDITION_STATUS_TRUE
- and experiment.spec.resume_policy == "LongRunning"
- ):
- raise Exception(
- f"Suggestion is Succeeded while Resume Policy is {experiment.spec.resume_policy}."
- f"Suggestion conditions: {suggestion.status.conditions}"
- )
- elif (
- c.type == constants.EXPERIMENT_CONDITION_RUNNING
- and c.status == constants.CONDITION_STATUS_TRUE
- and experiment.spec.resume_policy != "LongRunning"
- ):
- raise Exception(
- f"Suggestion is Running while Resume Policy is {experiment.spec.resume_policy}."
- f"Suggestion conditions: {suggestion.status.conditions}"
- )
-
- # For Never and FromVolume resume policies verify Suggestion's resources.
- if (
- experiment.spec.resume_policy == "Never"
- or experiment.spec.resume_policy == "FromVolume"
- ):
- resource_name = exp_name + "-" + experiment.spec.algorithm.algorithm_name
-
- # Suggestion's Service and Deployment should be deleted.
- for i in range(10):
- try:
- client.AppsV1Api().read_namespaced_deployment(
- resource_name, exp_namespace
- )
- except client.ApiException as e:
- if e.status == 404:
- break
- else:
- raise e
- # Deployment deletion might take some time.
- time.sleep(1)
- if i == 10:
- raise Exception(
- "Suggestion Deployment is still alive for Resume Policy: {}".format(
- experiment.spec.resume_policy
- )
- )
-
- try:
- client.CoreV1Api().read_namespaced_service(resource_name, exp_namespace)
- except client.ApiException as e:
- if e.status != 404:
- raise e
- else:
- raise Exception(
- "Suggestion Service is still alive for Resume Policy: {}".format(
- experiment.spec.resume_policy
- )
- )
-
- # For FromVolume resume policy PVC should not be deleted.
- if experiment.spec.resume_policy == "FromVolume":
- try:
- client.CoreV1Api().read_namespaced_persistent_volume_claim(
- resource_name, exp_namespace
- )
- except client.ApiException:
- raise Exception("PVC is deleted for FromVolume Resume Policy")
-
-
def run_e2e_experiment(
katib_client: KatibClient,
experiment: models.V1beta1Experiment,
diff --git a/test/e2e/v1beta1/scripts/gh-actions/run-e2e-tune-api.py b/test/e2e/v1beta1/scripts/gh-actions/run-e2e-tune-api.py
new file mode 100644
index 00000000000..c9d1cb2ee43
--- /dev/null
+++ b/test/e2e/v1beta1/scripts/gh-actions/run-e2e-tune-api.py
@@ -0,0 +1,96 @@
+import argparse
+import logging
+
+from kubeflow.katib import KatibClient, search
+from kubernetes import client
+from verify import verify_experiment_results
+
+# Experiment timeout is 40 min.
+EXPERIMENT_TIMEOUT = 60 * 40
+
+# The default logging config.
+logging.basicConfig(level=logging.INFO)
+
+
+def run_e2e_experiment_create_by_tune(
+ katib_client: KatibClient,
+ exp_name: str,
+ exp_namespace: str,
+):
+ # Create Katib Experiment and wait until it is finished.
+ logging.debug("Creating Experiment: {}/{}".format(exp_namespace, exp_name))
+
+ # Use the test case from get-started tutorial.
+ # https://www.kubeflow.org/docs/components/katib/getting-started/#getting-started-with-katib-python-sdk
+ # [1] Create an objective function.
+ def objective(parameters):
+ import time
+ time.sleep(5)
+ result = 4 * int(parameters["a"]) - float(parameters["b"]) ** 2
+ print(f"result={result}")
+
+ # [2] Create hyperparameter search space.
+ parameters = {
+ "a": search.int(min=10, max=20),
+ "b": search.double(min=0.1, max=0.2)
+ }
+
+ # [3] Create Katib Experiment with 4 Trials and 2 CPUs per Trial.
+ # And Wait until Experiment reaches Succeeded condition.
+ katib_client.tune(
+ name=exp_name,
+ namespace=exp_namespace,
+ objective=objective,
+ parameters=parameters,
+ objective_metric_name="result",
+ max_trial_count=4,
+ resources_per_trial={"cpu": "2"},
+ )
+ experiment = katib_client.wait_for_experiment_condition(
+ exp_name, exp_namespace, timeout=EXPERIMENT_TIMEOUT
+ )
+
+ # Verify the Experiment results.
+ verify_experiment_results(katib_client, experiment, exp_name, exp_namespace)
+
+ # Print the Experiment and Suggestion.
+ logging.debug(katib_client.get_experiment(exp_name, exp_namespace))
+ logging.debug(katib_client.get_suggestion(exp_name, exp_namespace))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--namespace", type=str, required=True, help="Namespace for the Katib E2E test",
+ )
+ parser.add_argument(
+ "--verbose", action="store_true", help="Verbose output for the Katib E2E test",
+ )
+ args = parser.parse_args()
+
+ if args.verbose:
+ logging.getLogger().setLevel(logging.DEBUG)
+
+ katib_client = KatibClient()
+
+ namespace_labels = client.CoreV1Api().read_namespace(args.namespace).metadata.labels
+ if 'katib.kubeflow.org/metrics-collector-injection' not in namespace_labels:
+ namespace_labels['katib.kubeflow.org/metrics-collector-injection'] = 'enabled'
+ client.CoreV1Api().patch_namespace(args.namespace, {'metadata': {'labels': namespace_labels}})
+
+ # Test with run_e2e_experiment_create_by_tune
+ exp_name = "tune-example"
+ exp_namespace = args.namespace
+ try:
+ run_e2e_experiment_create_by_tune(katib_client, exp_name, exp_namespace)
+ logging.info("---------------------------------------------------------------")
+ logging.info(f"E2E is succeeded for Experiment created by tune: {exp_namespace}/{exp_name}")
+ except Exception as e:
+ logging.info("---------------------------------------------------------------")
+ logging.info(f"E2E is failed for Experiment created by tune: {exp_namespace}/{exp_name}")
+ raise e
+ finally:
+ # Delete the Experiment.
+ logging.info("---------------------------------------------------------------")
+ logging.info("---------------------------------------------------------------")
+ katib_client.delete_experiment(exp_name, exp_namespace)
diff --git a/test/e2e/v1beta1/scripts/gh-actions/run-e2e-tune-api.sh b/test/e2e/v1beta1/scripts/gh-actions/run-e2e-tune-api.sh
new file mode 100755
index 00000000000..1520d301439
--- /dev/null
+++ b/test/e2e/v1beta1/scripts/gh-actions/run-e2e-tune-api.sh
@@ -0,0 +1,38 @@
+#!/usr/bin/env bash
+
+# Copyright 2024 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This shell script is used to run Katib Experiment.
+# Input parameter - path to Experiment yaml.
+
+set -o errexit
+set -o nounset
+set -o pipefail
+
+cd "$(dirname "$0")"
+
+echo "Katib deployments"
+kubectl -n kubeflow get deploy
+echo "Katib services"
+kubectl -n kubeflow get svc
+echo "Katib pods"
+kubectl -n kubeflow get pod
+echo "Katib persistent volume claims"
+kubectl get pvc -n kubeflow
+echo "Available CRDs"
+kubectl get crd
+
+python run-e2e-tune-api.py --namespace default \
+--verbose || (kubectl get pods -n kubeflow && exit 1)
diff --git a/test/e2e/v1beta1/scripts/gh-actions/setup-minikube.sh b/test/e2e/v1beta1/scripts/gh-actions/setup-minikube.sh
index a24131bbb7d..b890a40d41b 100755
--- a/test/e2e/v1beta1/scripts/gh-actions/setup-minikube.sh
+++ b/test/e2e/v1beta1/scripts/gh-actions/setup-minikube.sh
@@ -22,8 +22,9 @@ set -o nounset
cd "$(dirname "$0")"
DEPLOY_KATIB_UI=${1:-false}
-TRIAL_IMAGES=${2:-""}
-EXPERIMENTS=${3:-""}
+TUNE_API=${2:-false}
+TRIAL_IMAGES=${3:-""}
+EXPERIMENTS=${4:-""}
echo "Start to setup Minikube Kubernetes Cluster"
kubectl version
@@ -31,4 +32,4 @@ kubectl cluster-info
kubectl get nodes
echo "Build and Load container images"
-./build-load.sh "$TRIAL_IMAGES" "$EXPERIMENTS" "$DEPLOY_KATIB_UI"
+./build-load.sh "$DEPLOY_KATIB_UI" "$TUNE_API" "$TRIAL_IMAGES" "$EXPERIMENTS"
diff --git a/test/e2e/v1beta1/scripts/gh-actions/verify.py b/test/e2e/v1beta1/scripts/gh-actions/verify.py
new file mode 100644
index 00000000000..c1514f6da12
--- /dev/null
+++ b/test/e2e/v1beta1/scripts/gh-actions/verify.py
@@ -0,0 +1,140 @@
+import time
+
+from kubeflow.katib import KatibClient, models
+from kubeflow.katib.constants import constants
+from kubernetes import client
+
+
+def verify_experiment_results(
+ katib_client: KatibClient,
+ experiment: models.V1beta1Experiment,
+ exp_name: str,
+ exp_namespace: str,
+):
+
+ # Get the best objective metric.
+ best_objective_metric = None
+ for metric in experiment.status.current_optimal_trial.observation.metrics:
+ if metric.name == experiment.spec.objective.objective_metric_name:
+ best_objective_metric = metric
+ break
+
+ if best_objective_metric is None:
+ raise Exception(
+ "Unable to get the best metrics for objective: {}. Current Optimal Trial: {}".format(
+ experiment.spec.objective.objective_metric_name,
+ experiment.status.current_optimal_trial,
+ )
+ )
+
+ # Get Experiment Succeeded reason.
+ for c in experiment.status.conditions:
+ if (
+ c.type == constants.EXPERIMENT_CONDITION_SUCCEEDED
+ and c.status == constants.CONDITION_STATUS_TRUE
+ ):
+ succeeded_reason = c.reason
+ break
+
+ trials_completed = experiment.status.trials_succeeded or 0
+ trials_completed += experiment.status.trials_early_stopped or 0
+ max_trial_count = experiment.spec.max_trial_count
+
+ # If Experiment is Succeeded because of Max Trial Reached, all Trials must be completed.
+ if (
+ succeeded_reason == "ExperimentMaxTrialsReached"
+ and trials_completed != max_trial_count
+ ):
+ raise Exception(
+ "All Trials must be Completed. Max Trial count: {}, Experiment status: {}".format(
+ max_trial_count, experiment.status
+ )
+ )
+
+ # If Experiment is Succeeded because of Goal reached, the metrics must be correct.
+ if succeeded_reason == "ExperimentGoalReached" and (
+ (
+ experiment.spec.objective.type == "minimize"
+ and float(best_objective_metric.min) > float(experiment.spec.objective.goal)
+ )
+ or (
+ experiment.spec.objective.type == "maximize"
+ and float(best_objective_metric.max) < float(experiment.spec.objective.goal)
+ )
+ ):
+ raise Exception(
+ "Experiment goal is reached, but metrics are incorrect. "
+ f"Experiment objective: {experiment.spec.objective}. "
+ f"Experiment best objective metric: {best_objective_metric}"
+ )
+
+ # Verify Suggestion's resources. Suggestion name = Experiment name.
+ suggestion = katib_client.get_suggestion(exp_name, exp_namespace)
+
+ # For the Never or FromVolume resume policies Suggestion must be Succeeded.
+ # For the LongRunning resume policy Suggestion must be always Running.
+ for c in suggestion.status.conditions:
+ if (
+ c.type == constants.EXPERIMENT_CONDITION_SUCCEEDED
+ and c.status == constants.CONDITION_STATUS_TRUE
+ and experiment.spec.resume_policy == "LongRunning"
+ ):
+ raise Exception(
+ f"Suggestion is Succeeded while Resume Policy is {experiment.spec.resume_policy}."
+ f"Suggestion conditions: {suggestion.status.conditions}"
+ )
+ elif (
+ c.type == constants.EXPERIMENT_CONDITION_RUNNING
+ and c.status == constants.CONDITION_STATUS_TRUE
+ and experiment.spec.resume_policy != "LongRunning"
+ ):
+ raise Exception(
+ f"Suggestion is Running while Resume Policy is {experiment.spec.resume_policy}."
+ f"Suggestion conditions: {suggestion.status.conditions}"
+ )
+
+ # For Never and FromVolume resume policies verify Suggestion's resources.
+ if (
+ experiment.spec.resume_policy == "Never"
+ or experiment.spec.resume_policy == "FromVolume"
+ ):
+ resource_name = exp_name + "-" + experiment.spec.algorithm.algorithm_name
+
+ # Suggestion's Service and Deployment should be deleted.
+ for i in range(10):
+ try:
+ client.AppsV1Api().read_namespaced_deployment(
+ resource_name, exp_namespace
+ )
+ except client.ApiException as e:
+ if e.status == 404:
+ break
+ else:
+ raise e
+ if i == 10:
+ raise Exception(
+ "Suggestion Deployment is still alive for Resume Policy: {}".format(
+ experiment.spec.resume_policy
+ )
+ )
+
+ try:
+ client.CoreV1Api().read_namespaced_service(resource_name, exp_namespace)
+ except client.ApiException as e:
+ if e.status != 404:
+ raise e
+ else:
+ raise Exception(
+ "Suggestion Service is still alive for Resume Policy: {}".format(
+ experiment.spec.resume_policy
+ )
+ )
+
+ # For FromVolume resume policy PVC should not be deleted.
+ if experiment.spec.resume_policy == "FromVolume":
+ try:
+ client.CoreV1Api().read_namespaced_persistent_volume_claim(
+ resource_name, exp_namespace
+ )
+ except client.ApiException:
+ raise Exception("PVC is deleted for FromVolume Resume Policy")
diff --git a/test/unit/v1beta1/suggestion/test_darts_service.py b/test/unit/v1beta1/suggestion/test_darts_service.py
index c3cc792ba76..9355364ff2b 100644
--- a/test/unit/v1beta1/suggestion/test_darts_service.py
+++ b/test/unit/v1beta1/suggestion/test_darts_service.py
@@ -19,9 +19,10 @@
import grpc_testing
from pkg.apis.manager.v1beta1.python import api_pb2
-from pkg.suggestion.v1beta1.nas.darts.service import DartsService
-from pkg.suggestion.v1beta1.nas.darts.service import \
- validate_algorithm_settings
+from pkg.suggestion.v1beta1.nas.darts.service import (
+ DartsService,
+ validate_algorithm_settings,
+)
class TestDarts(unittest.TestCase):