From b3edbf3b3e6afd70c3cee86b014f06c7d142cbde Mon Sep 17 00:00:00 2001 From: Brian Dubois Date: Thu, 24 Aug 2023 09:45:53 -0400 Subject: [PATCH] Hparams: Support excluding metric information in HTTP requests. (#6556) There are some clients of the Hparams HTTP API that do not require the metric information. This includes the metric_infos usually returned in the /experiments request and the metric_values usually returned in the /session_groups request. Since these can be expensive to calculate, we want the option to not calculate and return them in the response. Add option `include_metrics` to both GetExperimentRequest and ListSessionGroupsRequest. If unspecified we treat `include_metrics` as True, for backward compatibility. Honor the `include_metrics` property in all three major cases: When experiment metadata is defined by Experiment tags, by Session tags, or by the DataProvider. --- tensorboard/plugins/hparams/api.proto | 10 +- .../plugins/hparams/backend_context.py | 44 ++++++-- .../plugins/hparams/backend_context_test.py | 87 ++++++++++++++- tensorboard/plugins/hparams/get_experiment.py | 19 +++- tensorboard/plugins/hparams/hparams_plugin.py | 9 +- .../plugins/hparams/list_session_groups.py | 43 ++++++-- .../hparams/list_session_groups_test.py | 103 +++++++++++++++++- 7 files changed, 274 insertions(+), 41 deletions(-) diff --git a/tensorboard/plugins/hparams/api.proto b/tensorboard/plugins/hparams/api.proto index 4e30431224b..44c0eb59a9f 100644 --- a/tensorboard/plugins/hparams/api.proto +++ b/tensorboard/plugins/hparams/api.proto @@ -254,17 +254,20 @@ enum Status { // Parameters for a GetExperiment API call. // Each experiment is scoped by a unique global id. -// NEXT_TAG: 2 +// NEXT_TAG: 3 message GetExperimentRequest { // REQUIRED string experiment_name = 1; + + // Whether to fetch metrics and include them in the results. Defaults to true. + optional bool include_metrics = 2; } // Parameters for a ListSessionGroups API call. // Computes a list of the current session groups allowing for filtering and // sorting by metrics and hyperparameter values. Returns a "slice" of // that list specified by start_index and slice_size. -// NEXT_TAG: 8 +// NEXT_TAG: 9 message ListSessionGroupsRequest { string experiment_name = 6; @@ -314,6 +317,9 @@ message ListSessionGroupsRequest { // sorted and filtered by the parameters above (if start_index > total_size // no session groups are returned). int32 slice_size = 5; + + // Whether to fetch metrics and include them in the results. Defaults to true. + optional bool include_metrics = 8; } // Defines parmeters for a ListSessionGroupsRequest for a specific column. diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index 6eab89a62e8..105442765bb 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -57,6 +57,7 @@ def experiment_from_metadata( self, ctx, experiment_id, + include_metrics, hparams_run_to_tag_to_content, data_provider_hparams, ): @@ -76,6 +77,8 @@ def experiment_from_metadata( Args: experiment_id: String, from `plugin_util.experiment_id`. + include_metrics: Whether to determine metrics_infos and include them + in the result. hparams_run_to_tag_to_content: The output from an hparams_metadata() call. A dict `d` such that `d[run][tag]` is a `bytes` value with the summary metadata content for the keyed time series. @@ -87,19 +90,21 @@ def experiment_from_metadata( The experiment proto. If no data is found for an experiment proto to be built, returns an entirely empty experiment. """ - experiment = self._find_experiment_tag(hparams_run_to_tag_to_content) + experiment = self._find_experiment_tag( + hparams_run_to_tag_to_content, include_metrics + ) if experiment: return experiment experiment_from_runs = self._compute_experiment_from_runs( - ctx, experiment_id, hparams_run_to_tag_to_content + ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content ) if experiment_from_runs: return experiment_from_runs experiment_from_data_provider_hparams = ( self._experiment_from_data_provider_hparams( - ctx, experiment_id, data_provider_hparams + ctx, experiment_id, include_metrics, data_provider_hparams ) ) return ( @@ -202,7 +207,9 @@ def session_groups_from_data_provider( ctx, experiment_ids=[experiment_id], filters=filters, sort=sort ) - def _find_experiment_tag(self, hparams_run_to_tag_to_content): + def _find_experiment_tag( + self, hparams_run_to_tag_to_content, include_metrics + ): """Finds the experiment associcated with the metadata.EXPERIMENT_TAG tag. @@ -214,23 +221,34 @@ def _find_experiment_tag(self, hparams_run_to_tag_to_content): for tags in hparams_run_to_tag_to_content.values(): maybe_content = tags.get(metadata.EXPERIMENT_TAG) if maybe_content is not None: - return metadata.parse_experiment_plugin_data(maybe_content) + experiment = metadata.parse_experiment_plugin_data( + maybe_content + ) + if not include_metrics: + # metric_infos haven't technically been "calculated" in this + # case. They have been read directly from the Experiment + # proto. + # Delete them from the result so that they are not returned + # to the client. + experiment.ClearField("metric_infos") + return experiment return None def _compute_experiment_from_runs( - self, ctx, experiment_id, hparams_run_to_tag_to_content + self, ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content ): """Computes a minimal Experiment protocol buffer by scanning the runs. Returns None if there are no hparam infos logged. """ hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content) - if hparam_infos: - metric_infos = self._compute_metric_infos_from_runs( + metric_infos = ( + self._compute_metric_infos_from_runs( ctx, experiment_id, hparams_run_to_tag_to_content ) - else: - metric_infos = [] + if hparam_infos and include_metrics + else [] + ) if not hparam_infos and not metric_infos: return None @@ -320,11 +338,15 @@ def _experiment_from_data_provider_hparams( self, ctx, experiment_id, + include_metrics, data_provider_hparams, ): """Returns an experiment protobuffer based on data provider hparams. Args: + experiment_id: String, from `plugin_util.experiment_id`. + include_metrics: Whether to determine metrics_infos and include them + in the result. data_provider_hparams: The ouput from an hparams_from_data_provider() call, corresponding to DataProvider.list_hyperparameters(). A provider.ListHyperparametersResult. @@ -352,6 +374,8 @@ def _experiment_from_data_provider_hparams( self.compute_metric_infos_from_data_provider_session_groups( ctx, experiment_id, session_groups ) + if include_metrics + else [] ) return api_pb2.Experiment( hparam_infos=hparam_infos, metric_infos=metric_infos diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index 784ce9b1589..af7ebea467d 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -152,13 +152,14 @@ def _mock_list_hyperparameters( ): return self._hyperparameters - def _experiment_from_metadata(self): + def _experiment_from_metadata(self, *, include_metrics=True): """Calls the expected operations for generating an Experiment proto.""" ctxt = backend_context.Context(self._mock_tb_context) request_ctx = context.RequestContext() return ctxt.experiment_from_metadata( request_ctx, "123", + include_metrics, ctxt.hparams_metadata(request_ctx, "123"), ctxt.hparams_from_data_provider(request_ctx, "123"), ) @@ -187,7 +188,39 @@ def test_experiment_with_experiment_tag(self): } self.assertProtoEquals(experiment, self._experiment_from_metadata()) - def test_experiment_without_experiment_tag(self): + def test_experiment_with_experiment_tag_include_metrics(self): + experiment = """ + description: 'Test experiment' + metric_infos: [ + { name: { tag: 'current_temp' } }, + { name: { tag: 'delta_temp' } } + ] + """ + run = "exp" + tag = metadata.EXPERIMENT_TAG + t = provider.TensorTimeSeries( + max_step=0, + max_wall_time=0, + plugin_content=self._serialized_plugin_data( + DATA_TYPE_EXPERIMENT, experiment + ), + description="", + display_name="", + ) + self._mock_tb_context.data_provider.list_tensors.side_effect = None + self._mock_tb_context.data_provider.list_tensors.return_value = { + run: {tag: t} + } + + with self.subTest("False"): + response = self._experiment_from_metadata(include_metrics=False) + self.assertEmpty(response.metric_infos) + + with self.subTest("True"): + response = self._experiment_from_metadata(include_metrics=True) + self.assertLen(response.metric_infos, 2) + + def test_experiment_with_session_tags(self): self.session_1_start_info_ = """ hparams: [ {key: 'batch_size' value: {number_value: 100}}, @@ -243,7 +276,7 @@ def test_experiment_without_experiment_tag(self): _canonicalize_experiment(actual_exp) self.assertProtoEquals(expected_exp, actual_exp) - def test_experiment_without_experiment_tag_different_hparam_types(self): + def test_experiment_with_session_tags_different_hparam_types(self): self.session_1_start_info_ = """ hparams:[ {key: 'batch_size' value: {number_value: 100}}, @@ -304,7 +337,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self): _canonicalize_experiment(actual_exp) self.assertProtoEquals(expected_exp, actual_exp) - def test_experiment_with_bool_types(self): + def test_experiment_with_session_tags_bool_types(self): self.session_1_start_info_ = """ hparams:[ {key: 'batch_size' value: {bool_value: true}} @@ -344,7 +377,9 @@ def test_experiment_with_bool_types(self): _canonicalize_experiment(actual_exp) self.assertProtoEquals(expected_exp, actual_exp) - def test_experiment_with_string_domain_and_invalid_number_values(self): + def test_experiment_with_session_tags_string_domain_and_invalid_number_values( + self, + ): self.session_1_start_info_ = """ hparams:[ {key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}} @@ -371,8 +406,21 @@ def test_experiment_with_string_domain_and_invalid_number_values(self): self.assertLen(actual_exp.hparam_infos, 1) self.assertProtoEquals(expected_hparam_info, actual_exp.hparam_infos[0]) + def test_experiment_with_session_tags_include_metrics(self): + self.session_1_start_info_ = """ + hparams: [ + {key: 'batch_size' value: {number_value: 100}} + ] + """ + with self.subTest("False"): + response = self._experiment_from_metadata(include_metrics=False) + self.assertEmpty(response.metric_infos) + + with self.subTest("True"): + response = self._experiment_from_metadata(include_metrics=True) + self.assertLen(response.metric_infos, 4) + def test_experiment_without_any_hparams(self): - request_ctx = context.RequestContext() actual_exp = self._experiment_from_metadata() self.assertIsInstance(actual_exp, api_pb2.Experiment) self.assertProtoEquals("", actual_exp) @@ -789,6 +837,33 @@ def test_experiment_from_data_provider_session_group_without_session_names( """ self.assertProtoEquals(expected_exp, actual_exp) + def test_experiment_from_data_provider_include_metrics(self): + self._mock_tb_context.data_provider.list_tensors.side_effect = None + self._hyperparameters = provider.ListHyperparametersResult( + hyperparameters=[], + session_groups=[ + provider.HyperparameterSessionGroup( + root=provider.HyperparameterSessionRun( + experiment_id="exp", run="" + ), + sessions=[ + provider.HyperparameterSessionRun( + experiment_id="exp", run="session_1" + ), + ], + hyperparameter_values=[], + ), + ], + ) + + with self.subTest("False"): + response = self._experiment_from_metadata(include_metrics=False) + self.assertEmpty(response.metric_infos) + + with self.subTest("True"): + response = self._experiment_from_metadata(include_metrics=True) + self.assertLen(response.metric_infos, 4) + def test_experiment_from_data_provider_old_response_type(self): self._hyperparameters = [ provider.Hyperparameter( diff --git a/tensorboard/plugins/hparams/get_experiment.py b/tensorboard/plugins/hparams/get_experiment.py index 9499e61301b..5fa3987b110 100644 --- a/tensorboard/plugins/hparams/get_experiment.py +++ b/tensorboard/plugins/hparams/get_experiment.py @@ -18,17 +18,26 @@ class Handler: """Handles a GetExperiment request.""" - def __init__(self, request_context, backend_context, experiment_id): + def __init__( + self, request_context, backend_context, experiment_id, request + ): """Constructor. Args: request_context: A tensorboard.context.RequestContext. backend_context: A backend_context.Context instance. experiment_id: A string, as from `plugin_util.experiment_id`. + request: A api_pb2.GetExperimentRequest instance. """ self._request_context = request_context self._backend_context = backend_context self._experiment_id = experiment_id + self._include_metrics = ( + # Metrics are included by default if include_metrics is not + # specified in the request. + not request.HasField("include_metrics") + or request.include_metrics + ) def run(self): """Handles the request specified on construction. @@ -36,14 +45,14 @@ def run(self): Returns: An Experiment object. """ - experiment_id = self._experiment_id return self._backend_context.experiment_from_metadata( self._request_context, - experiment_id, + self._experiment_id, + self._include_metrics, self._backend_context.hparams_metadata( - self._request_context, experiment_id + self._request_context, self._experiment_id ), self._backend_context.hparams_from_data_provider( - self._request_context, experiment_id + self._request_context, self._experiment_id ), ) diff --git a/tensorboard/plugins/hparams/hparams_plugin.py b/tensorboard/plugins/hparams/hparams_plugin.py index a9f680e93d2..ffb111594f3 100644 --- a/tensorboard/plugins/hparams/hparams_plugin.py +++ b/tensorboard/plugins/hparams/hparams_plugin.py @@ -113,15 +113,14 @@ def get_experiment_route(self, request): ctx = plugin_util.context(request.environ) experiment_id = plugin_util.experiment_id(request.environ) try: - # This backend currently ignores the request parameters, but (for a POST) - # we must advance the input stream to skip them -- otherwise the next HTTP - # request will be parsed incorrectly. - _ = _parse_request_argument(request, api_pb2.GetExperimentRequest) + request_proto = _parse_request_argument( + request, api_pb2.GetExperimentRequest + ) return http_util.Respond( request, json_format.MessageToJson( get_experiment.Handler( - ctx, self._context, experiment_id + ctx, self._context, experiment_id, request_proto ).run(), including_default_value_fields=True, ), diff --git a/tensorboard/plugins/hparams/list_session_groups.py b/tensorboard/plugins/hparams/list_session_groups.py index c825f022159..908abfa5b8d 100644 --- a/tensorboard/plugins/hparams/list_session_groups.py +++ b/tensorboard/plugins/hparams/list_session_groups.py @@ -51,6 +51,12 @@ def __init__( self._backend_context = backend_context self._experiment_id = experiment_id self._request = request + self._include_metrics = ( + # Metrics are included by default if include_metrics is not + # specified in the request. + not self._request.HasField("include_metrics") + or self._request.include_metrics + ) def run(self): """Handles the request specified on construction. @@ -92,6 +98,7 @@ def _session_groups_from_tags(self): experiment = self._backend_context.experiment_from_metadata( self._request_context, self._experiment_id, + self._include_metrics, hparams_run_to_tag_to_content, # Don't pass any information from the DataProvider since we are only # examining session groups based on tag metadata @@ -120,14 +127,22 @@ def _session_groups_from_data_provider(self): sort, ) - metric_infos = self._backend_context.compute_metric_infos_from_data_provider_session_groups( - self._request_context, self._experiment_id, response + metric_infos = ( + self._backend_context.compute_metric_infos_from_data_provider_session_groups( + self._request_context, self._experiment_id, response + ) + if self._include_metrics + else [] ) - all_metric_evals = self._backend_context.read_last_scalars( - self._request_context, - self._experiment_id, - run_tag_filter=None, + all_metric_evals = ( + self._backend_context.read_last_scalars( + self._request_context, + self._experiment_id, + run_tag_filter=None, + ) + if self._include_metrics + else {} ) session_groups = [] @@ -228,12 +243,16 @@ def _build_session_groups( ) metric_runs.add(run) metric_tags.add(tag) - all_metric_evals = self._backend_context.read_last_scalars( - self._request_context, - self._experiment_id, - run_tag_filter=provider.RunTagFilter( - runs=metric_runs, tags=metric_tags - ), + all_metric_evals = ( + self._backend_context.read_last_scalars( + self._request_context, + self._experiment_id, + run_tag_filter=provider.RunTagFilter( + runs=metric_runs, tags=metric_tags + ), + ) + if self._include_metrics + else {} ) for ( session_name, diff --git a/tensorboard/plugins/hparams/list_session_groups_test.py b/tensorboard/plugins/hparams/list_session_groups_test.py index 6437ad35255..8b41c5a431d 100644 --- a/tensorboard/plugins/hparams/list_session_groups_test.py +++ b/tensorboard/plugins/hparams/list_session_groups_test.py @@ -1329,7 +1329,7 @@ def test_filter_hparams_include_invalid_number_values(self): expected_total_size=3, ) - def test_filer_hparams_exclude_invalid_number_values(self): + def test_filter_hparams_exclude_invalid_number_values(self): self._mock_tb_context.data_provider.list_tensors.side_effect = ( self._mock_list_tensors_invalid_number_values ) @@ -1347,6 +1347,51 @@ def test_filer_hparams_exclude_invalid_number_values(self): expected_total_size=2, ) + def test_include_metrics(self): + with self.subTest("False"): + request = """ + start_index: 0 + slice_size: 1 + allowed_statuses: [ + STATUS_SUCCESS + ] + include_metrics: False + """ + response = self._run_handler(request) + self.assertEmpty(response.session_groups[0].metric_values) + self.assertEmpty( + response.session_groups[0].sessions[0].metric_values + ) + + with self.subTest("True"): + request = """ + start_index: 0 + slice_size: 1 + allowed_statuses: [ + STATUS_SUCCESS + ] + include_metrics: True + """ + response = self._run_handler(request) + self.assertLen(response.session_groups[0].metric_values, 3) + self.assertLen( + response.session_groups[0].sessions[0].metric_values, 3 + ) + + with self.subTest("unspecified"): + request = """ + start_index: 0 + slice_size: 1 + allowed_statuses: [ + STATUS_SUCCESS + ] + """ + response = self._run_handler(request) + self.assertLen(response.session_groups[0].metric_values, 3) + self.assertLen( + response.session_groups[0].sessions[0].metric_values, 3 + ) + def test_experiment_without_any_hparams(self): self._mock_tb_context.data_provider.list_tensors.side_effect = None self._hyperparameters = [] @@ -2304,6 +2349,62 @@ def test_experiment_from_data_provider_does_not_filter_by_hparam_values( self.assertLen(response.session_groups, 1) self.assertEqual("session_1", response.session_groups[0].name) + def test_experiment_from_data_provider_include_metrics( + self, + ): + self._mock_tb_context.data_provider.list_tensors.side_effect = None + self._hyperparameters = [ + provider.HyperparameterSessionGroup( + # The sessions names correspond to return values from + # _mock_list_scalars() and _mock_read_scalars() in order to + # generate metric infos and values. + root=provider.HyperparameterSessionRun( + experiment_id="session_2", run="" + ), + sessions=[ + provider.HyperparameterSessionRun( + experiment_id="session_2", run="" + ) + ], + hyperparameter_values=[], + ), + ] + + with self.subTest("False"): + request = """ + start_index: 0 + slice_size: 10 + include_metrics: False + """ + response = self._run_handler(request) + self.assertEmpty(response.session_groups[0].metric_values, 0) + self.assertEmpty( + response.session_groups[0].sessions[0].metric_values, 0 + ) + + with self.subTest("True"): + request = """ + start_index: 0 + slice_size: 10 + include_metrics: True + """ + response = self._run_handler(request) + self.assertLen(response.session_groups[0].metric_values, 2) + self.assertLen( + response.session_groups[0].sessions[0].metric_values, 2 + ) + + with self.subTest("unspecified"): + request = """ + start_index: 0 + slice_size: 10 + """ + response = self._run_handler(request) + self.assertLen(response.session_groups[0].metric_values, 2) + self.assertLen( + response.session_groups[0].sessions[0].metric_values, 2 + ) + def _run_handler(self, request): request_proto = api_pb2.ListSessionGroupsRequest() text_format.Merge(request, request_proto)