Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
Signed-off-by: forsaken628 <[email protected]>
  • Loading branch information
forsaken628 committed Jun 6, 2024
1 parent bf9f8cd commit ccd7a4c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
32 changes: 20 additions & 12 deletions pkg/earlystopping/v1beta1/medianstop/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Iterable
from typing import Iterable, Optional
from kubernetes import client, config
import multiprocessing
from datetime import datetime
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self):

self.api_instance = client.CustomObjectsApi()

def ValidateEarlyStoppingSettings(self, request, context):
def ValidateEarlyStoppingSettings(self, request: api_pb2.ValidateEarlyStoppingSettingsRequest, context: grpc.ServicerContext) -> api_pb2.ValidateEarlyStoppingSettingsReply:
is_valid, message = self.validate_early_stopping_spec(request.early_stopping)
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
Expand Down Expand Up @@ -98,7 +98,7 @@ def validate_medianstop_setting(early_stopping_settings):

return True, ""

def GetEarlyStoppingRules(self, request: api_pb2.GetEarlyStoppingRulesRequest, context):
def GetEarlyStoppingRules(self, request: api_pb2.GetEarlyStoppingRulesRequest, context: grpc.ServicerContext) -> api_pb2.GetSuggestionsReply:
logger.info("Get new early stopping rules")

# Get required values for the first call.
Expand Down Expand Up @@ -145,17 +145,25 @@ def get_early_stopping_settings(self, early_stopping_settings: Iterable[api_pb2.
elif setting.name == "start_step":
self.start_step = int(setting.value)

def get_median_value(self, trials: Iterable[api_pb2.Trial]):
def get_median_value(self, trials: Iterable[api_pb2.Trial]) -> Optional[float]:
for trial in trials:
# Get metrics only for the new succeeded Trials.
if trial.name not in self.trials_avg_history and trial.status.condition == SUCCEEDED_TRIAL:
with grpc.beta.implementations.insecure_channel(
self.db_manager_address[0], int(self.db_manager_address[1])) as channel:
if (
trial.name not in self.trials_avg_history
and trial.status.condition == SUCCEEDED_TRIAL
):
with grpc.insecure_channel(
f"{self.db_manager_address[0]}:{self.db_manager_address[1]}"
) as channel:
stub = api_pb2_grpc.DBManagerStub(channel)
get_log_response = stub.GetObservationLog(api_pb2.GetObservationLogRequest(
trial_name=trial.name,
metric_name=self.objective_metric
), timeout=APISERVER_TIMEOUT)
get_log_response: api_pb2.GetObservationLogReply = (
stub.GetObservationLog(
api_pb2.GetObservationLogRequest(
trial_name=trial.name, metric_name=self.objective_metric
),
timeout=APISERVER_TIMEOUT,
)
)

# Get only first start_step metrics.
# Since metrics are collected consistently and ordered by time, we slice top start_step metrics.
Expand All @@ -182,7 +190,7 @@ def get_median_value(self, trials: Iterable[api_pb2.Trial]):
))
return None

def SetTrialStatus(self, request, context):
def SetTrialStatus(self, request: api_pb2.SetTrialStatusRequest, context: grpc.ServicerContext) -> api_pb2.SetTrialStatusReply:
trial_name = request.trial_name

logger.info("Update status for Trial: {}".format(trial_name))
Expand Down
16 changes: 9 additions & 7 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from sys import path

root = os.path.join(os.path.dirname(__file__),'..')
path.extend([
os.path.join(root,'pkg/apis/manager/v1beta1/python'),
os.path.join(root,'pkg/apis/manager/health/python'),
os.path.join(root,'pkg/metricscollector/v1beta1/common'),
os.path.join(root,'pkg/metricscollector/v1beta1/tfevent-metricscollector')
])
root = os.path.join(os.path.dirname(__file__), "..")
path.extend(
[
os.path.join(root, "pkg/apis/manager/v1beta1/python"),
os.path.join(root, "pkg/apis/manager/health/python"),
os.path.join(root, "pkg/metricscollector/v1beta1/common"),
os.path.join(root, "pkg/metricscollector/v1beta1/tfevent-metricscollector"),
]
)

0 comments on commit ccd7a4c

Please sign in to comment.