diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index 806e9b236d6..98959498d5e 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -403,7 +403,7 @@ def compute_metric_infos_from_data_provider_session_groups( self, ctx, experiment_id, session_groups ): session_runs = set( - f"{s.experiment_id}/{s.run}" + f"{s.experiment_id}/{s.run}" if s.run else s.experiment_id for sg in session_groups for s in sg.sessions ) diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index 0273eafe642..e11bfed7a6e 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -561,7 +561,7 @@ def test_experiment_from_data_provider_discrete_string_hparam(self): """ self.assertProtoEquals(expected_exp, actual_exp) - def test_experiment_from_data_provider_session_group(self): + def test_experiment_from_data_provider_session_groups(self): self._mock_tb_context.data_provider.list_tensors.side_effect = None # The sessions chosen here mimic those returned in the implementation # of _mock_list_tensors. These work nicely with the scalars returned @@ -614,6 +614,44 @@ def test_experiment_from_data_provider_session_group(self): """ self.assertProtoEquals(expected_exp, actual_exp) + def test_experiment_from_data_provider_session_group_without_run_name(self): + self._mock_tb_context.data_provider.list_tensors.side_effect = None + self._hyperparameters = provider.ListHyperparametersResult( + hyperparameters=[], + session_groups=[ + provider.HyperparameterSessionGroup( + root=provider.HyperparameterSessionRun( + experiment_id="exp/session_1", run="" + ), + # The entire path to the run is encoded in the experiment_id + # to allow us to test empty run name while still generating + # metric_infos. + sessions=[ + provider.HyperparameterSessionRun( + experiment_id="exp/session_1", run="" + ), + ], + hyperparameter_values=[], + ), + ], + ) + actual_exp = self._experiment_from_metadata() + expected_exp = """ + metric_infos: { + name: {group: '', tag: 'accuracy'} + } + metric_infos: { + name: {group: '', tag: 'loss'} + } + metric_infos: { + name: {group: 'eval', tag: 'loss'} + } + metric_infos: { + name: {group: 'train', tag: 'loss'} + } + """ + self.assertProtoEquals(expected_exp, actual_exp) + def test_experiment_from_data_provider_old_response_type(self): self._hyperparameters = [ provider.Hyperparameter(