From d5fd18815ae51d3688ada347787977a5eb513e88 Mon Sep 17 00:00:00 2001 From: Riley Jones Date: Mon, 21 Aug 2023 22:44:57 +0000 Subject: [PATCH] add new hparams selectors to get default filters --- tensorboard/webapp/hparams/_redux/BUILD | 2 + .../webapp/hparams/_redux/hparams_actions.ts | 11 +++++ .../webapp/hparams/_redux/hparams_reducers.ts | 28 ++++++++++++ .../hparams/_redux/hparams_selectors.ts | 44 ++++++++++++++++++- tensorboard/webapp/hparams/_redux/types.ts | 18 +++++--- tensorboard/webapp/hparams/_redux/utils.ts | 34 ++++++++++++++ .../views/main_view/common_selectors.ts | 43 +++++++++--------- .../views/runs_table/runs_table_container.ts | 27 ++++-------- 8 files changed, 158 insertions(+), 49 deletions(-) diff --git a/tensorboard/webapp/hparams/_redux/BUILD b/tensorboard/webapp/hparams/_redux/BUILD index 6439918303f..ede27e0e464 100644 --- a/tensorboard/webapp/hparams/_redux/BUILD +++ b/tensorboard/webapp/hparams/_redux/BUILD @@ -10,6 +10,7 @@ tf_ts_library( "utils.ts", ], deps = [ + ":types", "//tensorboard/webapp/hparams:types", ], ) @@ -47,6 +48,7 @@ tf_ts_library( "hparams_actions.ts", ], deps = [ + ":types", "//tensorboard/webapp/hparams:types", "@npm//@ngrx/store", ], diff --git a/tensorboard/webapp/hparams/_redux/hparams_actions.ts b/tensorboard/webapp/hparams/_redux/hparams_actions.ts index 3dc2a315a0b..6780c711466 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_actions.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_actions.ts @@ -22,6 +22,7 @@ import { HparamAndMetricSpec, SessionGroup, } from '../types'; +import {HparamFilter, MetricFilter} from './types'; export const hparamsDiscreteHparamFilterChanged = createAction( '[Hparams] Hparams Discrete Hparam Filter Changed', @@ -62,3 +63,13 @@ export const hparamsFetchSessionGroupsSucceeded = createAction( sessionGroups: SessionGroup[]; }>() ); + +export const dashboardHparamFilterAdded = createAction( + '[Hparams] Dashboard Hparam Filter Added', + props<{name: string; filter: HparamFilter}>() +); + +export const dashboardMetricFilterAdded = createAction( + '[Hparams] Dashboard Metric Filter Added', + props<{name: string; filter: MetricFilter}>() +); diff --git a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts index 54ff2ece3aa..c772acaaf93 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts @@ -37,6 +37,10 @@ const initialState: HparamsState = { metrics: [], }, dashboardSessionGroups: [], + dashboardFilters: { + hparams: new Map(), + metrics: new Map(), + }, }; const reducer: ActionReducer = createReducer( @@ -369,6 +373,30 @@ const reducer: ActionReducer = createReducer( dashboardSpecs: nextDashboardSpecs, dashboardSessionGroups: nextDashboardSessionGroups, }; + }), + on(actions.dashboardHparamFilterAdded, (state, action) => { + return { + ...state, + dashboardFilters: { + ...state.dashboardFilters, + hparams: { + ...state.dashboardFilters.hparams, + [action.name]: action.filter, + }, + }, + }; + }), + on(actions.dashboardMetricFilterAdded, (state, action) => { + return { + ...state, + dashboardFilters: { + ...state.dashboardFilters, + metrics: { + ...state.dashboardFilters.metrics, + [action.name]: action.filter, + }, + }, + }; }) ); diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index d7dc87650bd..eeca25271fe 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -21,11 +21,17 @@ import { RunToHparamsAndMetrics, } from '../types'; import {combineHparamAndMetricSpecs} from './hparams_selectors_utils'; -import {HparamsState, HPARAMS_FEATURE_KEY} from './types'; +import { + HparamsState, + HPARAMS_FEATURE_KEY, + HparamsAndMetricsFilters, +} from './types'; import { combineDefaultHparamFilters, combineDefaultMetricFilters, getIdFromExperimentIds, + hparamSpecToDefaultFilter, + metricSpecToDefaultFilter, } from './utils'; const getHparamsState = @@ -232,3 +238,39 @@ export const getDashboardRunsToHparamsAndMetrics = createSelector( return runToHparamsAndMetrics; } ); + +export const getDashboardDefaultFilters = createSelector( + getDashboardHparamsAndMetricsSpecs, + (specs): HparamsAndMetricsFilters => { + const hparams = new Map( + specs.hparams.map((hparamSpec) => { + return [hparamSpec.name, hparamSpecToDefaultFilter(hparamSpec)]; + }) + ); + + const metrics = new Map( + specs.metrics.map((metricSpec) => { + return [metricSpec.name.tag, metricSpecToDefaultFilter(metricSpec)]; + }) + ); + + return { + hparams, + metrics, + }; + } +); + +export const getDashboardHparamFilterMap = createSelector( + getHparamsState, + (state) => { + return state.dashboardFilters.hparams; + } +); + +export const getDashboardMetricsFilterMap = createSelector( + getHparamsState, + (state) => { + return state.dashboardFilters.metrics; + } +); diff --git a/tensorboard/webapp/hparams/_redux/types.ts b/tensorboard/webapp/hparams/_redux/types.ts index 5ec4fa10a15..a147ea849c7 100644 --- a/tensorboard/webapp/hparams/_redux/types.ts +++ b/tensorboard/webapp/hparams/_redux/types.ts @@ -21,14 +21,22 @@ import { SessionGroup, } from '../_types'; +export type HparamFilter = DiscreteFilter | IntervalFilter; +export type MetricFilter = IntervalFilter; + +export interface HparamsAndMetricsFilters { + hparams: Map; + metrics: Map; +} + export interface HparamsMetricsAndFilters { hparam: { specs: HparamSpec[]; - defaultFilters: Map; + defaultFilters: Map; }; metric: { specs: MetricSpec[]; - defaultFilters: Map; + defaultFilters: Map; }; } @@ -47,6 +55,7 @@ export interface HparamsState { specs: ExperimentToHparams; dashboardSpecs: HparamAndMetricSpec; dashboardSessionGroups: SessionGroup[]; + dashboardFilters: HparamsAndMetricsFilters; /** * RATIONALE: we do not use the NamespaceContextedState because of the following reasons. * - RunsTable which uses the state renders both on the dashboard view and the @@ -61,10 +70,7 @@ export interface HparamsState { * separate for the list and the dashboard views, but them shared is not too bad. */ filters: { - [id: string]: { - hparams: Map; - metrics: Map; - }; + [id: string]: HparamsAndMetricsFilters; }; } diff --git a/tensorboard/webapp/hparams/_redux/utils.ts b/tensorboard/webapp/hparams/_redux/utils.ts index 06b5b59f989..3d1f7edba93 100644 --- a/tensorboard/webapp/hparams/_redux/utils.ts +++ b/tensorboard/webapp/hparams/_redux/utils.ts @@ -18,8 +18,11 @@ import { DiscreteHparamValue, DiscreteHparamValues, DomainType, + HparamSpec, IntervalFilter, + MetricSpec, } from '../types'; +import {HparamFilter, MetricFilter} from './types'; export function getIdFromExperimentIds(experimentIds: string[]): string { return JSON.stringify([...experimentIds].sort()); @@ -147,3 +150,34 @@ export function combineDefaultMetricFilters( return intervalMetrics; } + +export function hparamSpecToDefaultFilter(spec: HparamSpec): HparamFilter { + if (spec.domain.type === DomainType.DISCRETE) { + return { + type: DomainType.DISCRETE, + includeUndefined: true, + possibleValues: spec.domain.values, + filterValues: spec.domain.values, + }; + } + + return { + type: DomainType.INTERVAL, + includeUndefined: true, + minValue: spec.domain.minValue, + maxValue: spec.domain.maxValue, + filterLowerValue: spec.domain.minValue, + filterUpperValue: spec.domain.maxValue, + }; +} + +export function metricSpecToDefaultFilter(spec: MetricSpec): MetricFilter { + return { + type: DomainType.INTERVAL, + includeUndefined: true, + minValue: -Infinity, + maxValue: Infinity, + filterLowerValue: -Infinity, + filterUpperValue: Infinity, + }; +} diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 36f991ab28d..70599fcd699 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -29,9 +29,10 @@ import { } from '../../../selectors'; import {DeepReadonly} from '../../../util/types'; import { - getHparamFilterMapFromExperimentIds, - getMetricFilterMapFromExperimentIds, getDashboardHparamsAndMetricsSpecs, + getDashboardHparamFilterMap, + getDashboardMetricsFilterMap, + getDashboardDefaultFilters, } from '../../../hparams/_redux/hparams_selectors'; import { DiscreteFilter, @@ -185,6 +186,20 @@ const utils = { }, }; +export const getCurrentColumnFilters = createSelector( + getDashboardDefaultFilters, + getDashboardHparamFilterMap, + getDashboardMetricsFilterMap, + ({hparams, metrics}, hparamFilters, metricFilters) => { + return new Map([ + ...hparams, + ...metrics, + ...hparamFilters, + ...metricFilters, + ]); + } +); + const getRenderableRuns = memoize((experimentIds: string[]) => { return createSelector( getDashboardRuns(experimentIds), @@ -226,8 +241,8 @@ const getFilteredRenderableRuns = memoize((experimentIds: string[]) => { return createSelector( getRunSelectorRegexFilter, getRenderableRuns(experimentIds), - getHparamFilterMapFromExperimentIds(experimentIds), - getMetricFilterMapFromExperimentIds(experimentIds), + getDashboardHparamFilterMap, + getDashboardMetricsFilterMap, getRouteKind, (regexFilter, runItems, hparamFilters, metricFilters, routeKind) => { const regexFilteredItems = utils.filterRunItemsByRegex( @@ -245,28 +260,10 @@ const getFilteredRenderableRuns = memoize((experimentIds: string[]) => { ); }); -const getCurrentColumnFilters = memoize((experimentIds: string[]) => { - return createSelector( - getHparamFilterMapFromExperimentIds(experimentIds, true), - getMetricFilterMapFromExperimentIds(experimentIds, true), - (hparamFilters, metricFilters) => { - return new Map([...hparamFilters, ...metricFilters]); - } - ); -}); - -export const getCurrentColumnFiltersFromRoute = createSelector( - (state) => state, - getExperimentIdsFromRoute, - (state, experimentIds) => { - return getCurrentColumnFilters(experimentIds || [])(state); - } -); - export const getFilteredRenderableRunsFromRoute = createSelector( (state) => state, getExperimentIdsFromRoute, - (state, experimentIds) => { + (state, experimentIds): RunTableItem[] => { return getFilteredRenderableRuns(experimentIds || [])(state); } ); diff --git a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts index 5a0266625b2..bee49f09b79 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts @@ -96,7 +96,7 @@ import { } from './runs_table_component'; import {RunsTableColumn, RunTableItem} from './types'; import { - getCurrentColumnFiltersFromRoute, + getCurrentColumnFilters, getFilteredRenderableRunsFromRoute, getPotentialHparamColumns, } from '../../../metrics/views/main_view/common_selectors'; @@ -375,7 +375,7 @@ export class RunsTableContainer implements OnInit, OnDestroy { }) ); - columnFilters$ = this.store.select(getCurrentColumnFiltersFromRoute); + columnFilters$ = this.store.select(getCurrentColumnFilters); allRunsTableData$ = this.store .select(getFilteredRenderableRunsFromRoute) @@ -835,23 +835,12 @@ export class RunsTableContainer implements OnInit, OnDestroy { } addHparamFilter(event: FilterAddedEvent) { - switch (event.value.type) { - case DomainType.INTERVAL: - this.onHparamIntervalFilterChanged({ - name: event.header.name, - includeUndefined: event.value.includeUndefined, - filterLowerValue: (event.value as IntervalFilter).filterLowerValue, - filterUpperValue: (event.value as IntervalFilter).filterUpperValue, - }); - break; - case DomainType.DISCRETE: - this.onHparamDiscreteFilterChanged({ - hparamName: event.header.name, - includeUndefined: event.value.includeUndefined, - filterValues: (event.value as DiscreteFilter).filterValues, - }); - break; - } + this.store.dispatch( + hparamsActions.dashboardHparamFilterAdded({ + name: event.header.name, + filter: event.value, + }) + ); } }