Skip to content

Commit

Permalink
Hparams: refactor arguments provided to getDashboardRuns (tensorflo…
Browse files Browse the repository at this point in the history
…w#6555)

## Motivation for features / changes
As per a discussion around tensorflow#6544 we determined that the dashboard should
not be requesting runs per experiment but should instead be getting all
runs for all experiments shown.
This greatly simplifies a lot of the logic used to retrieve runs and run
related data.
  • Loading branch information
rileyajones authored and yatbear committed Aug 25, 2023
1 parent f229893 commit 3565328
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 232 deletions.
5 changes: 5 additions & 0 deletions tensorboard/webapp/experiments/store/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tf_ts_library(
],
deps = [
":types",
"//tensorboard/webapp/app_routing/store",
"//tensorboard/webapp/experiments:types",
"//tensorboard/webapp/types",
"@npm//@ngrx/store",
Expand Down Expand Up @@ -65,6 +66,10 @@ tf_ts_library(
":selectors",
":testing",
":types",
"//tensorboard/webapp:app_state",
"//tensorboard/webapp/app_routing:types",
"//tensorboard/webapp/app_routing/store:testing",
"//tensorboard/webapp/testing:utils",
"//tensorboard/webapp/types",
"@npm//@types/jasmine",
],
Expand Down
31 changes: 17 additions & 14 deletions tensorboard/webapp/experiments/store/experiments_selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
ExperimentsState,
EXPERIMENTS_FEATURE_KEY,
} from './experiments_types';
import {getExperimentIdsFromRoute} from '../../app_routing/store/app_routing_selectors';

const getExperimentsState = createFeatureSelector<ExperimentsState>(
EXPERIMENTS_FEATURE_KEY
Expand Down Expand Up @@ -46,18 +47,20 @@ export const getExperiment = createSelector(
);

/**
* Returns Observable that emits an object mapping the provided
* experiment ids to experiment names.
* Returns the names of all experiments present on the current dashboard.
*/
export const getExperimentNames = (experimentIds: string[]) =>
createSelector(
getDataState,
(state: ExperimentsDataState): Record<string, string> =>
experimentIds
.map((experimentId) => state.experimentMap[experimentId])
.filter(Boolean)
.reduce((map, experiment) => {
map[experiment.id] = experiment.name;
return map;
}, {} as Record<string, string>)
);
export const getDashboardExperimentNames = createSelector(
getDataState,
getExperimentIdsFromRoute,
(
state: ExperimentsDataState,
experimentIds: string[] | null
): Record<string, string> =>
(experimentIds ?? [])
.map((experimentId) => state.experimentMap[experimentId])
.filter(Boolean)
.reduce((map, experiment) => {
map[experiment.id] = experiment.name;
return map;
}, {} as Record<string, string>)
);
48 changes: 37 additions & 11 deletions tensorboard/webapp/experiments/store/experiments_selectors_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import * as selectors from './experiments_selectors';
import {State} from './experiments_types';
import {State} from '../../app_state';
import {buildExperiment, buildStateFromExperimentsState} from './testing';
import {buildMockState} from '../../testing/utils';
import {
buildAppRoutingState,
buildStateFromAppRoutingState,
} from '../../app_routing/store/testing';
import {RouteKind} from '../../app_routing/types';

describe('experiments selectors', () => {
describe('#getExperiment', () => {
Expand Down Expand Up @@ -49,32 +55,52 @@ describe('experiments selectors', () => {
});
});

describe('#getExperimentNames', () => {
describe('#getDashboardExperimentNames', () => {
let state: State;

beforeEach(() => {
const foo = buildExperiment({id: 'foo', name: 'foo name'});
const bar = buildExperiment({id: 'bar', name: 'bar name'});

state = buildStateFromExperimentsState({
data: {
experimentMap: {foo, bar},
},
state = buildMockState({
...buildStateFromAppRoutingState(
buildAppRoutingState({
activeRoute: {
routeKind: RouteKind.COMPARE_EXPERIMENT,
params: {
experimentIds: 'exp1:foo,exp2:bar,exp3:baz',
},
},
})
),
...buildStateFromExperimentsState({
data: {
experimentMap: {foo, bar},
},
}),
});
});

it('translates experiment ids to experiment names', () => {
expect(
selectors.getExperimentNames(['foo', 'bar', 'baz'])(state)
).toEqual({
expect(selectors.getDashboardExperimentNames(state)).toEqual({
foo: 'foo name',
bar: 'bar name',
});
});

it('returns an empty object when no experiments are provided', () => {
expect(selectors.getExperimentNames([])(state)).toEqual({});
expect(selectors.getExperimentNames(['abc', '123'])(state)).toEqual({});
state = {
...state,
...buildStateFromAppRoutingState(
buildAppRoutingState({
activeRoute: {
routeKind: RouteKind.EXPERIMENTS,
params: {},
},
})
),
};
expect(selectors.getDashboardExperimentNames(state)).toEqual({});
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ import {
HeaderToggleInfo,
XAxisType,
} from '../../types';
import {getFilteredRenderableRunsIdsFromRoute} from '../main_view/common_selectors';
import {getFilteredRenderableRunsIds} from '../main_view/common_selectors';
import {CardRenderer} from '../metrics_view_types';
import {getTagDisplayName} from '../utils';
import {DataDownloadDialogContainer} from './data_download_dialog_container';
Expand Down Expand Up @@ -499,7 +499,7 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy {
combineLatestWith(
this.store.select(getCurrentRouteRunSelection),
this.store.select(getEnableHparamsInTimeSeries),
this.store.select(getFilteredRenderableRunsIdsFromRoute),
this.store.select(getFilteredRenderableRunsIds),
this.store.select(getRunColorMap),
this.store.select(getMetricsScalarSmoothing)
),
Expand Down
40 changes: 20 additions & 20 deletions tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ describe('scalar card', () => {
new Map([['run1', true]])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1'])
);
store.overrideSelector(selectors.getMetricsXAxisType, XAxisType.STEP);
Expand Down Expand Up @@ -764,7 +764,7 @@ describe('scalar card', () => {
new Map([['run1', true]])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1'])
);
});
Expand Down Expand Up @@ -2574,7 +2574,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);
store.overrideSelector(getCardStateMap, {
Expand Down Expand Up @@ -2996,7 +2996,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);

Expand Down Expand Up @@ -3066,7 +3066,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);

Expand Down Expand Up @@ -3145,7 +3145,7 @@ describe('scalar card', () => {
new Map([['run1', true]])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1'])
);

Expand Down Expand Up @@ -3211,7 +3211,7 @@ describe('scalar card', () => {
new Map([['run1', true]])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1'])
);

Expand Down Expand Up @@ -3278,7 +3278,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);

Expand Down Expand Up @@ -3330,7 +3330,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);

Expand Down Expand Up @@ -3382,7 +3382,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);

Expand Down Expand Up @@ -3432,7 +3432,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);
store.overrideSelector(getMetricsLinkedTimeSelection, {
Expand Down Expand Up @@ -3481,7 +3481,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);
store.overrideSelector(selectors.getExperimentIdToExperimentAliasMap, {
Expand Down Expand Up @@ -3539,7 +3539,7 @@ describe('scalar card', () => {
new Map([['run1', true]])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1'])
);
store.overrideSelector(getMetricsLinkedTimeSelection, {
Expand Down Expand Up @@ -3586,7 +3586,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2', 'run3'])
);

Expand Down Expand Up @@ -3635,7 +3635,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2', 'run3'])
);

Expand Down Expand Up @@ -3692,7 +3692,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2', 'run3', 'run4', 'run5', 'run6', 'run7'])
);

Expand Down Expand Up @@ -3754,7 +3754,7 @@ describe('scalar card', () => {
);

store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2', 'run3', 'run4', 'run5', 'run6', 'run7'])
);

Expand Down Expand Up @@ -3840,7 +3840,7 @@ describe('scalar card', () => {
);

store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1'])
);

Expand Down Expand Up @@ -3952,7 +3952,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2', 'run3', 'run4'])
);

Expand Down Expand Up @@ -3981,7 +3981,7 @@ describe('scalar card', () => {
])
);
store.overrideSelector(
commonSelectors.getFilteredRenderableRunsIdsFromRoute,
commonSelectors.getFilteredRenderableRunsIds,
new Set(['run1', 'run2'])
);
const fixture = createComponent('card1');
Expand Down
Loading

0 comments on commit 3565328

Please sign in to comment.