From 91a637e2979b8d11a8a64e82cf2b3d900b84e43d Mon Sep 17 00:00:00 2001 From: Riley Jones <78179109+rileyajones@users.noreply.github.com> Date: Wed, 23 Aug 2023 14:41:59 -0700 Subject: [PATCH] Hparams: refactor arguments provided to `getDashboardRuns` (#6555) ## Motivation for features / changes As per a discussion around #6544 we determined that the dashboard should not be requesting runs per experiment but should instead be getting all runs for all experiments shown. This greatly simplifies a lot of the logic used to retrieve runs and run related data. --- tensorboard/webapp/experiments/store/BUILD | 5 + .../store/experiments_selectors.ts | 31 ++-- .../store/experiments_selectors_test.ts | 48 +++++-- .../card_renderer/scalar_card_container.ts | 4 +- .../views/card_renderer/scalar_card_test.ts | 40 +++--- .../views/main_view/common_selectors.ts | 134 ++++++++---------- .../views/main_view/common_selectors_test.ts | 98 +++++-------- tensorboard/webapp/runs/store/BUILD | 1 + .../webapp/runs/store/runs_selectors.ts | 49 ++++--- .../webapp/runs/store/runs_selectors_test.ts | 49 +++++-- .../views/runs_table/runs_table_container.ts | 34 +++-- .../runs/views/runs_table/runs_table_test.ts | 8 +- 12 files changed, 269 insertions(+), 232 deletions(-) diff --git a/tensorboard/webapp/experiments/store/BUILD b/tensorboard/webapp/experiments/store/BUILD index 2cef1f9c3a6..f9347dad644 100644 --- a/tensorboard/webapp/experiments/store/BUILD +++ b/tensorboard/webapp/experiments/store/BUILD @@ -26,6 +26,7 @@ tf_ts_library( ], deps = [ ":types", + "//tensorboard/webapp/app_routing/store", "//tensorboard/webapp/experiments:types", "//tensorboard/webapp/types", "@npm//@ngrx/store", @@ -65,6 +66,10 @@ tf_ts_library( ":selectors", ":testing", ":types", + "//tensorboard/webapp:app_state", + "//tensorboard/webapp/app_routing:types", + "//tensorboard/webapp/app_routing/store:testing", + "//tensorboard/webapp/testing:utils", "//tensorboard/webapp/types", "@npm//@types/jasmine", ], diff --git a/tensorboard/webapp/experiments/store/experiments_selectors.ts b/tensorboard/webapp/experiments/store/experiments_selectors.ts index 64af2a336e1..c63d6e99ff5 100644 --- a/tensorboard/webapp/experiments/store/experiments_selectors.ts +++ b/tensorboard/webapp/experiments/store/experiments_selectors.ts @@ -19,6 +19,7 @@ import { ExperimentsState, EXPERIMENTS_FEATURE_KEY, } from './experiments_types'; +import {getExperimentIdsFromRoute} from '../../app_routing/store/app_routing_selectors'; const getExperimentsState = createFeatureSelector( EXPERIMENTS_FEATURE_KEY @@ -46,18 +47,20 @@ export const getExperiment = createSelector( ); /** - * Returns Observable that emits an object mapping the provided - * experiment ids to experiment names. + * Returns the names of all experiments present on the current dashboard. */ -export const getExperimentNames = (experimentIds: string[]) => - createSelector( - getDataState, - (state: ExperimentsDataState): Record => - experimentIds - .map((experimentId) => state.experimentMap[experimentId]) - .filter(Boolean) - .reduce((map, experiment) => { - map[experiment.id] = experiment.name; - return map; - }, {} as Record) - ); +export const getDashboardExperimentNames = createSelector( + getDataState, + getExperimentIdsFromRoute, + ( + state: ExperimentsDataState, + experimentIds: string[] | null + ): Record => + (experimentIds ?? []) + .map((experimentId) => state.experimentMap[experimentId]) + .filter(Boolean) + .reduce((map, experiment) => { + map[experiment.id] = experiment.name; + return map; + }, {} as Record) +); diff --git a/tensorboard/webapp/experiments/store/experiments_selectors_test.ts b/tensorboard/webapp/experiments/store/experiments_selectors_test.ts index cc7d1ce6094..ac0397213b8 100644 --- a/tensorboard/webapp/experiments/store/experiments_selectors_test.ts +++ b/tensorboard/webapp/experiments/store/experiments_selectors_test.ts @@ -13,8 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import * as selectors from './experiments_selectors'; -import {State} from './experiments_types'; +import {State} from '../../app_state'; import {buildExperiment, buildStateFromExperimentsState} from './testing'; +import {buildMockState} from '../../testing/utils'; +import { + buildAppRoutingState, + buildStateFromAppRoutingState, +} from '../../app_routing/store/testing'; +import {RouteKind} from '../../app_routing/types'; describe('experiments selectors', () => { describe('#getExperiment', () => { @@ -49,32 +55,52 @@ describe('experiments selectors', () => { }); }); - describe('#getExperimentNames', () => { + describe('#getDashboardExperimentNames', () => { let state: State; beforeEach(() => { const foo = buildExperiment({id: 'foo', name: 'foo name'}); const bar = buildExperiment({id: 'bar', name: 'bar name'}); - state = buildStateFromExperimentsState({ - data: { - experimentMap: {foo, bar}, - }, + state = buildMockState({ + ...buildStateFromAppRoutingState( + buildAppRoutingState({ + activeRoute: { + routeKind: RouteKind.COMPARE_EXPERIMENT, + params: { + experimentIds: 'exp1:foo,exp2:bar,exp3:baz', + }, + }, + }) + ), + ...buildStateFromExperimentsState({ + data: { + experimentMap: {foo, bar}, + }, + }), }); }); it('translates experiment ids to experiment names', () => { - expect( - selectors.getExperimentNames(['foo', 'bar', 'baz'])(state) - ).toEqual({ + expect(selectors.getDashboardExperimentNames(state)).toEqual({ foo: 'foo name', bar: 'bar name', }); }); it('returns an empty object when no experiments are provided', () => { - expect(selectors.getExperimentNames([])(state)).toEqual({}); - expect(selectors.getExperimentNames(['abc', '123'])(state)).toEqual({}); + state = { + ...state, + ...buildStateFromAppRoutingState( + buildAppRoutingState({ + activeRoute: { + routeKind: RouteKind.EXPERIMENTS, + params: {}, + }, + }) + ), + }; + expect(selectors.getDashboardExperimentNames(state)).toEqual({}); }); }); }); diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts index c5f91c1395a..ba367888240 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts @@ -100,7 +100,7 @@ import { HeaderToggleInfo, XAxisType, } from '../../types'; -import {getFilteredRenderableRunsIdsFromRoute} from '../main_view/common_selectors'; +import {getFilteredRenderableRunsIds} from '../main_view/common_selectors'; import {CardRenderer} from '../metrics_view_types'; import {getTagDisplayName} from '../utils'; import {DataDownloadDialogContainer} from './data_download_dialog_container'; @@ -499,7 +499,7 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { combineLatestWith( this.store.select(getCurrentRouteRunSelection), this.store.select(getEnableHparamsInTimeSeries), - this.store.select(getFilteredRenderableRunsIdsFromRoute), + this.store.select(getFilteredRenderableRunsIds), this.store.select(getRunColorMap), this.store.select(getMetricsScalarSmoothing) ), diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts index cad08e37f92..d8d1f2d6b72 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts @@ -485,7 +485,7 @@ describe('scalar card', () => { new Map([['run1', true]]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1']) ); store.overrideSelector(selectors.getMetricsXAxisType, XAxisType.STEP); @@ -764,7 +764,7 @@ describe('scalar card', () => { new Map([['run1', true]]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1']) ); }); @@ -2574,7 +2574,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); store.overrideSelector(getCardStateMap, { @@ -2996,7 +2996,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); @@ -3066,7 +3066,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); @@ -3145,7 +3145,7 @@ describe('scalar card', () => { new Map([['run1', true]]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1']) ); @@ -3211,7 +3211,7 @@ describe('scalar card', () => { new Map([['run1', true]]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1']) ); @@ -3278,7 +3278,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); @@ -3330,7 +3330,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); @@ -3382,7 +3382,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); @@ -3432,7 +3432,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); store.overrideSelector(getMetricsLinkedTimeSelection, { @@ -3481,7 +3481,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); store.overrideSelector(selectors.getExperimentIdToExperimentAliasMap, { @@ -3539,7 +3539,7 @@ describe('scalar card', () => { new Map([['run1', true]]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1']) ); store.overrideSelector(getMetricsLinkedTimeSelection, { @@ -3586,7 +3586,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2', 'run3']) ); @@ -3635,7 +3635,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2', 'run3']) ); @@ -3692,7 +3692,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2', 'run3', 'run4', 'run5', 'run6', 'run7']) ); @@ -3754,7 +3754,7 @@ describe('scalar card', () => { ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2', 'run3', 'run4', 'run5', 'run6', 'run7']) ); @@ -3840,7 +3840,7 @@ describe('scalar card', () => { ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1']) ); @@ -3952,7 +3952,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2', 'run3', 'run4']) ); @@ -3981,7 +3981,7 @@ describe('scalar card', () => { ]) ); store.overrideSelector( - commonSelectors.getFilteredRenderableRunsIdsFromRoute, + commonSelectors.getFilteredRenderableRunsIds, new Set(['run1', 'run2']) ); const fixture = createComponent('card1'); diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 2797422d79a..e152b0e988e 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -20,18 +20,18 @@ import { getMetricsTagMetadata, getExperimentIdsFromRoute, getExperimentIdToExperimentAliasMap, - getExperimentNames, getRunColorMap, getRunSelectorRegexFilter, getRouteKind, getDashboardRuns, getColumnHeadersForCard, + getDashboardExperimentNames, } from '../../../selectors'; import {DeepReadonly} from '../../../util/types'; import { - getHparamFilterMapFromExperimentIds, - getMetricFilterMapFromExperimentIds, + getDashboardMetricsFilterMap, getDashboardHparamsAndMetricsSpecs, + getDashboardHparamFilterMap, } from '../../../hparams/_redux/hparams_selectors'; import { DiscreteFilter, @@ -185,76 +185,70 @@ const utils = { }, }; -const getRenderableRuns = memoize((experimentIds: string[]) => { - return createSelector( - getDashboardRuns(experimentIds), - getExperimentNames(experimentIds), - getCurrentRouteRunSelection, - getRunColorMap, - getExperimentIdToExperimentAliasMap, - ( - runs, - experimentNames, - selectionMap, - colorMap, - experimentIdToAlias - ): Array => { - return runs.map((run) => { - const hparamMap: RunTableItem['hparams'] = new Map(); - (run.hparams || []).forEach((hparam) => { - hparamMap.set(hparam.name, hparam.value); - }); - const metricMap: RunTableItem['metrics'] = new Map(); - (run.metrics || []).forEach((metric) => { - metricMap.set(metric.tag, metric.value); - }); - return { - run, - experimentName: experimentNames[run.experimentId] || '', - experimentAlias: experimentIdToAlias[run.experimentId], - selected: Boolean(selectionMap && selectionMap.get(run.id)), - runColor: colorMap[run.id], - hparams: hparamMap, - metrics: metricMap, - }; +const getRenderableRuns = createSelector( + getDashboardRuns, + getDashboardExperimentNames, + getCurrentRouteRunSelection, + getRunColorMap, + getExperimentIdToExperimentAliasMap, + ( + runs, + experimentNames, + selectionMap, + colorMap, + experimentIdToAlias + ): Array => { + return runs.map((run) => { + const hparamMap: RunTableItem['hparams'] = new Map(); + (run.hparams || []).forEach((hparam) => { + hparamMap.set(hparam.name, hparam.value); }); - } - ); -}); - -const getFilteredRenderableRuns = memoize((experimentIds: string[]) => { - return createSelector( - getRunSelectorRegexFilter, - getRenderableRuns(experimentIds), - getHparamFilterMapFromExperimentIds(experimentIds), - getMetricFilterMapFromExperimentIds(experimentIds), - getRouteKind, - (regexFilter, runItems, hparamFilters, metricFilters, routeKind) => { - const regexFilteredItems = utils.filterRunItemsByRegex( - runItems, - regexFilter, - routeKind === RouteKind.COMPARE_EXPERIMENT - ); + const metricMap: RunTableItem['metrics'] = new Map(); + (run.metrics || []).forEach((metric) => { + metricMap.set(metric.tag, metric.value); + }); + return { + run, + experimentName: experimentNames[run.experimentId] || '', + experimentAlias: experimentIdToAlias[run.experimentId], + selected: Boolean(selectionMap && selectionMap.get(run.id)), + runColor: colorMap[run.id], + hparams: hparamMap, + metrics: metricMap, + }; + }); + } +); - return utils.filterRunItemsByHparamAndMetricFilter( - regexFilteredItems, - hparamFilters, - metricFilters - ); - } - ); -}); +export const getFilteredRenderableRuns = createSelector( + getRunSelectorRegexFilter, + getRenderableRuns, + getDashboardHparamFilterMap, + getDashboardMetricsFilterMap, + getRouteKind, + ( + regexFilter, + runItems, + hparamFilters, + metricFilters, + routeKind + ): RunTableItem[] => { + const regexFilteredItems = utils.filterRunItemsByRegex( + runItems, + regexFilter, + routeKind === RouteKind.COMPARE_EXPERIMENT + ); -export const getFilteredRenderableRunsFromRoute = createSelector( - (state) => state, - getExperimentIdsFromRoute, - (state, experimentIds) => { - return getFilteredRenderableRuns(experimentIds || [])(state); + return utils.filterRunItemsByHparamAndMetricFilter( + regexFilteredItems, + hparamFilters, + metricFilters + ); } ); -export const getFilteredRenderableRunsIdsFromRoute = createSelector( - getFilteredRenderableRunsFromRoute, +export const getFilteredRenderableRunsIds = createSelector( + getFilteredRenderableRuns, (filteredRenderableRuns) => { return new Set(filteredRenderableRuns.map(({run: {id}}) => id)); } @@ -292,12 +286,8 @@ export const getAllPotentialColumnsForCard = memoize((cardId: string) => { ); }); -export const factories = { - getRenderableRuns, - getFilteredRenderableRuns, -}; - export const TEST_ONLY = { + getRenderableRuns, getRenderableCardIdsWithMetadata, getScalarTagsForRunSelection, utils, diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index e6e653d54c8..898e50ac275 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -175,10 +175,11 @@ describe('common selectors', () => { activeRoute: { routeKind: RouteKind.EXPERIMENT, params: { + experimentId: 'defaultExperimentId', experimentIds: 'foo:exp1,bar:exp2', }, }, - } as any, + }, hparams: { dashboardSpecs: { hparams: [buildHparamSpec({name: 'foo', displayName: 'Foo'})], @@ -710,33 +711,23 @@ describe('common selectors', () => { }); describe('getRenderableRuns', () => { - it('returns all runs associated with experiment', () => { - const exp1Result = selectors.factories.getRenderableRuns(['exp1'])(state); - expect(exp1Result.length).toEqual(2); - expect(exp1Result[0].run).toEqual({...run1, experimentId: 'exp1'}); - expect(exp1Result[1].run).toEqual({...run2, experimentId: 'exp1'}); - - const exp2Result = selectors.factories.getRenderableRuns(['exp2'])(state); - expect(exp2Result.length).toEqual(3); - expect(exp2Result[0].run).toEqual({...run2, experimentId: 'exp2'}); - expect(exp2Result[1].run).toEqual({...run3, experimentId: 'exp2'}); - expect(exp2Result[2].run).toEqual({...run4, experimentId: 'exp2'}); - }); - - it('returns two runs when a run is associated with multiple experiments', () => { - const result = selectors.factories.getRenderableRuns(['exp1', 'exp2'])( - state - ); - expect(result.length).toEqual(5); - expect(result[0].run).toEqual({...run1, experimentId: 'exp1'}); - expect(result[1].run).toEqual({...run2, experimentId: 'exp1'}); - expect(result[2].run).toEqual({...run2, experimentId: 'exp2'}); - expect(result[3].run).toEqual({...run3, experimentId: 'exp2'}); - expect(result[4].run).toEqual({...run4, experimentId: 'exp2'}); - }); - - it('returns empty list when no experiments are provided', () => { - expect(selectors.factories.getRenderableRuns([])(state)).toEqual([]); + it('returns all runs associated with each experiment', () => { + state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; + const results = selectors.TEST_ONLY.getRenderableRuns(state); + expect(results.length).toEqual(5); + expect(results[0].run).toEqual({...run1, experimentId: 'exp1'}); + expect(results[1].run).toEqual({...run2, experimentId: 'exp1'}); + expect(results[2].run).toEqual({...run2, experimentId: 'exp2'}); + expect(results[3].run).toEqual({...run3, experimentId: 'exp2'}); + expect(results[4].run).toEqual({...run4, experimentId: 'exp2'}); + }); + + it('returns empty list when route does not contain experiments', () => { + state.app_routing!.activeRoute = { + routeKind: RouteKind.EXPERIMENTS, + params: {}, + }; + expect(selectors.TEST_ONLY.getRenderableRuns(state)).toEqual([]); }); }); @@ -918,18 +909,18 @@ describe('common selectors', () => { describe('getFilteredRenderableRuns', () => { it('does not use experiment alias when route is not compare', () => { state.runs!.data.regexFilter = 'foo'; - const result = selectors.factories.getFilteredRenderableRuns(['exp1'])( - state - ); + state.app_routing!.activeRoute = { + routeKind: RouteKind.EXPERIMENT, + params: {experimentIds: 'exp1'}, + }; + const result = selectors.getFilteredRenderableRuns(state); expect(result).toEqual([]); }); it('uses experiment alias when route is compare', () => { state.runs!.data.regexFilter = 'foo'; state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; - const result = selectors.factories.getFilteredRenderableRuns(['exp1'])( - state - ); + const result = selectors.getFilteredRenderableRuns(state); expect(result.length).toEqual(2); expect(result[0].run.name).toEqual('run 1'); expect(result[1].run.name).toEqual('run 2'); @@ -940,47 +931,32 @@ describe('common selectors', () => { selectors.TEST_ONLY.utils, 'filterRunItemsByHparamAndMetricFilter' ).and.callThrough(); - const results = selectors.factories.getFilteredRenderableRuns(['exp1'])( - state - ); + state.app_routing!.activeRoute = { + routeKind: RouteKind.EXPERIMENT, + params: {experimentIds: 'exp1'}, + }; + const results = selectors.getFilteredRenderableRuns(state); expect(spy).toHaveBeenCalledOnceWith(results, new Map(), new Map()); }); it('returns empty list when no experiments are provided', () => { - expect(selectors.factories.getFilteredRenderableRuns([])(state)).toEqual( - [] - ); - }); - }); - - describe('getFilteredRenderableRunsFromRoute', () => { - it('calls getFilteredRenderableRuns with experiment ids from the route when in compare view', () => { - state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; - const result = selectors.getFilteredRenderableRunsFromRoute(state); - expect(result).toEqual( - selectors.factories.getFilteredRenderableRuns(['exp1', 'exp2'])(state) - ); - }); - - it('calls getFilteredRenderableRuns with experiment ids from the route when in single experiment view', () => { - const result = selectors.getFilteredRenderableRunsFromRoute(state); - expect(result).toEqual( - selectors.factories.getFilteredRenderableRuns(['defaultExperimentId'])( - state - ) - ); + state.app_routing!.activeRoute = { + routeKind: RouteKind.EXPERIMENTS, + params: {}, + }; + expect(selectors.getFilteredRenderableRuns(state)).toEqual([]); }); }); describe('getFilteredRenderableRunsIdsFromRoute', () => { it('returns a set of run ids from the route when in compare view', () => { state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; - const result = selectors.getFilteredRenderableRunsIdsFromRoute(state); + const result = selectors.getFilteredRenderableRunsIds(state); expect(result).toEqual(new Set(['1', '2', '3', '4'])); }); it('returns a set of run ids from the route when in single experiment view', () => { - const result = selectors.getFilteredRenderableRunsIdsFromRoute(state); + const result = selectors.getFilteredRenderableRunsIds(state); expect(result).toEqual(new Set()); }); }); diff --git a/tensorboard/webapp/runs/store/BUILD b/tensorboard/webapp/runs/store/BUILD index 740e9fe6523..ef603b75ece 100644 --- a/tensorboard/webapp/runs/store/BUILD +++ b/tensorboard/webapp/runs/store/BUILD @@ -45,6 +45,7 @@ tf_ts_library( deps = [ ":types", ":utils", + "//tensorboard/webapp/app_routing/store", "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/runs:types", diff --git a/tensorboard/webapp/runs/store/runs_selectors.ts b/tensorboard/webapp/runs/store/runs_selectors.ts index 9804a189c61..adab7e71e5d 100644 --- a/tensorboard/webapp/runs/store/runs_selectors.ts +++ b/tensorboard/webapp/runs/store/runs_selectors.ts @@ -29,6 +29,7 @@ import {createGroupBy} from './utils'; import {ColumnHeader, SortingInfo} from '../../widgets/data_table/types'; import {getDashboardRunsToHparamsAndMetrics} from '../../hparams/_redux/hparams_selectors'; import {RunToHparamsAndMetrics} from '../../hparams/types'; +import {getExperimentIdsFromRoute} from '../../app_routing/store/app_routing_selectors'; const getRunsState = createFeatureSelector(RUNS_FEATURE_KEY); @@ -90,29 +91,33 @@ export const getRuns = createSelector( * @param experimentIds * @returns */ -export const getDashboardRuns = (experimentIds: string[]) => - createSelector( - getDataState, - getDashboardRunsToHparamsAndMetrics, - ( - state: RunsDataState, - runsToHparamsAndMetrics: RunToHparamsAndMetrics - ): Array => { - return experimentIds - .map((experimentId) => { - return (state.runIds[experimentId] || []) - .filter((id) => Boolean(state.runMetadata[id])) - .map((runId) => { - const run = {...state.runMetadata[runId], experimentId}; - run.hparams = runsToHparamsAndMetrics[runId]?.hparams ?? null; - run.metrics = runsToHparamsAndMetrics[runId]?.metrics ?? null; - - return run; - }); - }) - .flat(); +export const getDashboardRuns = createSelector( + getDataState, + getExperimentIdsFromRoute, + getDashboardRunsToHparamsAndMetrics, + ( + state: RunsDataState, + experimentIds: string[] | null, + runsToHparamsAndMetrics: RunToHparamsAndMetrics + ): Array => { + if (!experimentIds) { + return []; } - ); + return experimentIds + .map((experimentId) => { + return (state.runIds[experimentId] || []) + .filter((id) => Boolean(state.runMetadata[id])) + .map((runId) => { + const run = {...state.runMetadata[runId], experimentId}; + run.hparams = runsToHparamsAndMetrics[runId]?.hparams ?? null; + run.metrics = runsToHparamsAndMetrics[runId]?.metrics ?? null; + + return run; + }); + }) + .flat(); + } +); /** * Returns Observable that emits runs list for an experiment. diff --git a/tensorboard/webapp/runs/store/runs_selectors_test.ts b/tensorboard/webapp/runs/store/runs_selectors_test.ts index 7d2d624df44..a50b9fd7380 100644 --- a/tensorboard/webapp/runs/store/runs_selectors_test.ts +++ b/tensorboard/webapp/runs/store/runs_selectors_test.ts @@ -205,6 +205,14 @@ describe('runs_selectors', () => { describe('#getDashboardRuns', () => { it('returns runs', () => { const state = buildMockState({ + ...buildStateFromAppRoutingState( + buildAppRoutingState({ + activeRoute: { + routeKind: RouteKind.EXPERIMENT, + params: {experimentId: 'eid'}, + }, + }) + ), ...buildStateFromRunsState( buildRunsState({ runIds: { @@ -216,7 +224,7 @@ describe('runs_selectors', () => { }) ), }); - expect(selectors.getDashboardRuns(['eid'])(state)).toEqual([ + expect(selectors.getDashboardRuns(state)).toEqual([ { ...buildRun({ id: 'run1', @@ -226,8 +234,16 @@ describe('runs_selectors', () => { ]); }); - it('returns runs for the ones that has metadata', () => { + it('returns runs that have metadata', () => { const state = buildMockState({ + ...buildStateFromAppRoutingState( + buildAppRoutingState({ + activeRoute: { + routeKind: RouteKind.EXPERIMENT, + params: {experimentId: 'eid'}, + }, + }) + ), ...buildStateFromRunsState( buildRunsState({ runIds: { @@ -239,7 +255,7 @@ describe('runs_selectors', () => { }) ), }); - expect(selectors.getDashboardRuns(['eid'])(state)).toEqual([ + expect(selectors.getDashboardRuns(state)).toEqual([ { ...buildRun({ id: 'run1', @@ -250,8 +266,17 @@ describe('runs_selectors', () => { }); it('returns empty list if experiment id does not exist', () => { - const state = buildMockState(); - expect(selectors.getDashboardRuns(['i_do_not_exist'])(state)).toEqual([]); + const state = buildMockState( + buildStateFromAppRoutingState( + buildAppRoutingState({ + activeRoute: { + routeKind: RouteKind.EXPERIMENTS, + params: {}, + }, + }) + ) + ); + expect(selectors.getDashboardRuns(state)).toEqual([]); }); it('includes dashboard hparams data', () => { @@ -260,7 +285,7 @@ describe('runs_selectors', () => { buildAppRoutingState({ activeRoute: { routeKind: RouteKind.EXPERIMENT, - params: {}, + params: {experimentId: 'eid'}, }, }) ), @@ -291,7 +316,7 @@ describe('runs_selectors', () => { }) ), }); - expect(selectors.getDashboardRuns(['eid'])(state)).toEqual([ + expect(selectors.getDashboardRuns(state)).toEqual([ { ...buildRun({ id: 'run1', @@ -308,6 +333,14 @@ describe('runs_selectors', () => { it('never returns hparams or metric data from run metadata', () => { const state = buildMockState({ + ...buildStateFromAppRoutingState( + buildAppRoutingState({ + activeRoute: { + routeKind: RouteKind.EXPERIMENT, + params: {experimentId: 'eid'}, + }, + }) + ), ...buildStateFromRunsState( buildRunsState({ runIds: { @@ -324,7 +357,7 @@ describe('runs_selectors', () => { ), }); - const response = selectors.getDashboardRuns(['eid'])(state)[0]; + const response = selectors.getDashboardRuns(state)[0]; expect(response.hparams).toBeNull(); expect(response.metrics).toBeNull(); }); diff --git a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts index 7cb65bdcd8d..70ce1acda5e 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts @@ -96,7 +96,7 @@ import { } from './runs_table_component'; import {RunsTableColumn, RunTableItem} from './types'; import { - getFilteredRenderableRunsFromRoute, + getFilteredRenderableRuns, getPotentialHparamColumns, } from '../../../metrics/views/main_view/common_selectors'; import {runsTableFullScreenToggled} from '../../../core/actions'; @@ -372,23 +372,21 @@ export class RunsTableContainer implements OnInit, OnDestroy { }) ); - allRunsTableData$ = this.store - .select(getFilteredRenderableRunsFromRoute) - .pipe( - map((filteredRenderableRuns) => { - return filteredRenderableRuns.map((runTableItem) => { - const tableData: TableData = { - ...Object.fromEntries(runTableItem.hparams.entries()), - id: runTableItem.run.id, - run: runTableItem.run.name, - experimentAlias: runTableItem.experimentAlias, - selected: runTableItem.selected, - color: runTableItem.runColor, - }; - return tableData; - }); - }) - ); + allRunsTableData$ = this.store.select(getFilteredRenderableRuns).pipe( + map((filteredRenderableRuns) => { + return filteredRenderableRuns.map((runTableItem) => { + const tableData: TableData = { + ...Object.fromEntries(runTableItem.hparams.entries()), + id: runTableItem.run.id, + run: runTableItem.run.name, + experimentAlias: runTableItem.experimentAlias, + selected: runTableItem.selected, + color: runTableItem.runColor, + }; + return tableData; + }); + }) + ); private readonly ngUnsubscribe = new Subject(); diff --git a/tensorboard/webapp/runs/views/runs_table/runs_table_test.ts b/tensorboard/webapp/runs/views/runs_table/runs_table_test.ts index db3f3307abf..25c4ffe29aa 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_table_test.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_table_test.ts @@ -110,7 +110,7 @@ import { ColumnHeaderType, SortingOrder, } from '../../../widgets/data_table/types'; -import {getFilteredRenderableRunsFromRoute} from '../../../metrics/views/main_view/common_selectors'; +import {getFilteredRenderableRuns} from '../../../metrics/views/main_view/common_selectors'; @Injectable() class ColorPickerTestHelper { @@ -3213,7 +3213,7 @@ describe('runs_table', () => { it('passes run name, experiment alias, selected value, and color to data table', () => { // To make sure we only return the runs when called with the right props. const selectSpy = spyOn(store, 'select').and.callThrough(); - selectSpy.withArgs(getFilteredRenderableRunsFromRoute).and.returnValue( + selectSpy.withArgs(getFilteredRenderableRuns).and.returnValue( of([ { run: buildRun({id: 'book1', name: "The Philosopher's Stone"}), @@ -3276,7 +3276,7 @@ describe('runs_table', () => { ]) ); - selectSpy.withArgs(getFilteredRenderableRunsFromRoute).and.returnValue( + selectSpy.withArgs(getFilteredRenderableRuns).and.returnValue( of([ { run: run1, @@ -3338,7 +3338,7 @@ describe('runs_table', () => { }, ]); - store.overrideSelector(getFilteredRenderableRunsFromRoute, [ + store.overrideSelector(getFilteredRenderableRuns, [ { run: run1, experimentAlias: {aliasNumber: 1, aliasText: 'bbb'},