Skip to content

Commit

Permalink
Hparams: Support excluding metric information in HTTP requests. (tens…
Browse files Browse the repository at this point in the history
…orflow#6556)

There are some clients of the Hparams HTTP API that do not require the
metric information. This includes the metric_infos usually returned in
the /experiments request and the metric_values usually returned in the
/session_groups request. Since these can be expensive to calculate, we
want the option to not calculate and return them in the response.

Add option `include_metrics` to both GetExperimentRequest and
ListSessionGroupsRequest. If unspecified we treat `include_metrics` as
True, for backward compatibility. Honor the `include_metrics` property
in all three major cases: When experiment metadata is defined by
Experiment tags, by Session tags, or by the DataProvider.
  • Loading branch information
bmd3k authored and yatbear committed Aug 25, 2023
1 parent 3565328 commit b3edbf3
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 41 deletions.
10 changes: 8 additions & 2 deletions tensorboard/plugins/hparams/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,20 @@ enum Status {

// Parameters for a GetExperiment API call.
// Each experiment is scoped by a unique global id.
// NEXT_TAG: 2
// NEXT_TAG: 3
message GetExperimentRequest {
// REQUIRED
string experiment_name = 1;

// Whether to fetch metrics and include them in the results. Defaults to true.
optional bool include_metrics = 2;
}

// Parameters for a ListSessionGroups API call.
// Computes a list of the current session groups allowing for filtering and
// sorting by metrics and hyperparameter values. Returns a "slice" of
// that list specified by start_index and slice_size.
// NEXT_TAG: 8
// NEXT_TAG: 9
message ListSessionGroupsRequest {
string experiment_name = 6;

Expand Down Expand Up @@ -314,6 +317,9 @@ message ListSessionGroupsRequest {
// sorted and filtered by the parameters above (if start_index > total_size
// no session groups are returned).
int32 slice_size = 5;

// Whether to fetch metrics and include them in the results. Defaults to true.
optional bool include_metrics = 8;
}

// Defines parmeters for a ListSessionGroupsRequest for a specific column.
Expand Down
44 changes: 34 additions & 10 deletions tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def experiment_from_metadata(
self,
ctx,
experiment_id,
include_metrics,
hparams_run_to_tag_to_content,
data_provider_hparams,
):
Expand All @@ -76,6 +77,8 @@ def experiment_from_metadata(
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
hparams_run_to_tag_to_content: The output from an hparams_metadata()
call. A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
Expand All @@ -87,19 +90,21 @@ def experiment_from_metadata(
The experiment proto. If no data is found for an experiment proto to
be built, returns an entirely empty experiment.
"""
experiment = self._find_experiment_tag(hparams_run_to_tag_to_content)
experiment = self._find_experiment_tag(
hparams_run_to_tag_to_content, include_metrics
)
if experiment:
return experiment

experiment_from_runs = self._compute_experiment_from_runs(
ctx, experiment_id, hparams_run_to_tag_to_content
ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
)
if experiment_from_runs:
return experiment_from_runs

experiment_from_data_provider_hparams = (
self._experiment_from_data_provider_hparams(
ctx, experiment_id, data_provider_hparams
ctx, experiment_id, include_metrics, data_provider_hparams
)
)
return (
Expand Down Expand Up @@ -202,7 +207,9 @@ def session_groups_from_data_provider(
ctx, experiment_ids=[experiment_id], filters=filters, sort=sort
)

def _find_experiment_tag(self, hparams_run_to_tag_to_content):
def _find_experiment_tag(
self, hparams_run_to_tag_to_content, include_metrics
):
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
tag.
Expand All @@ -214,23 +221,34 @@ def _find_experiment_tag(self, hparams_run_to_tag_to_content):
for tags in hparams_run_to_tag_to_content.values():
maybe_content = tags.get(metadata.EXPERIMENT_TAG)
if maybe_content is not None:
return metadata.parse_experiment_plugin_data(maybe_content)
experiment = metadata.parse_experiment_plugin_data(
maybe_content
)
if not include_metrics:
# metric_infos haven't technically been "calculated" in this
# case. They have been read directly from the Experiment
# proto.
# Delete them from the result so that they are not returned
# to the client.
experiment.ClearField("metric_infos")
return experiment
return None

def _compute_experiment_from_runs(
self, ctx, experiment_id, hparams_run_to_tag_to_content
self, ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
):
"""Computes a minimal Experiment protocol buffer by scanning the runs.
Returns None if there are no hparam infos logged.
"""
hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content)
if hparam_infos:
metric_infos = self._compute_metric_infos_from_runs(
metric_infos = (
self._compute_metric_infos_from_runs(
ctx, experiment_id, hparams_run_to_tag_to_content
)
else:
metric_infos = []
if hparam_infos and include_metrics
else []
)
if not hparam_infos and not metric_infos:
return None

Expand Down Expand Up @@ -320,11 +338,15 @@ def _experiment_from_data_provider_hparams(
self,
ctx,
experiment_id,
include_metrics,
data_provider_hparams,
):
"""Returns an experiment protobuffer based on data provider hparams.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
data_provider_hparams: The ouput from an hparams_from_data_provider()
call, corresponding to DataProvider.list_hyperparameters().
A provider.ListHyperparametersResult.
Expand Down Expand Up @@ -352,6 +374,8 @@ def _experiment_from_data_provider_hparams(
self.compute_metric_infos_from_data_provider_session_groups(
ctx, experiment_id, session_groups
)
if include_metrics
else []
)
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
Expand Down
87 changes: 81 additions & 6 deletions tensorboard/plugins/hparams/backend_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ def _mock_list_hyperparameters(
):
return self._hyperparameters

def _experiment_from_metadata(self):
def _experiment_from_metadata(self, *, include_metrics=True):
"""Calls the expected operations for generating an Experiment proto."""
ctxt = backend_context.Context(self._mock_tb_context)
request_ctx = context.RequestContext()
return ctxt.experiment_from_metadata(
request_ctx,
"123",
include_metrics,
ctxt.hparams_metadata(request_ctx, "123"),
ctxt.hparams_from_data_provider(request_ctx, "123"),
)
Expand Down Expand Up @@ -187,7 +188,39 @@ def test_experiment_with_experiment_tag(self):
}
self.assertProtoEquals(experiment, self._experiment_from_metadata())

def test_experiment_without_experiment_tag(self):
def test_experiment_with_experiment_tag_include_metrics(self):
experiment = """
description: 'Test experiment'
metric_infos: [
{ name: { tag: 'current_temp' } },
{ name: { tag: 'delta_temp' } }
]
"""
run = "exp"
tag = metadata.EXPERIMENT_TAG
t = provider.TensorTimeSeries(
max_step=0,
max_wall_time=0,
plugin_content=self._serialized_plugin_data(
DATA_TYPE_EXPERIMENT, experiment
),
description="",
display_name="",
)
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._mock_tb_context.data_provider.list_tensors.return_value = {
run: {tag: t}
}

with self.subTest("False"):
response = self._experiment_from_metadata(include_metrics=False)
self.assertEmpty(response.metric_infos)

with self.subTest("True"):
response = self._experiment_from_metadata(include_metrics=True)
self.assertLen(response.metric_infos, 2)

def test_experiment_with_session_tags(self):
self.session_1_start_info_ = """
hparams: [
{key: 'batch_size' value: {number_value: 100}},
Expand Down Expand Up @@ -243,7 +276,7 @@ def test_experiment_without_experiment_tag(self):
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_without_experiment_tag_different_hparam_types(self):
def test_experiment_with_session_tags_different_hparam_types(self):
self.session_1_start_info_ = """
hparams:[
{key: 'batch_size' value: {number_value: 100}},
Expand Down Expand Up @@ -304,7 +337,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self):
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_with_bool_types(self):
def test_experiment_with_session_tags_bool_types(self):
self.session_1_start_info_ = """
hparams:[
{key: 'batch_size' value: {bool_value: true}}
Expand Down Expand Up @@ -344,7 +377,9 @@ def test_experiment_with_bool_types(self):
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_with_string_domain_and_invalid_number_values(self):
def test_experiment_with_session_tags_string_domain_and_invalid_number_values(
self,
):
self.session_1_start_info_ = """
hparams:[
{key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}}
Expand All @@ -371,8 +406,21 @@ def test_experiment_with_string_domain_and_invalid_number_values(self):
self.assertLen(actual_exp.hparam_infos, 1)
self.assertProtoEquals(expected_hparam_info, actual_exp.hparam_infos[0])

def test_experiment_with_session_tags_include_metrics(self):
self.session_1_start_info_ = """
hparams: [
{key: 'batch_size' value: {number_value: 100}}
]
"""
with self.subTest("False"):
response = self._experiment_from_metadata(include_metrics=False)
self.assertEmpty(response.metric_infos)

with self.subTest("True"):
response = self._experiment_from_metadata(include_metrics=True)
self.assertLen(response.metric_infos, 4)

def test_experiment_without_any_hparams(self):
request_ctx = context.RequestContext()
actual_exp = self._experiment_from_metadata()
self.assertIsInstance(actual_exp, api_pb2.Experiment)
self.assertProtoEquals("", actual_exp)
Expand Down Expand Up @@ -789,6 +837,33 @@ def test_experiment_from_data_provider_session_group_without_session_names(
"""
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_from_data_provider_include_metrics(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._hyperparameters = provider.ListHyperparametersResult(
hyperparameters=[],
session_groups=[
provider.HyperparameterSessionGroup(
root=provider.HyperparameterSessionRun(
experiment_id="exp", run=""
),
sessions=[
provider.HyperparameterSessionRun(
experiment_id="exp", run="session_1"
),
],
hyperparameter_values=[],
),
],
)

with self.subTest("False"):
response = self._experiment_from_metadata(include_metrics=False)
self.assertEmpty(response.metric_infos)

with self.subTest("True"):
response = self._experiment_from_metadata(include_metrics=True)
self.assertLen(response.metric_infos, 4)

def test_experiment_from_data_provider_old_response_type(self):
self._hyperparameters = [
provider.Hyperparameter(
Expand Down
19 changes: 14 additions & 5 deletions tensorboard/plugins/hparams/get_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,41 @@
class Handler:
"""Handles a GetExperiment request."""

def __init__(self, request_context, backend_context, experiment_id):
def __init__(
self, request_context, backend_context, experiment_id, request
):
"""Constructor.
Args:
request_context: A tensorboard.context.RequestContext.
backend_context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
request: A api_pb2.GetExperimentRequest instance.
"""
self._request_context = request_context
self._backend_context = backend_context
self._experiment_id = experiment_id
self._include_metrics = (
# Metrics are included by default if include_metrics is not
# specified in the request.
not request.HasField("include_metrics")
or request.include_metrics
)

def run(self):
"""Handles the request specified on construction.
Returns:
An Experiment object.
"""
experiment_id = self._experiment_id
return self._backend_context.experiment_from_metadata(
self._request_context,
experiment_id,
self._experiment_id,
self._include_metrics,
self._backend_context.hparams_metadata(
self._request_context, experiment_id
self._request_context, self._experiment_id
),
self._backend_context.hparams_from_data_provider(
self._request_context, experiment_id
self._request_context, self._experiment_id
),
)
9 changes: 4 additions & 5 deletions tensorboard/plugins/hparams/hparams_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ def get_experiment_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
# This backend currently ignores the request parameters, but (for a POST)
# we must advance the input stream to skip them -- otherwise the next HTTP
# request will be parsed incorrectly.
_ = _parse_request_argument(request, api_pb2.GetExperimentRequest)
request_proto = _parse_request_argument(
request, api_pb2.GetExperimentRequest
)
return http_util.Respond(
request,
json_format.MessageToJson(
get_experiment.Handler(
ctx, self._context, experiment_id
ctx, self._context, experiment_id, request_proto
).run(),
including_default_value_fields=True,
),
Expand Down
Loading

0 comments on commit b3edbf3

Please sign in to comment.