Skip to content

Commit

Permalink
Hparams: Support excluding metric information in HTTP requests. (#6556)
Browse files Browse the repository at this point in the history
There are some clients of the Hparams HTTP API that do not require the
metric information. This includes the metric_infos usually returned in
the /experiments request and the metric_values usually returned in the
/session_groups request. Since these can be expensive to calculate, we
want the option to not calculate and return them in the response.

Add option `include_metrics` to both GetExperimentRequest and
ListSessionGroupsRequest. If unspecified we treat `include_metrics` as
True, for backward compatibility. Honor the `include_metrics` property
in all three major cases: When experiment metadata is defined by
Experiment tags, by Session tags, or by the DataProvider.
  • Loading branch information
bmd3k authored Aug 24, 2023
1 parent 91a637e commit 2a91acc
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 41 deletions.
10 changes: 8 additions & 2 deletions tensorboard/plugins/hparams/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,20 @@ enum Status {

// Parameters for a GetExperiment API call.
// Each experiment is scoped by a unique global id.
// NEXT_TAG: 2
// NEXT_TAG: 3
message GetExperimentRequest {
// REQUIRED
string experiment_name = 1;

// Whether to fetch metrics and include them in the results. Defaults to true.
optional bool include_metrics = 2;
}

// Parameters for a ListSessionGroups API call.
// Computes a list of the current session groups allowing for filtering and
// sorting by metrics and hyperparameter values. Returns a "slice" of
// that list specified by start_index and slice_size.
// NEXT_TAG: 8
// NEXT_TAG: 9
message ListSessionGroupsRequest {
string experiment_name = 6;

Expand Down Expand Up @@ -314,6 +317,9 @@ message ListSessionGroupsRequest {
// sorted and filtered by the parameters above (if start_index > total_size
// no session groups are returned).
int32 slice_size = 5;

// Whether to fetch metrics and include them in the results. Defaults to true.
optional bool include_metrics = 8;
}

// Defines parmeters for a ListSessionGroupsRequest for a specific column.
Expand Down
44 changes: 34 additions & 10 deletions tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def experiment_from_metadata(
self,
ctx,
experiment_id,
include_metrics,
hparams_run_to_tag_to_content,
data_provider_hparams,
):
Expand All @@ -76,6 +77,8 @@ def experiment_from_metadata(
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
hparams_run_to_tag_to_content: The output from an hparams_metadata()
call. A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
Expand All @@ -87,19 +90,21 @@ def experiment_from_metadata(
The experiment proto. If no data is found for an experiment proto to
be built, returns an entirely empty experiment.
"""
experiment = self._find_experiment_tag(hparams_run_to_tag_to_content)
experiment = self._find_experiment_tag(
hparams_run_to_tag_to_content, include_metrics
)
if experiment:
return experiment

experiment_from_runs = self._compute_experiment_from_runs(
ctx, experiment_id, hparams_run_to_tag_to_content
ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
)
if experiment_from_runs:
return experiment_from_runs

experiment_from_data_provider_hparams = (
self._experiment_from_data_provider_hparams(
ctx, experiment_id, data_provider_hparams
ctx, experiment_id, include_metrics, data_provider_hparams
)
)
return (
Expand Down Expand Up @@ -202,7 +207,9 @@ def session_groups_from_data_provider(
ctx, experiment_ids=[experiment_id], filters=filters, sort=sort
)

def _find_experiment_tag(self, hparams_run_to_tag_to_content):
def _find_experiment_tag(
self, hparams_run_to_tag_to_content, include_metrics
):
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
tag.
Expand All @@ -214,23 +221,34 @@ def _find_experiment_tag(self, hparams_run_to_tag_to_content):
for tags in hparams_run_to_tag_to_content.values():
maybe_content = tags.get(metadata.EXPERIMENT_TAG)
if maybe_content is not None:
return metadata.parse_experiment_plugin_data(maybe_content)
experiment = metadata.parse_experiment_plugin_data(
maybe_content
)
if not include_metrics:
# metric_infos haven't technically been "calculated" in this
# case. They have been read directly from the Experiment
# proto.
# Delete them from the result so that they are not returned
# to the client.
experiment.ClearField("metric_infos")
return experiment
return None

def _compute_experiment_from_runs(
self, ctx, experiment_id, hparams_run_to_tag_to_content
self, ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
):
"""Computes a minimal Experiment protocol buffer by scanning the runs.
Returns None if there are no hparam infos logged.
"""
hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content)
if hparam_infos:
metric_infos = self._compute_metric_infos_from_runs(
metric_infos = (
self._compute_metric_infos_from_runs(
ctx, experiment_id, hparams_run_to_tag_to_content
)
else:
metric_infos = []
if hparam_infos and include_metrics
else []
)
if not hparam_infos and not metric_infos:
return None

Expand Down Expand Up @@ -320,11 +338,15 @@ def _experiment_from_data_provider_hparams(
self,
ctx,
experiment_id,
include_metrics,
data_provider_hparams,
):
"""Returns an experiment protobuffer based on data provider hparams.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
data_provider_hparams: The ouput from an hparams_from_data_provider()
call, corresponding to DataProvider.list_hyperparameters().
A provider.ListHyperparametersResult.
Expand Down Expand Up @@ -352,6 +374,8 @@ def _experiment_from_data_provider_hparams(
self.compute_metric_infos_from_data_provider_session_groups(
ctx, experiment_id, session_groups
)
if include_metrics
else []
)
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
Expand Down
87 changes: 81 additions & 6 deletions tensorboard/plugins/hparams/backend_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ def _mock_list_hyperparameters(
):
return self._hyperparameters

def _experiment_from_metadata(self):
def _experiment_from_metadata(self, *, include_metrics=True):
"""Calls the expected operations for generating an Experiment proto."""
ctxt = backend_context.Context(self._mock_tb_context)
request_ctx = context.RequestContext()
return ctxt.experiment_from_metadata(
request_ctx,
"123",
include_metrics,
ctxt.hparams_metadata(request_ctx, "123"),
ctxt.hparams_from_data_provider(request_ctx, "123"),
)
Expand Down Expand Up @@ -187,7 +188,39 @@ def test_experiment_with_experiment_tag(self):
}
self.assertProtoEquals(experiment, self._experiment_from_metadata())

def test_experiment_without_experiment_tag(self):
def test_experiment_with_experiment_tag_include_metrics(self):
experiment = """
description: 'Test experiment'
metric_infos: [
{ name: { tag: 'current_temp' } },
{ name: { tag: 'delta_temp' } }
]
"""
run = "exp"
tag = metadata.EXPERIMENT_TAG
t = provider.TensorTimeSeries(
max_step=0,
max_wall_time=0,
plugin_content=self._serialized_plugin_data(
DATA_TYPE_EXPERIMENT, experiment
),
description="",
display_name="",
)
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._mock_tb_context.data_provider.list_tensors.return_value = {
run: {tag: t}
}

with self.subTest("False"):
response = self._experiment_from_metadata(include_metrics=False)
self.assertEmpty(response.metric_infos)

with self.subTest("True"):
response = self._experiment_from_metadata(include_metrics=True)
self.assertLen(response.metric_infos, 2)

def test_experiment_with_session_tags(self):
self.session_1_start_info_ = """
hparams: [
{key: 'batch_size' value: {number_value: 100}},
Expand Down Expand Up @@ -243,7 +276,7 @@ def test_experiment_without_experiment_tag(self):
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_without_experiment_tag_different_hparam_types(self):
def test_experiment_with_session_tags_different_hparam_types(self):
self.session_1_start_info_ = """
hparams:[
{key: 'batch_size' value: {number_value: 100}},
Expand Down Expand Up @@ -304,7 +337,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self):
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_with_bool_types(self):
def test_experiment_with_session_tags_bool_types(self):
self.session_1_start_info_ = """
hparams:[
{key: 'batch_size' value: {bool_value: true}}
Expand Down Expand Up @@ -344,7 +377,9 @@ def test_experiment_with_bool_types(self):
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_with_string_domain_and_invalid_number_values(self):
def test_experiment_with_session_tags_string_domain_and_invalid_number_values(
self,
):
self.session_1_start_info_ = """
hparams:[
{key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}}
Expand All @@ -371,8 +406,21 @@ def test_experiment_with_string_domain_and_invalid_number_values(self):
self.assertLen(actual_exp.hparam_infos, 1)
self.assertProtoEquals(expected_hparam_info, actual_exp.hparam_infos[0])

def test_experiment_with_session_tags_include_metrics(self):
self.session_1_start_info_ = """
hparams: [
{key: 'batch_size' value: {number_value: 100}}
]
"""
with self.subTest("False"):
response = self._experiment_from_metadata(include_metrics=False)
self.assertEmpty(response.metric_infos)

with self.subTest("True"):
response = self._experiment_from_metadata(include_metrics=True)
self.assertLen(response.metric_infos, 4)

def test_experiment_without_any_hparams(self):
request_ctx = context.RequestContext()
actual_exp = self._experiment_from_metadata()
self.assertIsInstance(actual_exp, api_pb2.Experiment)
self.assertProtoEquals("", actual_exp)
Expand Down Expand Up @@ -789,6 +837,33 @@ def test_experiment_from_data_provider_session_group_without_session_names(
"""
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_from_data_provider_include_metrics(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._hyperparameters = provider.ListHyperparametersResult(
hyperparameters=[],
session_groups=[
provider.HyperparameterSessionGroup(
root=provider.HyperparameterSessionRun(
experiment_id="exp", run=""
),
sessions=[
provider.HyperparameterSessionRun(
experiment_id="exp", run="session_1"
),
],
hyperparameter_values=[],
),
],
)

with self.subTest("False"):
response = self._experiment_from_metadata(include_metrics=False)
self.assertEmpty(response.metric_infos)

with self.subTest("True"):
response = self._experiment_from_metadata(include_metrics=True)
self.assertLen(response.metric_infos, 4)

def test_experiment_from_data_provider_old_response_type(self):
self._hyperparameters = [
provider.Hyperparameter(
Expand Down
19 changes: 14 additions & 5 deletions tensorboard/plugins/hparams/get_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,41 @@
class Handler:
"""Handles a GetExperiment request."""

def __init__(self, request_context, backend_context, experiment_id):
def __init__(
self, request_context, backend_context, experiment_id, request
):
"""Constructor.
Args:
request_context: A tensorboard.context.RequestContext.
backend_context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
request: A api_pb2.GetExperimentRequest instance.
"""
self._request_context = request_context
self._backend_context = backend_context
self._experiment_id = experiment_id
self._include_metrics = (
# Metrics are included by default if include_metrics is not
# specified in the request.
not request.HasField("include_metrics")
or request.include_metrics
)

def run(self):
"""Handles the request specified on construction.
Returns:
An Experiment object.
"""
experiment_id = self._experiment_id
return self._backend_context.experiment_from_metadata(
self._request_context,
experiment_id,
self._experiment_id,
self._include_metrics,
self._backend_context.hparams_metadata(
self._request_context, experiment_id
self._request_context, self._experiment_id
),
self._backend_context.hparams_from_data_provider(
self._request_context, experiment_id
self._request_context, self._experiment_id
),
)
9 changes: 4 additions & 5 deletions tensorboard/plugins/hparams/hparams_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ def get_experiment_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
# This backend currently ignores the request parameters, but (for a POST)
# we must advance the input stream to skip them -- otherwise the next HTTP
# request will be parsed incorrectly.
_ = _parse_request_argument(request, api_pb2.GetExperimentRequest)
request_proto = _parse_request_argument(
request, api_pb2.GetExperimentRequest
)
return http_util.Respond(
request,
json_format.MessageToJson(
get_experiment.Handler(
ctx, self._context, experiment_id
ctx, self._context, experiment_id, request_proto
).run(),
including_default_value_fields=True,
),
Expand Down
Loading

0 comments on commit 2a91acc

Please sign in to comment.