diff --git a/tensorboard/webapp/app_routing/types.ts b/tensorboard/webapp/app_routing/types.ts index 8d84687737..20efe88067 100644 --- a/tensorboard/webapp/app_routing/types.ts +++ b/tensorboard/webapp/app_routing/types.ts @@ -38,6 +38,21 @@ export enum RouteKind { NOT_SET, } +export type DashboardRoute = + | RouteKind.EXPERIMENT + | RouteKind.COMPARE_EXPERIMENT + | RouteKind.CARD; + +export function isDashboardRoute( + routeKind: RouteKind +): routeKind is DashboardRoute { + return ( + routeKind === RouteKind.EXPERIMENT || + routeKind === RouteKind.COMPARE_EXPERIMENT || + routeKind === RouteKind.CARD + ); +} + export const DEFAULT_EXPERIMENT_ID = 'defaultExperimentId'; /** diff --git a/tensorboard/webapp/runs/effects/runs_effects.ts b/tensorboard/webapp/runs/effects/runs_effects.ts index b4c3042635..147f61f72c 100644 --- a/tensorboard/webapp/runs/effects/runs_effects.ts +++ b/tensorboard/webapp/runs/effects/runs_effects.ts @@ -15,24 +15,27 @@ limitations under the License. import {Injectable} from '@angular/core'; import {Actions, createEffect, ofType} from '@ngrx/effects'; import {Store} from '@ngrx/store'; -import {forkJoin, merge, Observable, of, throwError} from 'rxjs'; +import {forkJoin, merge, Observable, of, throwError, zip} from 'rxjs'; import { catchError, distinctUntilChanged, filter, map, mergeMap, + switchMap, take, tap, withLatestFrom, } from 'rxjs/operators'; import {areSameRouteKindAndExperiments} from '../../app_routing'; import {navigated} from '../../app_routing/actions'; -import {RouteKind} from '../../app_routing/types'; +import {RouteKind, isDashboardRoute} from '../../app_routing/types'; import {State} from '../../app_state'; import * as coreActions from '../../core/actions'; import { getActiveRoute, + getRouteKind, + getEnableHparamsInTimeSeries, getExperimentIdsFromRoute, getRuns, getRunsLoadState, @@ -264,15 +267,33 @@ export class RunsEffects { ); } + private maybeFetchHparamsMetadata( + experimentId: string + ): Observable { + return this.store.select(getEnableHparamsInTimeSeries).pipe( + withLatestFrom(this.store.select(getRouteKind)), + switchMap(([hparamsInTimeSeries, routeKind]) => { + if (hparamsInTimeSeries && isDashboardRoute(routeKind)) { + return of({ + hparamSpecs: [], + metricSpecs: [], + runToHparamsAndMetrics: {}, + }); + } + return this.runsDataSource.fetchHparamsMetadata(experimentId); + }) + ); + } + private fetchRunsForExperiment(experimentId: string): Observable<{ fromRemote: true; experimentId: string; runs: Run[]; metadata: HparamsAndMetadata; }> { - return forkJoin([ + return zip([ this.runsDataSource.fetchRuns(experimentId), - this.runsDataSource.fetchHparamsMetadata(experimentId), + this.maybeFetchHparamsMetadata(experimentId), ]).pipe( map(([runs, metadata]) => { return {fromRemote: true, experimentId, runs, metadata}; diff --git a/tensorboard/webapp/runs/effects/runs_effects_test.ts b/tensorboard/webapp/runs/effects/runs_effects_test.ts index f27c830ab1..de38326dbb 100644 --- a/tensorboard/webapp/runs/effects/runs_effects_test.ts +++ b/tensorboard/webapp/runs/effects/runs_effects_test.ts @@ -28,7 +28,9 @@ import {State} from '../../app_state'; import * as coreActions from '../../core/actions'; import { getActiveRoute, + getEnableHparamsInTimeSeries, getExperimentIdsFromRoute, + getRouteKind, getRuns, getRunsLoadState, } from '../../selectors'; @@ -62,6 +64,7 @@ describe('runs_effects', () => { let dispatchSpy: jasmine.Spy; let actualActions: Action[]; let selectSpy: jasmine.Spy; + let fetchHparamsSpy: jasmine.Spy; function flushFetchRuns(requestIndex: number, runs: Run[]) { expect(fetchRunsSubjects.length).toBeGreaterThan(requestIndex); @@ -113,7 +116,8 @@ describe('runs_effects', () => { }); fetchHparamsMetadataSubjects = []; - spyOn(runsDataSource, 'fetchHparamsMetadata').and.callFake(() => { + fetchHparamsSpy = spyOn(runsDataSource, 'fetchHparamsMetadata'); + fetchHparamsSpy.and.callFake(() => { const subject = new ReplaySubject(1); fetchHparamsMetadataSubjects.push(subject); return subject; @@ -231,6 +235,15 @@ describe('runs_effects', () => { }); }); + it('does not fetch hparam data when enableHparamsInTimeSeries is true when on a dashboard route', () => { + store.overrideSelector(getEnableHparamsInTimeSeries, true); + store.overrideSelector(getRouteKind, RouteKind.EXPERIMENT); + store.refreshState(); + + action.next(actions.runTableShown({experimentIds: ['a']})); + expect(fetchHparamsSpy).not.toHaveBeenCalled(); + }); + it('fires FAILED action when failed to fetch runs', () => { action.next(actions.runTableShown({experimentIds: ['a']})); const expectedExperimentId = 'a';