Skip to content

Commit

Permalink
Filter metric values in /list_session_groups. (#6550)
Browse files Browse the repository at this point in the history
Handle metric-based filters for DataProvider-based hparam requests.

There are three parts:
* Stop sending the metric-based filters to the DataProvider. The
DataProvider does not have the metric data to apply the filtering.
* Generate local filters. The generated local filters are based on
metrics only. Local filters are not generated for hparams so as not to
repeat the work of the DataProvider.
* Apply the local filters after session group metrics have been
retrieved and aggregated.
  • Loading branch information
bmd3k authored Aug 18, 2023
1 parent 7e30969 commit fd3b82e
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tensorboard/plugins/hparams/list_session_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ def _session_groups_from_data_provider(self):
if group.sessions:
self._aggregate_metrics(group)

extractors = _create_extractors(self._request.col_params)
filters = _create_filters(
self._request.col_params,
extractors,
# We assume the DataProvider will apply hparam filters and we do not
# attempt to reapply them.
include_hparam_filters=False,
)
session_groups = self._filter(session_groups, filters)
return session_groups

def _build_session_groups(
Expand Down Expand Up @@ -552,20 +561,25 @@ def extractor_fn(session_group):
# True if it should be included in the result. Currently, Filters are functions
# of a single column value extracted from the session group with a given
# extractor specified in the construction of the filter.
def _create_filters(col_params, extractors):
def _create_filters(col_params, extractors, *, include_hparam_filters=True):
"""Creates filters for the given col_params.
Args:
col_params: List of ListSessionGroupsRequest.ColParam protobufs.
extractors: list of extractor functions of the same length as col_params.
Each element should extract the column described by the corresponding
element of col_params.
include_hparam_filters: bool that indicates whether hparam filters should
be generated. Defaults to True.
Returns:
A list of filter functions. Each corresponding to a single
col_params.filter oneof field of _request
"""
result = []
for col_param, extractor in zip(col_params, extractors):
if not include_hparam_filters and col_param.hparam:
continue

a_filter = _create_filter(col_param, extractor)
if a_filter:
result.append(a_filter)
Expand Down Expand Up @@ -860,6 +874,11 @@ def _build_data_provider_filters(col_params):
"""Builds HyperparameterFilters from ColParams."""
filters = []
for col_param in col_params:
if not col_param.hparam:
# We do not pass metric filters to the DataProvider as it does not
# have the metric data for filtering.
continue

fltr = _build_data_provider_filter(col_param)
if fltr is None:
continue
Expand Down
135 changes: 135 additions & 0 deletions tensorboard/plugins/hparams/list_session_groups_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,21 @@ def test_experiment_from_data_provider_sends_discrete_filter(self):
],
)

def test_experiment_from_data_provider_does_not_send_metric_filters(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = None
request = """
col_params: {
metric: { tag: 'delta_temp' }
filter_interval: {
min_value: 0
max_value: 100
}
}
"""
self._run_handler(request)

self.assertEmpty(self._get_read_hyperparameters_call_filters())

def test_experiment_from_data_provider_sends_sort(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = None
request = """
Expand Down Expand Up @@ -2169,6 +2184,126 @@ def test_experiment_from_data_provider_with_metric_values_aggregates(
response.session_groups[0].metric_values[2],
)

def test_experiment_from_data_provider_filters_by_metric_values(
self,
):
# Filters are tested in-depth elsewhere using the Tensor-based hparams.
# For DataProvider-based hparam tests we just test one filter to verify
# the filter logic is being applied.
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._hyperparameters = [
# The sessions names correspond to return values from
# _mock_list_scalars() and _mock_read_scalars() in order to
# generate metric infos and values.
provider.HyperparameterSessionGroup(
root=provider.HyperparameterSessionRun(
experiment_id="session_1", run=""
),
sessions=[
provider.HyperparameterSessionRun(
experiment_id="session_1", run=""
)
],
hyperparameter_values=[],
),
provider.HyperparameterSessionGroup(
root=provider.HyperparameterSessionRun(
experiment_id="session_2", run=""
),
sessions=[
provider.HyperparameterSessionRun(
experiment_id="session_2", run=""
)
],
hyperparameter_values=[],
),
provider.HyperparameterSessionGroup(
root=provider.HyperparameterSessionRun(
experiment_id="session_3", run=""
),
sessions=[
provider.HyperparameterSessionRun(
experiment_id="session_3", run=""
)
],
hyperparameter_values=[],
),
]
request = """
start_index: 0
slice_size: 10
"""
response = self._run_handler(request)
self.assertLen(response.session_groups, 3)
self.assertEqual("session_1", response.session_groups[0].name)
self.assertEqual("session_2", response.session_groups[1].name)
self.assertEqual("session_3", response.session_groups[2].name)

filtered_request = """
start_index: 0
slice_size: 10
col_params: {
metric: { tag: 'delta_temp' }
filter_interval: {
min_value: 0
max_value: 100
}
}
"""
filtered_response = self._run_handler(filtered_request)
# The delta_temp values for session_1, session_2, and session_3 are
# 10, 150, and 1.5, respectively. We expect session_2 to have been
# filtered out.
self.assertLen(filtered_response.session_groups, 2)
self.assertEqual("session_1", filtered_response.session_groups[0].name)
self.assertEqual("session_3", filtered_response.session_groups[1].name)

def test_experiment_from_data_provider_does_not_filter_by_hparam_values(
self,
):
# We assume the DataProvider will apply hparam filters and we do not
# attempt to reapply them.
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._hyperparameters = [
provider.HyperparameterSessionGroup(
root=provider.HyperparameterSessionRun(
experiment_id="session_1", run=""
),
sessions=[
provider.HyperparameterSessionRun(
experiment_id="session_1", run=""
)
],
hyperparameter_values=[
provider.HyperparameterValue(
hyperparameter_name="hparam1",
domain_type=provider.HyperparameterDomainType.INTERVAL,
value=-1.0,
),
],
),
]
request = """
start_index: 0
slice_size: 10
col_params: {
hparam: 'hparam1'
filter_interval: {
min_value: 0
max_value: 100
}
}
"""
response = self._run_handler(request)
# The one result from the DataProvider call is returned even though
# there is an hparam filter that it should not pass. This indicates we
# are purposefully not applying the hparam filters.
#
# Note: The scenario should not happen in practice as we'd expect
# the DataProvider to have successfully applied the filter.
self.assertLen(response.session_groups, 1)
self.assertEqual("session_1", response.session_groups[0].name)

def _run_handler(self, request):
request_proto = api_pb2.ListSessionGroupsRequest()
text_format.Merge(request, request_proto)
Expand Down

0 comments on commit fd3b82e

Please sign in to comment.