From b408561058932b76382ee0ac590d7fd1b17825f8 Mon Sep 17 00:00:00 2001 From: Riley Jones <78179109+rileyajones@users.noreply.github.com> Date: Mon, 21 Aug 2023 10:48:43 -0700 Subject: [PATCH] Hparam: Add selectors to retrieve data written by the new hparams effects (#6544) ## Motivation for features / changes This is a continuation of #6540 Now that data is being retrieved from the hparams plugin and written to the store I am adding a pair of selectors to retrieve it and inject it into the the data retrieved from `getRuns` and `getRunsForExperimentId`. The logic used to translate `SessionGroups` to `Record` is very similar to that being used by the old code in `runs_data_source`. I could imagine this being refactored to retrieve the run/group relationships from another part of the store in the future in order to better integrate with our plans to stop relying on the metrics portion of the hparams payload. ## Note For Reviewer(s) A large majority of the lines changed in this PR are indent changes resulting from changing the method used to mock the state during tests. I have opted not to add additional test to the runs or common selectors (though I did have to fix 16) as the intention is for this to be a drop in replacement with no new features added. The number of test failures which I had to fix left me feeling confident of the existing test coverage. --- tensorboard/webapp/BUILD | 3 + tensorboard/webapp/alert/BUILD | 1 - tensorboard/webapp/alert/types.ts | 5 +- .../views/alert_display_snackbar_container.ts | 4 +- tensorboard/webapp/app_state.ts | 8 +- tensorboard/webapp/hparams/_redux/BUILD | 1 + .../hparams/_redux/hparams_selectors.ts | 56 +- .../hparams/_redux/hparams_selectors_test.ts | 235 +++++++++ tensorboard/webapp/hparams/_redux/testing.ts | 53 ++ tensorboard/webapp/hparams/_types.ts | 2 + .../views/main_view/common_selectors.ts | 14 +- .../views/main_view/common_selectors_test.ts | 39 +- tensorboard/webapp/runs/store/BUILD | 5 + .../webapp/runs/store/runs_selectors.ts | 38 +- .../webapp/runs/store/runs_selectors_test.ts | 488 +++++++++++------- tensorboard/webapp/testing/BUILD | 5 + tensorboard/webapp/testing/utils.ts | 52 +- tensorboard/webapp/util/BUILD | 1 + tensorboard/webapp/util/ui_selectors_test.ts | 31 +- 19 files changed, 783 insertions(+), 258 deletions(-) diff --git a/tensorboard/webapp/BUILD b/tensorboard/webapp/BUILD index d167c713f4c..6c0d9d46be6 100644 --- a/tensorboard/webapp/BUILD +++ b/tensorboard/webapp/BUILD @@ -142,6 +142,8 @@ tf_ng_module( "app_state.ts", ], deps = [ + "//tensorboard/plugins/debugger_v2/tf_debugger_v2_plugin/store:types", + "//tensorboard/webapp/alert/store:types", "//tensorboard/webapp/app_routing/store:types", "//tensorboard/webapp/core/store", "//tensorboard/webapp/experiments/store:types", @@ -149,6 +151,7 @@ tf_ng_module( "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/metrics/store:types", "//tensorboard/webapp/notification_center/_redux:types", + "//tensorboard/webapp/persistent_settings/_redux:types", "//tensorboard/webapp/runs/store:types", "//tensorboard/webapp/settings", ], diff --git a/tensorboard/webapp/alert/BUILD b/tensorboard/webapp/alert/BUILD index 0a98e6532dd..d80c37551eb 100644 --- a/tensorboard/webapp/alert/BUILD +++ b/tensorboard/webapp/alert/BUILD @@ -37,7 +37,6 @@ tf_ts_library( "types.ts", ], deps = [ - "//tensorboard/webapp:app_state", "@npm//@ngrx/store", "@npm//rxjs", ], diff --git a/tensorboard/webapp/alert/types.ts b/tensorboard/webapp/alert/types.ts index fd797f4f4ba..61f855cce7d 100644 --- a/tensorboard/webapp/alert/types.ts +++ b/tensorboard/webapp/alert/types.ts @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {Action, Store} from '@ngrx/store'; -import {State} from '../app_state'; +import {Action} from '@ngrx/store'; /** * An alert structure used when creating newly reported alerts. @@ -35,7 +34,7 @@ export interface AlertReport { * when the followup action is requested, a newly created Promise will be * awaited, and the resulting action is dispatched. */ - getFollowupAction: (store: Store) => Promise; + getFollowupAction: () => Promise; }; } diff --git a/tensorboard/webapp/alert/views/alert_display_snackbar_container.ts b/tensorboard/webapp/alert/views/alert_display_snackbar_container.ts index 9f34469c082..139e064b788 100644 --- a/tensorboard/webapp/alert/views/alert_display_snackbar_container.ts +++ b/tensorboard/webapp/alert/views/alert_display_snackbar_container.ts @@ -43,9 +43,7 @@ export class AlertDisplaySnackbarContainer { async onActionButtonClicked() { this.snackBarRef.dismiss(); - const followupAction = await this.alert.followupAction!.getFollowupAction( - this.store - ); + const followupAction = await this.alert.followupAction!.getFollowupAction(); this.store.dispatch(followupAction); } diff --git a/tensorboard/webapp/app_state.ts b/tensorboard/webapp/app_state.ts index 47ce6a756a0..47b0da38355 100644 --- a/tensorboard/webapp/app_state.ts +++ b/tensorboard/webapp/app_state.ts @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {State as DebuggerState} from '../plugins/debugger_v2/tf_debugger_v2_plugin/store/debugger_types'; +import {State as AlertState} from './alert/store/alert_types'; import {State as AppRoutingState} from './app_routing/store/app_routing_types'; import {State as CoreState} from './core/store/core_types'; import {State as ExperimentsState} from './experiments/store/experiments_types'; @@ -20,6 +22,7 @@ import {State as FeatureFlagState} from './feature_flag/store/feature_flag_types import {State as HparamsState} from './hparams/types'; import {State as MetricsState} from './metrics/store/metrics_types'; import {State as NotificationState} from './notification_center/_redux/notification_center_types'; +import {State as PersistentSettingsState} from './persistent_settings/_redux/persistent_settings_types'; import {State as RunsState} from './runs/store/runs_types'; import {State as SettingsState} from './settings'; @@ -31,4 +34,7 @@ export type State = AppRoutingState & MetricsState & RunsState & SettingsState & - NotificationState; + NotificationState & + DebuggerState & + AlertState & + PersistentSettingsState; diff --git a/tensorboard/webapp/hparams/_redux/BUILD b/tensorboard/webapp/hparams/_redux/BUILD index d2b0fc296da..6439918303f 100644 --- a/tensorboard/webapp/hparams/_redux/BUILD +++ b/tensorboard/webapp/hparams/_redux/BUILD @@ -168,6 +168,7 @@ tf_ts_library( "//tensorboard/webapp/runs/data_source:testing", "//tensorboard/webapp/runs/store:testing", "//tensorboard/webapp/testing:utils", + "//tensorboard/webapp/util:types", "//tensorboard/webapp/webapp_data_source:http_client_testing", "@npm//@ngrx/effects", "@npm//@ngrx/store", diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index 66831c5149d..d7dc87650bd 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -13,7 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import {createFeatureSelector, createSelector} from '@ngrx/store'; -import {DiscreteFilter, HparamAndMetricSpec, IntervalFilter} from '../types'; +import { + DiscreteFilter, + HparamAndMetricSpec, + HparamValue, + IntervalFilter, + RunToHparamsAndMetrics, +} from '../types'; import {combineHparamAndMetricSpecs} from './hparams_selectors_utils'; import {HparamsState, HPARAMS_FEATURE_KEY} from './types'; import { @@ -178,3 +184,51 @@ export const getExperimentsHparamsAndMetricsSpecs = createSelector( ); } ); + +export const getDashboardHparamsAndMetricsSpecs = createSelector( + getHparamsState, + (state: HparamsState) => { + return state.dashboardSpecs; + } +); + +export const getDashboardRunsToHparamsAndMetrics = createSelector( + getHparamsState, + (state): RunToHparamsAndMetrics => { + const runToHparamsAndMetrics: RunToHparamsAndMetrics = {}; + + for (const sessionGroup of state.dashboardSessionGroups) { + const hparams: HparamValue[] = Object.entries(sessionGroup.hparams).map( + (keyValue) => { + const [hparam, value] = keyValue; + return {name: hparam, value}; + } + ); + + for (const session of sessionGroup.sessions) { + runToHparamsAndMetrics[session.name] = { + metrics: [], + hparams, + }; + + for (const metricValue of session.metricValues) { + const runId = metricValue.name.group + ? `${session.name}/${metricValue.name.group}` + : session.name; + + const hparamsAndMetrics = runToHparamsAndMetrics[runId] || { + metrics: [], + hparams, + }; + hparamsAndMetrics.metrics.push({ + tag: metricValue.name.tag, + trainingStep: metricValue.trainingStep, + value: metricValue.value, + }); + runToHparamsAndMetrics[runId] = hparamsAndMetrics; + } + } + } + return runToHparamsAndMetrics; + } +); diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts index 2f93838a744..e4422d5d4c6 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts @@ -21,6 +21,8 @@ import { buildHparamsState, buildIntervalFilter, buildMetricSpec, + buildMetricsValue, + buildSessionGroup, buildSpecs, buildStateFromHparamsState, } from './testing'; @@ -466,6 +468,239 @@ describe('hparams/_redux/hparams_selectors_test', () => { }); }); + describe('#getDashboardHparamsAndMetricsSpecs', () => { + it('returns dashboard specs', () => { + const state = buildStateFromHparamsState( + buildHparamsState({ + dashboardSpecs: { + hparams: [buildHparamSpec({name: 'foo'})], + metrics: [buildMetricSpec({tag: 'bar'})], + }, + }) + ); + + expect(selectors.getDashboardHparamsAndMetricsSpecs(state)).toEqual({ + hparams: [buildHparamSpec({name: 'foo'})], + metrics: [buildMetricSpec({tag: 'bar'})], + }); + }); + }); + + describe('#getDashboardRunsToHparamsAndMetrics', () => { + it('contains entry for each runId/group', () => { + const mockSessionGroups = [ + buildSessionGroup({ + name: 'session_group_1', + hparams: { + hp1: 1, + hp2: true, + hp3: 'foo', + }, + sessions: [ + { + name: 'exp1/run1', + metricValues: [ + buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 2}), + buildMetricsValue({ + name: {tag: 'bar', group: '2'}, + value: 103, + trainingStep: 4, + }), + buildMetricsValue({ + name: {tag: 'bar', group: '2'}, + value: 107, + trainingStep: 5, + }), + buildMetricsValue({name: {tag: 'abc123', group: ''}, value: 2}), + ], + }, + { + name: 'exp1/run2', + metricValues: [ + buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 3}), + buildMetricsValue({name: {tag: 'bar', group: '2'}, value: 104}), + buildMetricsValue({name: {tag: 'baz', group: '3'}, value: 201}), + ], + }, + ], + }), + buildSessionGroup({ + name: 'session_group_2', + hparams: { + hp1: 2, + hp2: false, + hp3: 'bar', + }, + sessions: [ + { + name: 'exp1/run3', + metricValues: [ + buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 4}), + buildMetricsValue({name: {tag: 'bar', group: '2'}, value: 105}), + ], + }, + ], + }), + buildSessionGroup({ + name: 'session_group_3', + hparams: { + hp4: 'hyperparameter4', + }, + sessions: [ + { + name: 'exp1/run4', + metricValues: [], + }, + ], + }), + buildSessionGroup({ + name: 'session_group_4', + hparams: { + hp1: 7, + hp2: false, + hp3: 'foobar', + }, + sessions: [ + { + name: 'exp2/run1', + metricValues: [ + buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 4}), + buildMetricsValue({name: {tag: 'bar', group: '2'}, value: 105}), + buildMetricsValue({ + name: {tag: 'baz', group: '2'}, + value: 1000, + }), + ], + }, + ], + }), + ]; + + const state = buildStateFromHparamsState( + buildHparamsState({ + dashboardSessionGroups: mockSessionGroups, + }) + ); + + expect(selectors.getDashboardRunsToHparamsAndMetrics(state)).toEqual({ + 'exp1/run1': { + metrics: [{tag: 'abc123', trainingStep: 0, value: 2}], + hparams: [ + {name: 'hp1', value: 1}, + {name: 'hp2', value: true}, + {name: 'hp3', value: 'foo'}, + ], + }, + 'exp1/run1/1': { + metrics: [{tag: 'foo', trainingStep: 0, value: 2}], + hparams: [ + {name: 'hp1', value: 1}, + {name: 'hp2', value: true}, + {name: 'hp3', value: 'foo'}, + ], + }, + 'exp1/run1/2': { + metrics: [ + {tag: 'bar', trainingStep: 4, value: 103}, + {tag: 'bar', trainingStep: 5, value: 107}, + ], + hparams: [ + {name: 'hp1', value: 1}, + {name: 'hp2', value: true}, + {name: 'hp3', value: 'foo'}, + ], + }, + 'exp1/run2': { + metrics: [], + hparams: [ + {name: 'hp1', value: 1}, + {name: 'hp2', value: true}, + {name: 'hp3', value: 'foo'}, + ], + }, + 'exp1/run2/1': { + metrics: [{tag: 'foo', trainingStep: 0, value: 3}], + hparams: [ + {name: 'hp1', value: 1}, + {name: 'hp2', value: true}, + {name: 'hp3', value: 'foo'}, + ], + }, + 'exp1/run2/2': { + metrics: [{tag: 'bar', trainingStep: 0, value: 104}], + hparams: [ + {name: 'hp1', value: 1}, + {name: 'hp2', value: true}, + {name: 'hp3', value: 'foo'}, + ], + }, + 'exp1/run2/3': { + metrics: [{tag: 'baz', trainingStep: 0, value: 201}], + hparams: [ + {name: 'hp1', value: 1}, + {name: 'hp2', value: true}, + {name: 'hp3', value: 'foo'}, + ], + }, + 'exp1/run3': { + metrics: [], + hparams: [ + {name: 'hp1', value: 2}, + {name: 'hp2', value: false}, + {name: 'hp3', value: 'bar'}, + ], + }, + 'exp1/run3/1': { + metrics: [{tag: 'foo', trainingStep: 0, value: 4}], + hparams: [ + {name: 'hp1', value: 2}, + {name: 'hp2', value: false}, + {name: 'hp3', value: 'bar'}, + ], + }, + 'exp1/run3/2': { + metrics: [{tag: 'bar', trainingStep: 0, value: 105}], + hparams: [ + {name: 'hp1', value: 2}, + {name: 'hp2', value: false}, + {name: 'hp3', value: 'bar'}, + ], + }, + 'exp1/run4': { + metrics: [], + hparams: [{name: 'hp4', value: 'hyperparameter4'}], + }, + 'exp2/run1': { + metrics: [], + hparams: [ + {name: 'hp1', value: 7}, + {name: 'hp2', value: false}, + {name: 'hp3', value: 'foobar'}, + ], + }, + 'exp2/run1/1': { + metrics: [{tag: 'foo', trainingStep: 0, value: 4}], + hparams: [ + {name: 'hp1', value: 7}, + {name: 'hp2', value: false}, + {name: 'hp3', value: 'foobar'}, + ], + }, + 'exp2/run1/2': { + metrics: [ + {tag: 'bar', trainingStep: 0, value: 105}, + {tag: 'baz', trainingStep: 0, value: 1000}, + ], + hparams: [ + {name: 'hp1', value: 7}, + {name: 'hp2', value: false}, + {name: 'hp3', value: 'foobar'}, + ], + }, + }); + }); + }); + it('does not use default filters when includeDefaults is false', () => { const state = buildStateFromHparamsState( buildHparamsState({ diff --git a/tensorboard/webapp/hparams/_redux/testing.ts b/tensorboard/webapp/hparams/_redux/testing.ts index 88f0e5a23fb..76894db4004 100644 --- a/tensorboard/webapp/hparams/_redux/testing.ts +++ b/tensorboard/webapp/hparams/_redux/testing.ts @@ -19,9 +19,13 @@ import { DiscreteFilter, DomainType, HparamSpec, + HparamValue, HparamsValueType, IntervalFilter, MetricSpec, + MetricsValue, + RunStatus, + Session, SessionGroup, } from '../_types'; import { @@ -143,3 +147,52 @@ export function buildIntervalFilter( ...override, }; } + +export function buildMetricsValue( + override: DeepPartial = {} +): MetricsValue { + return { + trainingStep: 0, + value: 1, + wallTimeSecs: 123, + ...override, + name: { + tag: override.name?.tag ?? 'someTag', + group: override.name?.group ?? 'someGroup', + }, + }; +} + +export function buildHparamValue(override: Partial): HparamValue { + return { + name: 'some_hparam', + value: 4, + ...override, + }; +} + +export function buildSession(override: DeepPartial = {}): Session { + return { + name: 'someExperiment/someRun', + modelUri: '', + monitorUrl: '', + startTimeSecs: 123, + endTimeSecs: 456, + status: RunStatus.STATUS_UNKNOWN, + ...override, + metricValues: [...(override.metricValues ?? [])].map(buildMetricsValue), + }; +} + +export function buildSessionGroup( + override: DeepPartial +): SessionGroup { + return { + name: 'some_session_group', + ...override, + hparams: { + ...override.hparams, + } as any, + sessions: (override.sessions ?? []).map(buildSession), + }; +} diff --git a/tensorboard/webapp/hparams/_types.ts b/tensorboard/webapp/hparams/_types.ts index 01df0df1383..70ceafffe25 100644 --- a/tensorboard/webapp/hparams/_types.ts +++ b/tensorboard/webapp/hparams/_types.ts @@ -42,6 +42,8 @@ export { BackendHparamsValueType, RunStatus, SessionGroup, + Session, + MetricsValue, } from '../runs/data_source/runs_backend_types'; export interface HparamAndMetricSpec { diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 526c99fbacf..2797422d79a 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -24,14 +24,14 @@ import { getRunColorMap, getRunSelectorRegexFilter, getRouteKind, - getRunsFromExperimentIds, + getDashboardRuns, getColumnHeadersForCard, } from '../../../selectors'; import {DeepReadonly} from '../../../util/types'; import { getHparamFilterMapFromExperimentIds, getMetricFilterMapFromExperimentIds, - getExperimentsHparamsAndMetricsSpecs, + getDashboardHparamsAndMetricsSpecs, } from '../../../hparams/_redux/hparams_selectors'; import { DiscreteFilter, @@ -187,7 +187,7 @@ const utils = { const getRenderableRuns = memoize((experimentIds: string[]) => { return createSelector( - getRunsFromExperimentIds(experimentIds), + getDashboardRuns(experimentIds), getExperimentNames(experimentIds), getCurrentRouteRunSelection, getRunColorMap, @@ -261,17 +261,13 @@ export const getFilteredRenderableRunsIdsFromRoute = createSelector( ); export const getPotentialHparamColumns = createSelector( - (state: State) => state, + getDashboardHparamsAndMetricsSpecs, getExperimentIdsFromRoute, - (state, experimentIds): ColumnHeader[] => { + ({hparams}, experimentIds): ColumnHeader[] => { if (!experimentIds) { return []; } - const {hparams} = getExperimentsHparamsAndMetricsSpecs(state, { - experimentIds, - }); - return hparams.map((spec) => ({ type: ColumnHeaderType.HPARAM, name: spec.name, diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index b87b7576ede..e6e653d54c8 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ import {RouteKind} from '../../../app_routing'; import { - buildSpecs, buildHparamSpec, buildMetricSpec, } from '../../../hparams/_redux/testing'; @@ -181,23 +180,18 @@ describe('common selectors', () => { }, } as any, hparams: { - specs: buildSpecs('defaultExperimentId', { - hparam: { - specs: [buildHparamSpec({name: 'foo', displayName: 'Foo'})], - defaultFilters: new Map(), - }, - metric: { - specs: [buildMetricSpec({displayName: 'Bar'})], - defaultFilters: new Map(), - }, - }), + dashboardSpecs: { + hparams: [buildHparamSpec({name: 'foo', displayName: 'Foo'})], + metrics: [buildMetricSpec({displayName: 'Bar'})], + }, + dashboardSessionGroups: [], } as any, }); }); describe('getScalarTagsForRunSelection', () => { it('returns all tags containing scalar data when no runs are selected', () => { - const state = { + const state = buildMockState({ ...appStateFromMetricsState( buildMetricsState({ tagMetadata: { @@ -239,14 +233,14 @@ describe('common selectors', () => { } ) ), - }; + }); expect(selectors.TEST_ONLY.getScalarTagsForRunSelection(state)).toEqual( new Set(['tag-1', 'tag-2']) ); }); it('returns only tags containing selected runs when some runs are selected', () => { - const state = { + const state = buildMockState({ ...appStateFromMetricsState( buildMetricsState({ tagMetadata: { @@ -276,6 +270,7 @@ describe('common selectors', () => { }), }) ), + ...buildStateFromRunsState( buildRunsState( { @@ -292,7 +287,7 @@ describe('common selectors', () => { } ) ), - }; + }); expect(selectors.TEST_ONLY.getScalarTagsForRunSelection(state)).toEqual( new Set(['tag-2']) ); @@ -301,7 +296,7 @@ describe('common selectors', () => { describe('getRenderableCardIdsWithMetadata', () => { it('returns all tags containing scalar data when no runs are selected', () => { - const state = { + const state = buildMockState({ ...appStateFromMetricsState( buildMetricsState({ cardList: ['card1', 'card2'], @@ -359,7 +354,7 @@ describe('common selectors', () => { } ) ), - }; + }); expect( selectors.TEST_ONLY.getRenderableCardIdsWithMetadata(state) ).toEqual([ @@ -381,7 +376,7 @@ describe('common selectors', () => { describe('getSortedRenderableCardIdsWithMetadata', () => { it('shows empty scalar cards when hideEmptyCards is false', () => { - const state = { + const state = buildMockState({ ...appStateFromMetricsState( buildMetricsState({ cardList: ['card1', 'card2', 'card3'], @@ -449,7 +444,7 @@ describe('common selectors', () => { } ) ), - }; + }); expect(selectors.getSortedRenderableCardIdsWithMetadata(state)).toEqual([ { cardId: 'card1', @@ -473,7 +468,7 @@ describe('common selectors', () => { }); it('hides empty scalar cards when hideEmptyCards is true', () => { - const state = { + const state = buildMockState({ ...appStateFromMetricsState( buildMetricsState({ cardList: ['card1', 'card2', 'card3'], @@ -541,7 +536,7 @@ describe('common selectors', () => { } ) ), - }; + }); expect(selectors.getSortedRenderableCardIdsWithMetadata(state)).toEqual([ { cardId: 'card1', @@ -1012,7 +1007,7 @@ describe('common selectors', () => { }); it('sets name as display name when a display name is not provided', () => { - state.hparams!.specs['defaultExperimentId'].hparam.specs.push( + state.hparams!.dashboardSpecs.hparams.push( buildHparamSpec({name: 'bar', displayName: ''}) ); expect(selectors.getPotentialHparamColumns(state)).toEqual([ diff --git a/tensorboard/webapp/runs/store/BUILD b/tensorboard/webapp/runs/store/BUILD index 3f4c3fe25dd..740e9fe6523 100644 --- a/tensorboard/webapp/runs/store/BUILD +++ b/tensorboard/webapp/runs/store/BUILD @@ -45,6 +45,8 @@ tf_ts_library( deps = [ ":types", ":utils", + "//tensorboard/webapp/hparams:types", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/runs:types", "//tensorboard/webapp/types", "//tensorboard/webapp/types:ui", @@ -98,11 +100,14 @@ tf_ts_library( "//tensorboard/webapp/app_routing:testing", "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", + "//tensorboard/webapp/app_routing/store:testing", + "//tensorboard/webapp/hparams:testing", "//tensorboard/webapp/runs:types", "//tensorboard/webapp/runs/actions", "//tensorboard/webapp/runs/data_source", "//tensorboard/webapp/runs/data_source:testing", "//tensorboard/webapp/testing:lang", + "//tensorboard/webapp/testing:utils", "//tensorboard/webapp/types", "//tensorboard/webapp/types:ui", "//tensorboard/webapp/widgets/data_table:types", diff --git a/tensorboard/webapp/runs/store/runs_selectors.ts b/tensorboard/webapp/runs/store/runs_selectors.ts index 5dc28ffd018..9804a189c61 100644 --- a/tensorboard/webapp/runs/store/runs_selectors.ts +++ b/tensorboard/webapp/runs/store/runs_selectors.ts @@ -27,6 +27,8 @@ import { } from './runs_types'; import {createGroupBy} from './utils'; import {ColumnHeader, SortingInfo} from '../../widgets/data_table/types'; +import {getDashboardRunsToHparamsAndMetrics} from '../../hparams/_redux/hparams_selectors'; +import {RunToHparamsAndMetrics} from '../../hparams/types'; const getRunsState = createFeatureSelector(RUNS_FEATURE_KEY); @@ -69,6 +71,8 @@ export const getRun = createSelector( /** * Returns Observable that emits runs list for an experiment. + * This is intended to be used in the experiment_list page. + * TODO(rileyajones) remove usage of this selector from the timeseries dashboard. */ export const getRuns = createSelector( getDataState, @@ -80,19 +84,33 @@ export const getRuns = createSelector( } ); -export const getRunsFromExperimentIds = (experimentIds: string[]) => +/** + * Get the runs used on the dashboard. + * TODO(rileyajones) get the experiment ids from the state rather than as an argument. + * @param experimentIds + * @returns + */ +export const getDashboardRuns = (experimentIds: string[]) => createSelector( getDataState, - (state: RunsDataState): Array => { - return experimentIds.reduce((runs, experimentId) => { - (state.runIds[experimentId] || []) - .filter((id) => Boolean(state.runMetadata[id])) - .forEach((runId) => { - runs.push({...state.runMetadata[runId], experimentId}); - }); + getDashboardRunsToHparamsAndMetrics, + ( + state: RunsDataState, + runsToHparamsAndMetrics: RunToHparamsAndMetrics + ): Array => { + return experimentIds + .map((experimentId) => { + return (state.runIds[experimentId] || []) + .filter((id) => Boolean(state.runMetadata[id])) + .map((runId) => { + const run = {...state.runMetadata[runId], experimentId}; + run.hparams = runsToHparamsAndMetrics[runId]?.hparams ?? null; + run.metrics = runsToHparamsAndMetrics[runId]?.metrics ?? null; - return runs; - }, [] as Array); + return run; + }); + }) + .flat(); } ); diff --git a/tensorboard/webapp/runs/store/runs_selectors_test.ts b/tensorboard/webapp/runs/store/runs_selectors_test.ts index 65626fa6144..7d2d624df44 100644 --- a/tensorboard/webapp/runs/store/runs_selectors_test.ts +++ b/tensorboard/webapp/runs/store/runs_selectors_test.ts @@ -12,6 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import { + buildAppRoutingState, + buildStateFromAppRoutingState, +} from '../../app_routing/store/testing'; +import {RouteKind} from '../../app_routing/types'; +import { + buildSessionGroup, + buildStateFromHparamsState, + buildHparamsState, +} from '../../hparams/testing'; +import {buildMockState} from '../../testing/utils'; import {DataLoadState} from '../../types/data'; import {SortDirection} from '../../types/ui'; import {ColumnHeaderType, SortingOrder} from '../../widgets/data_table/types'; @@ -27,15 +38,17 @@ describe('runs_selectors', () => { }); it('returns runIdToExpId', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIdToExpId: { - run1: 'eid1', - run2: 'eid1', - run3: 'eid2', - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIdToExpId: { + run1: 'eid1', + run2: 'eid1', + run3: 'eid2', + }, + }) + ), + }); expect(selectors.getRunIdToExperimentId(state)).toEqual({ run1: 'eid1', run2: 'eid1', @@ -50,15 +63,17 @@ describe('runs_selectors', () => { }); it('returns eid', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIdToExpId: { - run1: 'eid1', - run2: 'eid1', - run3: 'eid2', - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIdToExpId: { + run1: 'eid1', + run2: 'eid1', + run3: 'eid2', + }, + }) + ), + }); expect( selectors.getExperimentIdForRunId(state, { runId: 'run1', @@ -77,11 +92,13 @@ describe('runs_selectors', () => { }); it('returns `null` if the runId is unknown', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIdToExpId: {run1: 'eid1'}, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIdToExpId: {run1: 'eid1'}, + }) + ), + }); expect( selectors.getExperimentIdForRunId(state, { runId: 'run4', @@ -97,13 +114,15 @@ describe('runs_selectors', () => { }); it('returns run', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runMetadata: { - run1: buildRun({id: 'run1'}), - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runMetadata: { + run1: buildRun({id: 'run1'}), + }, + }) + ), + }); expect(selectors.getRun(state, {runId: 'run1'})).toEqual( buildRun({ @@ -113,13 +132,15 @@ describe('runs_selectors', () => { }); it('returns `null` if run with `runId` does not exist', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runMetadata: { - run1: buildRun({id: 'run1'}), - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runMetadata: { + run1: buildRun({id: 'run1'}), + }, + }) + ), + }); expect(selectors.getRun(state, {runId: 'run10'})).toBe(null); }); @@ -132,16 +153,18 @@ describe('runs_selectors', () => { }); it('returns runs', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIds: { - eid: ['run1'], - }, - runMetadata: { - run1: buildRun({id: 'run1'}), - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIds: { + eid: ['run1'], + }, + runMetadata: { + run1: buildRun({id: 'run1'}), + }, + }) + ), + }); expect(selectors.getRuns(state, {experimentId: 'eid'})).toEqual([ buildRun({ id: 'run1', @@ -150,16 +173,18 @@ describe('runs_selectors', () => { }); it('returns runs for the ones that has metadata', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIds: { - eid: ['run1', 'run2'], - }, - runMetadata: { - run1: buildRun({id: 'run1'}), - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIds: { + eid: ['run1', 'run2'], + }, + runMetadata: { + run1: buildRun({id: 'run1'}), + }, + }) + ), + }); expect(selectors.getRuns(state, {experimentId: 'eid'})).toEqual([ buildRun({ id: 'run1', @@ -168,7 +193,7 @@ describe('runs_selectors', () => { }); it('returns empty list if experiment id does not exist', () => { - const state = buildStateFromRunsState(buildRunsState()); + const state = buildMockState(); expect( selectors.getRuns(state, { experimentId: 'i_do_not_exist', @@ -177,19 +202,21 @@ describe('runs_selectors', () => { }); }); - describe('#getRunsFromExperimentIds', () => { + describe('#getDashboardRuns', () => { it('returns runs', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIds: { - eid: ['run1'], - }, - runMetadata: { - run1: buildRun({id: 'run1'}), - }, - }) - ); - expect(selectors.getRunsFromExperimentIds(['eid'])(state)).toEqual([ + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIds: { + eid: ['run1'], + }, + runMetadata: { + run1: buildRun({id: 'run1'}), + }, + }) + ), + }); + expect(selectors.getDashboardRuns(['eid'])(state)).toEqual([ { ...buildRun({ id: 'run1', @@ -200,17 +227,19 @@ describe('runs_selectors', () => { }); it('returns runs for the ones that has metadata', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIds: { - eid: ['run1', 'run2'], - }, - runMetadata: { - run1: buildRun({id: 'run1'}), - }, - }) - ); - expect(selectors.getRunsFromExperimentIds(['eid'])(state)).toEqual([ + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIds: { + eid: ['run1', 'run2'], + }, + runMetadata: { + run1: buildRun({id: 'run1'}), + }, + }) + ), + }); + expect(selectors.getDashboardRuns(['eid'])(state)).toEqual([ { ...buildRun({ id: 'run1', @@ -221,10 +250,83 @@ describe('runs_selectors', () => { }); it('returns empty list if experiment id does not exist', () => { - const state = buildStateFromRunsState(buildRunsState()); - expect( - selectors.getRunsFromExperimentIds(['i_do_not_exist'])(state) - ).toEqual([]); + const state = buildMockState(); + expect(selectors.getDashboardRuns(['i_do_not_exist'])(state)).toEqual([]); + }); + + it('includes dashboard hparams data', () => { + const state = buildMockState({ + ...buildStateFromAppRoutingState( + buildAppRoutingState({ + activeRoute: { + routeKind: RouteKind.EXPERIMENT, + params: {}, + }, + }) + ), + ...buildStateFromRunsState( + buildRunsState({ + runIds: { + eid: ['run1', 'run2'], + }, + runMetadata: { + run1: buildRun({id: 'run1'}), + }, + }) + ), + ...buildStateFromHparamsState( + buildHparamsState({ + dashboardSessionGroups: [ + buildSessionGroup({ + name: 'some_session_group', + hparams: {hp1: 'foo', hp2: 'bar'}, + sessions: [ + { + name: 'run1', + metricValues: [], + } as any, + ], + }), + ], + }) + ), + }); + expect(selectors.getDashboardRuns(['eid'])(state)).toEqual([ + { + ...buildRun({ + id: 'run1', + hparams: [ + {name: 'hp1', value: 'foo'}, + {name: 'hp2', value: 'bar'}, + ], + metrics: [], + }), + experimentId: 'eid', + }, + ]); + }); + + it('never returns hparams or metric data from run metadata', () => { + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIds: { + eid: ['run1', 'run2'], + }, + runMetadata: { + run1: buildRun({ + id: 'run1', + hparams: [{name: 'foo', value: '1'}], + metrics: [{tag: 'm1', value: 4}], + }), + }, + }) + ), + }); + + const response = selectors.getDashboardRuns(['eid'])(state)[0]; + expect(response.hparams).toBeNull(); + expect(response.metrics).toBeNull(); }); }); @@ -235,20 +337,22 @@ describe('runs_selectors', () => { }); it('returns runIds', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runIds: { - eid: ['run1', 'run2'], - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runIds: { + eid: ['run1', 'run2'], + }, + }) + ), + }); expect( selectors.getRunIdsForExperiment(state, {experimentId: 'eid'}) ).toEqual(['run1', 'run2']); }); it('returns empty list if experiment id does not exist', () => { - const state = buildStateFromRunsState(buildRunsState()); + const state = buildMockState(); expect( selectors.getRunIdsForExperiment(state, { experimentId: 'i_do_not_exist', @@ -264,14 +368,16 @@ describe('runs_selectors', () => { }); it('returns a map from RunId to Run', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runMetadata: { - run1: buildRun({id: 'run1'}), - run2: buildRun({id: 'run2'}), - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runMetadata: { + run1: buildRun({id: 'run1'}), + run2: buildRun({id: 'run2'}), + }, + }) + ), + }); expect(selectors.getRunMap(state)).toEqual( new Map([ @@ -282,11 +388,13 @@ describe('runs_selectors', () => { }); it('returns an empty map if there are no runs', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runMetadata: {}, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runMetadata: {}, + }) + ), + }); expect(selectors.getRunMap(state)).toEqual(new Map()); }); @@ -304,9 +412,11 @@ describe('runs_selectors', () => { lastLoadedTimeInMs: 1337, }; - const state = buildStateFromRunsState( - buildRunsState({runsLoadState: {id1: loadState}}) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({runsLoadState: {id1: loadState}}) + ), + }); expect( selectors.getRunsLoadState(state, { experimentId: 'id1', @@ -315,13 +425,15 @@ describe('runs_selectors', () => { }); it('returns NOT_LOADED state if experiment id does not exist', () => { - const state = buildStateFromRunsState( - buildRunsState({ - runsLoadState: { - id1: {state: DataLoadState.FAILED, lastLoadedTimeInMs: 1337}, - }, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + runsLoadState: { + id1: {state: DataLoadState.FAILED, lastLoadedTimeInMs: 1337}, + }, + }) + ), + }); expect(selectors.getRunsLoadState(state, {experimentId: 'id2'})).toEqual({ lastLoadedTimeInMs: null, state: DataLoadState.NOT_LOADED, @@ -336,8 +448,8 @@ describe('runs_selectors', () => { }); it('returns selection map of runId passed', () => { - const state = buildStateFromRunsState( - buildRunsState( + const state = buildMockState({ + runs: buildRunsState( {}, { selectionState: new Map([ @@ -345,8 +457,8 @@ describe('runs_selectors', () => { ['r2', true], ]), } - ) - ); + ), + }); const actual = selectors.getRunSelectionMap(state); expect(actual).toEqual( @@ -451,14 +563,16 @@ describe('runs_selectors', () => { }); it('returns override map', () => { - const state = buildStateFromRunsState( - buildRunsState({ - defaultRunColorIdForGroupBy: new Map([ - ['foo', 1], - ['bar', 2], - ]), - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + defaultRunColorIdForGroupBy: new Map([ + ['foo', 1], + ['bar', 2], + ]), + }) + ), + }); expect(selectors.getDefaultRunColorIdMap(state)).toEqual( new Map([ @@ -476,13 +590,15 @@ describe('runs_selectors', () => { }); it('returns groupBy set by user when it is present', () => { - const state = buildStateFromRunsState( - buildRunsState({ - colorGroupRegexString: 'hello', - initialGroupBy: {key: GroupByKey.RUN}, - userSetGroupByKey: GroupByKey.REGEX, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + colorGroupRegexString: 'hello', + initialGroupBy: {key: GroupByKey.RUN}, + userSetGroupByKey: GroupByKey.REGEX, + }) + ), + }); expect(selectors.getRunUserSetGroupBy(state)).toEqual({ key: GroupByKey.REGEX, @@ -491,12 +607,14 @@ describe('runs_selectors', () => { }); it('returns null if user never has set one', () => { - const state = buildStateFromRunsState( - buildRunsState({ - initialGroupBy: {key: GroupByKey.RUN}, - userSetGroupByKey: null, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + initialGroupBy: {key: GroupByKey.RUN}, + userSetGroupByKey: null, + }) + ), + }); expect(selectors.getRunUserSetGroupBy(state)).toEqual(null); }); @@ -510,13 +628,15 @@ describe('runs_selectors', () => { }); it('returns groupBy set by user when it is present', () => { - const state = buildStateFromRunsState( - buildRunsState({ - colorGroupRegexString: 'hello', - initialGroupBy: {key: GroupByKey.RUN}, - userSetGroupByKey: GroupByKey.REGEX, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + colorGroupRegexString: 'hello', + initialGroupBy: {key: GroupByKey.RUN}, + userSetGroupByKey: GroupByKey.REGEX, + }) + ), + }); expect(selectors.getRunGroupBy(state)).toEqual({ key: GroupByKey.REGEX, @@ -525,13 +645,15 @@ describe('runs_selectors', () => { }); it('returns groupBy set by user with regexString overridden', () => { - const state = buildStateFromRunsState( - buildRunsState({ - colorGroupRegexString: '', - initialGroupBy: {key: GroupByKey.REGEX, regexString: 'hello'}, - userSetGroupByKey: GroupByKey.REGEX, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + colorGroupRegexString: '', + initialGroupBy: {key: GroupByKey.REGEX, regexString: 'hello'}, + userSetGroupByKey: GroupByKey.REGEX, + }) + ), + }); expect(selectors.getRunGroupBy(state)).toEqual({ key: GroupByKey.REGEX, @@ -540,12 +662,14 @@ describe('runs_selectors', () => { }); it('returns initial group by if user never has set one', () => { - const state = buildStateFromRunsState( - buildRunsState({ - initialGroupBy: {key: GroupByKey.RUN}, - userSetGroupByKey: null, - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + initialGroupBy: {key: GroupByKey.RUN}, + userSetGroupByKey: null, + }) + ), + }); expect(selectors.getRunGroupBy(state)).toEqual({ key: GroupByKey.RUN, @@ -560,27 +684,31 @@ describe('runs_selectors', () => { }); it('returns regex string when it is group by regex', () => { - const state = buildStateFromRunsState( - buildRunsState({ - colorGroupRegexString: 'foo(\\d+)', - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + colorGroupRegexString: 'foo(\\d+)', + }) + ), + }); expect(selectors.getColorGroupRegexString(state)).toEqual('foo(\\d+)'); }); it('returns default empty string if user never has set one', () => { - const state = buildStateFromRunsState(buildRunsState({})); + const state = buildMockState(); expect(selectors.getColorGroupRegexString(state)).toEqual(''); }); it('returns regex string even if it is not user set groupby', () => { - const state = buildStateFromRunsState( - buildRunsState({ - colorGroupRegexString: 'foo(\\d+)', - }) - ); + const state = buildMockState({ + ...buildStateFromRunsState( + buildRunsState({ + colorGroupRegexString: 'foo(\\d+)', + }) + ), + }); expect(selectors.getColorGroupRegexString(state)).toEqual('foo(\\d+)'); }); @@ -588,8 +716,8 @@ describe('runs_selectors', () => { describe('#getRunsTableHeaders', () => { it('returns the runs table headers', () => { - const state = buildStateFromRunsState( - buildRunsState( + const state = buildMockState({ + runs: buildRunsState( {}, { runsTableHeaders: [ @@ -607,8 +735,8 @@ describe('runs_selectors', () => { }, ], } - ) - ); + ), + }); expect(selectors.getRunsTableHeaders(state)).toEqual([ { type: ColumnHeaderType.RUN, @@ -628,8 +756,8 @@ describe('runs_selectors', () => { describe('#getRunsTableSortingInfo', () => { it('returns the runs data table sorting info', () => { - const state = buildStateFromRunsState( - buildRunsState( + const state = buildMockState({ + runs: buildRunsState( {}, { sortingInfo: { @@ -637,8 +765,8 @@ describe('runs_selectors', () => { order: SortingOrder.ASCENDING, }, } - ) - ); + ), + }); expect(selectors.getRunsTableSortingInfo(state)).toEqual({ name: 'run', order: SortingOrder.ASCENDING, diff --git a/tensorboard/webapp/testing/BUILD b/tensorboard/webapp/testing/BUILD index 36b5fd52f19..7fa69f955fe 100644 --- a/tensorboard/webapp/testing/BUILD +++ b/tensorboard/webapp/testing/BUILD @@ -66,12 +66,15 @@ tf_ts_library( "utils.ts", ], deps = [ + "//tensorboard/plugins/debugger_v2/tf_debugger_v2_plugin/store:types", "//tensorboard/plugins/debugger_v2/tf_debugger_v2_plugin/testing", "//tensorboard/webapp:app_state", "//tensorboard/webapp/alert/store:testing", + "//tensorboard/webapp/alert/store:types", "//tensorboard/webapp/angular:expect_ngrx_store_testing", "//tensorboard/webapp/app_routing/store:testing", "//tensorboard/webapp/app_routing/store:types", + "//tensorboard/webapp/core/store", "//tensorboard/webapp/core/testing", "//tensorboard/webapp/experiments/store:testing", "//tensorboard/webapp/experiments/store:types", @@ -84,9 +87,11 @@ tf_ts_library( "//tensorboard/webapp/notification_center/_redux:testing", "//tensorboard/webapp/notification_center/_redux:types", "//tensorboard/webapp/persistent_settings/_redux:testing", + "//tensorboard/webapp/persistent_settings/_redux:types", "//tensorboard/webapp/runs/store:testing", "//tensorboard/webapp/runs/store:types", "//tensorboard/webapp/settings:testing", + "//tensorboard/webapp/settings/_redux", "//tensorboard/webapp/util:types", "@npm//@ngrx/store", ], diff --git a/tensorboard/webapp/testing/utils.ts b/tensorboard/webapp/testing/utils.ts index dde6d15a307..97beab4ee78 100644 --- a/tensorboard/webapp/testing/utils.ts +++ b/tensorboard/webapp/testing/utils.ts @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import {provideMockStore} from '@ngrx/store/testing'; -import {createState as createDebuggerState} from '../../plugins/debugger_v2/tf_debugger_v2_plugin/testing'; +import { + createState as buildStateFromDebuggerState, + createDebuggerState, +} from '../../plugins/debugger_v2/tf_debugger_v2_plugin/testing'; import { buildAlertState, buildStateFromAlertState, @@ -24,7 +27,10 @@ import { buildStateFromAppRoutingState, } from '../app_routing/store/testing'; import {State} from '../app_state'; -import {createState as createCoreState} from '../core/testing'; +import { + createState as buildStateFromCoreState, + createCoreState, +} from '../core/testing'; import {EXPERIMENTS_FEATURE_KEY} from '../experiments/store/experiments_types'; import { buildExperimentState, @@ -33,7 +39,7 @@ import { import {FEATURE_FLAG_FEATURE_KEY} from '../feature_flag/store/feature_flag_types'; import { buildFeatureFlagState, - buildState as buildStateFromFeatureFlagsState, + buildState as buildStateFromFeatureFlagState, } from '../feature_flag/store/testing'; import { buildHparamsState, @@ -53,22 +59,40 @@ import { import {RUNS_FEATURE_KEY} from '../runs/store/runs_types'; import {buildRunsState, buildStateFromRunsState} from '../runs/store/testing'; import { - createState as createSettings, + createState as buildStateFromSettingsState, createSettingsState, } from '../settings/testing'; import {HPARAMS_FEATURE_KEY} from '../hparams/_redux/types'; +import {ALERT_FEATURE_KEY} from '../alert/store/alert_types'; +import {PERSISTENT_SETTINGS_FEATURE_KEY} from '../persistent_settings/_redux/persistent_settings_types'; +import {SETTINGS_FEATURE_KEY} from '../settings/_redux/settings_types'; +import {CORE_FEATURE_KEY} from '../core/store/core_types'; +import {DEBUGGER_FEATURE_KEY} from '../../plugins/debugger_v2/tf_debugger_v2_plugin/store/debugger_types'; + +type PartialOverrides = { + [K in keyof State]?: Partial; +}; -export function buildMockState(overrides: Partial = {}): State { +export function buildMockState(overrides: PartialOverrides = {}): State { return { - ...createDebuggerState(), - ...buildFeatureFlagState(overrides[FEATURE_FLAG_FEATURE_KEY]), - ...buildStateFromAlertState(buildAlertState({})), - ...buildStateFromPersistentSettingsState(buildPersistentSettingsState({})), - ...createCoreState(), + ...buildStateFromDebuggerState( + createDebuggerState(overrides[DEBUGGER_FEATURE_KEY] ?? {}) + ), + ...buildStateFromAlertState( + buildAlertState(overrides[ALERT_FEATURE_KEY] ?? {}) + ), + ...buildStateFromPersistentSettingsState( + buildPersistentSettingsState( + overrides[PERSISTENT_SETTINGS_FEATURE_KEY] ?? {} + ) + ), + ...buildStateFromCoreState(createCoreState(overrides[CORE_FEATURE_KEY])), ...appStateFromMetricsState( buildMetricsState(overrides[METRICS_FEATURE_KEY]) ), - ...createSettings(createSettingsState()), + ...buildStateFromSettingsState( + createSettingsState(overrides[SETTINGS_FEATURE_KEY]) + ), ...buildStateFromRunsState( buildRunsState( overrides[RUNS_FEATURE_KEY]?.data, @@ -81,12 +105,14 @@ export function buildMockState(overrides: Partial = {}): State { ...buildStateFromAppRoutingState( buildAppRoutingState(overrides[APP_ROUTING_FEATURE_KEY]) ), - ...buildStateFromFeatureFlagsState(buildFeatureFlagState()), + ...buildStateFromFeatureFlagState( + buildFeatureFlagState(overrides[FEATURE_FLAG_FEATURE_KEY]) + ), ...buildStateFromHparamsState( buildHparamsState(overrides[HPARAMS_FEATURE_KEY]) ), ...buildStateFromNotificationState( - buildNotificationState(overrides[NOTIFICATION_FEATURE_KEY] || {}) + buildNotificationState(overrides[NOTIFICATION_FEATURE_KEY] ?? {}) ), }; } diff --git a/tensorboard/webapp/util/BUILD b/tensorboard/webapp/util/BUILD index c76f154c860..9bbf2de4f37 100644 --- a/tensorboard/webapp/util/BUILD +++ b/tensorboard/webapp/util/BUILD @@ -123,6 +123,7 @@ tf_ts_library( "//tensorboard/webapp/feature_flag/store:testing", "//tensorboard/webapp/runs/store:testing", "//tensorboard/webapp/settings:testing", + "//tensorboard/webapp/testing:utils", "@npm//@ngrx/store", "@npm//@types/jasmine", ], diff --git a/tensorboard/webapp/util/ui_selectors_test.ts b/tensorboard/webapp/util/ui_selectors_test.ts index 6e161b8067d..cd9fffdde8a 100644 --- a/tensorboard/webapp/util/ui_selectors_test.ts +++ b/tensorboard/webapp/util/ui_selectors_test.ts @@ -26,8 +26,8 @@ import { } from '../experiments/store/testing'; import { buildFeatureFlag, - buildFeatureFlagState, buildState as buildStateFromFeatureFlagState, + buildFeatureFlagState, } from '../feature_flag/store/testing'; import { buildRun, @@ -44,9 +44,10 @@ import { import { buildColorPalette, createSettings as buildSettings, + createState as appStateFromSettingsState, createSettingsState as buildSettingsState, - createState as buildStateFromSettingsState, } from '../settings/testing'; +import {buildMockState} from '../testing/utils'; import {ColorPalette} from './colors'; import {getCurrentRouteRunSelection, getRunColorMap} from './ui_selectors'; @@ -63,7 +64,7 @@ describe('ui_selectors test', () => { describe('#getCurrentRouteRunSelection', () => { it('returns selection map of current experiments', () => { - const state = { + const state = buildMockState({ ...buildStateFromAppRoutingState( // The route only contains experiments 123 and 234 buildAppRoutingState({ @@ -100,7 +101,7 @@ describe('ui_selectors test', () => { }, }) ), - }; + }); // Runs form experiment 345 are not included in the final result. expect(getCurrentRouteRunSelection(state)).toEqual( @@ -112,7 +113,7 @@ describe('ui_selectors test', () => { }); it('returns null if current route does not have experimentIds', () => { - const state = { + const state = buildMockState({ ...buildStateFromAppRoutingState( buildAppRoutingState({ activeRoute: buildRoute({ @@ -144,14 +145,14 @@ describe('ui_selectors test', () => { }, }) ), - }; + }); expect(getCurrentRouteRunSelection(state)).toBeNull(); }); describe('regex filter', () => { it('filters runs based on regex and run name', () => { - const state = { + const state = buildMockState({ ...buildStateFromAppRoutingState( buildAppRoutingState({ activeRoute: buildRoute({ @@ -194,7 +195,7 @@ describe('ui_selectors test', () => { }, }) ), - }; + }); expect(getCurrentRouteRunSelection(state)).toEqual( new Map([ @@ -206,7 +207,7 @@ describe('ui_selectors test', () => { }); it('filters run name and alias in compare mode', () => { - const state = { + const state = buildMockState({ ...buildStateFromAppRoutingState( buildAppRoutingState({ activeRoute: buildRoute({ @@ -260,7 +261,7 @@ describe('ui_selectors test', () => { }, }) ), - }; + }); expect(getCurrentRouteRunSelection(state)).toEqual( new Map([ @@ -283,7 +284,7 @@ describe('ui_selectors test', () => { }); it('does not violently throw when an experiment metadata is null', () => { - const state = { + const state = buildMockState({ ...buildStateFromAppRoutingState( buildAppRoutingState({ activeRoute: buildRoute({ @@ -331,7 +332,7 @@ describe('ui_selectors test', () => { }, }) ), - }; + }); expect(getCurrentRouteRunSelection(state)).toEqual( new Map([ @@ -352,14 +353,14 @@ describe('ui_selectors test', () => { colorPalette: ColorPalette = buildColorPalette(), useDarkMode: boolean = false ): State { - return { + return buildMockState({ ...buildStateFromRunsState( buildRunsState({ defaultRunColorIdForGroupBy, runColorOverrideForGroupBy, }) ), - ...buildStateFromSettingsState( + ...appStateFromSettingsState( buildSettingsState({ settings: buildSettings({colorPalette}), }) @@ -371,7 +372,7 @@ describe('ui_selectors test', () => { }), }) ), - }; + }); } it('returns color from color id from the default PALETTE', () => {