Skip to content

Commit

Permalink
add new hparams selectors to get default filters
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyajones committed Aug 21, 2023
1 parent f5b0bfd commit d5fd188
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 49 deletions.
2 changes: 2 additions & 0 deletions tensorboard/webapp/hparams/_redux/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tf_ts_library(
"utils.ts",
],
deps = [
":types",
"//tensorboard/webapp/hparams:types",
],
)
Expand Down Expand Up @@ -47,6 +48,7 @@ tf_ts_library(
"hparams_actions.ts",
],
deps = [
":types",
"//tensorboard/webapp/hparams:types",
"@npm//@ngrx/store",
],
Expand Down
11 changes: 11 additions & 0 deletions tensorboard/webapp/hparams/_redux/hparams_actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
HparamAndMetricSpec,
SessionGroup,
} from '../types';
import {HparamFilter, MetricFilter} from './types';

export const hparamsDiscreteHparamFilterChanged = createAction(
'[Hparams] Hparams Discrete Hparam Filter Changed',
Expand Down Expand Up @@ -62,3 +63,13 @@ export const hparamsFetchSessionGroupsSucceeded = createAction(
sessionGroups: SessionGroup[];
}>()
);

export const dashboardHparamFilterAdded = createAction(
'[Hparams] Dashboard Hparam Filter Added',
props<{name: string; filter: HparamFilter}>()
);

export const dashboardMetricFilterAdded = createAction(
'[Hparams] Dashboard Metric Filter Added',
props<{name: string; filter: MetricFilter}>()
);
28 changes: 28 additions & 0 deletions tensorboard/webapp/hparams/_redux/hparams_reducers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ const initialState: HparamsState = {
metrics: [],
},
dashboardSessionGroups: [],
dashboardFilters: {
hparams: new Map(),
metrics: new Map(),
},
};

const reducer: ActionReducer<HparamsState, Action> = createReducer(
Expand Down Expand Up @@ -369,6 +373,30 @@ const reducer: ActionReducer<HparamsState, Action> = createReducer(
dashboardSpecs: nextDashboardSpecs,
dashboardSessionGroups: nextDashboardSessionGroups,
};
}),
on(actions.dashboardHparamFilterAdded, (state, action) => {
return {
...state,
dashboardFilters: {
...state.dashboardFilters,
hparams: {
...state.dashboardFilters.hparams,
[action.name]: action.filter,
},
},
};
}),
on(actions.dashboardMetricFilterAdded, (state, action) => {
return {
...state,
dashboardFilters: {
...state.dashboardFilters,
metrics: {
...state.dashboardFilters.metrics,
[action.name]: action.filter,
},
},
};
})
);

Expand Down
44 changes: 43 additions & 1 deletion tensorboard/webapp/hparams/_redux/hparams_selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ import {
RunToHparamsAndMetrics,
} from '../types';
import {combineHparamAndMetricSpecs} from './hparams_selectors_utils';
import {HparamsState, HPARAMS_FEATURE_KEY} from './types';
import {
HparamsState,
HPARAMS_FEATURE_KEY,
HparamsAndMetricsFilters,
} from './types';
import {
combineDefaultHparamFilters,
combineDefaultMetricFilters,
getIdFromExperimentIds,
hparamSpecToDefaultFilter,
metricSpecToDefaultFilter,
} from './utils';

const getHparamsState =
Expand Down Expand Up @@ -232,3 +238,39 @@ export const getDashboardRunsToHparamsAndMetrics = createSelector(
return runToHparamsAndMetrics;
}
);

export const getDashboardDefaultFilters = createSelector(
getDashboardHparamsAndMetricsSpecs,
(specs): HparamsAndMetricsFilters => {
const hparams = new Map(
specs.hparams.map((hparamSpec) => {
return [hparamSpec.name, hparamSpecToDefaultFilter(hparamSpec)];
})
);

const metrics = new Map(
specs.metrics.map((metricSpec) => {
return [metricSpec.name.tag, metricSpecToDefaultFilter(metricSpec)];
})
);

return {
hparams,
metrics,
};
}
);

export const getDashboardHparamFilterMap = createSelector(
getHparamsState,
(state) => {
return state.dashboardFilters.hparams;
}
);

export const getDashboardMetricsFilterMap = createSelector(
getHparamsState,
(state) => {
return state.dashboardFilters.metrics;
}
);
18 changes: 12 additions & 6 deletions tensorboard/webapp/hparams/_redux/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,22 @@ import {
SessionGroup,
} from '../_types';

export type HparamFilter = DiscreteFilter | IntervalFilter;
export type MetricFilter = IntervalFilter;

export interface HparamsAndMetricsFilters {
hparams: Map<string, HparamFilter>;
metrics: Map<string, MetricFilter>;
}

export interface HparamsMetricsAndFilters {
hparam: {
specs: HparamSpec[];
defaultFilters: Map<string, DiscreteFilter | IntervalFilter>;
defaultFilters: Map<string, HparamFilter>;
};
metric: {
specs: MetricSpec[];
defaultFilters: Map<string, IntervalFilter>;
defaultFilters: Map<string, MetricFilter>;
};
}

Expand All @@ -47,6 +55,7 @@ export interface HparamsState {
specs: ExperimentToHparams;
dashboardSpecs: HparamAndMetricSpec;
dashboardSessionGroups: SessionGroup[];
dashboardFilters: HparamsAndMetricsFilters;
/**
* RATIONALE: we do not use the NamespaceContextedState because of the following reasons.
* - RunsTable which uses the state renders both on the dashboard view and the
Expand All @@ -61,10 +70,7 @@ export interface HparamsState {
* separate for the list and the dashboard views, but them shared is not too bad.
*/
filters: {
[id: string]: {
hparams: Map<string, DiscreteFilter | IntervalFilter>;
metrics: Map<string, IntervalFilter>;
};
[id: string]: HparamsAndMetricsFilters;
};
}

Expand Down
34 changes: 34 additions & 0 deletions tensorboard/webapp/hparams/_redux/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ import {
DiscreteHparamValue,
DiscreteHparamValues,
DomainType,
HparamSpec,
IntervalFilter,
MetricSpec,
} from '../types';
import {HparamFilter, MetricFilter} from './types';

export function getIdFromExperimentIds(experimentIds: string[]): string {
return JSON.stringify([...experimentIds].sort());
Expand Down Expand Up @@ -147,3 +150,34 @@ export function combineDefaultMetricFilters(

return intervalMetrics;
}

export function hparamSpecToDefaultFilter(spec: HparamSpec): HparamFilter {
if (spec.domain.type === DomainType.DISCRETE) {
return {
type: DomainType.DISCRETE,
includeUndefined: true,
possibleValues: spec.domain.values,
filterValues: spec.domain.values,
};
}

return {
type: DomainType.INTERVAL,
includeUndefined: true,
minValue: spec.domain.minValue,
maxValue: spec.domain.maxValue,
filterLowerValue: spec.domain.minValue,
filterUpperValue: spec.domain.maxValue,
};
}

export function metricSpecToDefaultFilter(spec: MetricSpec): MetricFilter {
return {
type: DomainType.INTERVAL,
includeUndefined: true,
minValue: -Infinity,
maxValue: Infinity,
filterLowerValue: -Infinity,
filterUpperValue: Infinity,
};
}
43 changes: 20 additions & 23 deletions tensorboard/webapp/metrics/views/main_view/common_selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ import {
} from '../../../selectors';
import {DeepReadonly} from '../../../util/types';
import {
getHparamFilterMapFromExperimentIds,
getMetricFilterMapFromExperimentIds,
getDashboardHparamsAndMetricsSpecs,
getDashboardHparamFilterMap,
getDashboardMetricsFilterMap,
getDashboardDefaultFilters,
} from '../../../hparams/_redux/hparams_selectors';
import {
DiscreteFilter,
Expand Down Expand Up @@ -185,6 +186,20 @@ const utils = {
},
};

export const getCurrentColumnFilters = createSelector(
getDashboardDefaultFilters,
getDashboardHparamFilterMap,
getDashboardMetricsFilterMap,
({hparams, metrics}, hparamFilters, metricFilters) => {
return new Map([
...hparams,
...metrics,
...hparamFilters,
...metricFilters,
]);
}
);

const getRenderableRuns = memoize((experimentIds: string[]) => {
return createSelector(
getDashboardRuns(experimentIds),
Expand Down Expand Up @@ -226,8 +241,8 @@ const getFilteredRenderableRuns = memoize((experimentIds: string[]) => {
return createSelector(
getRunSelectorRegexFilter,
getRenderableRuns(experimentIds),
getHparamFilterMapFromExperimentIds(experimentIds),
getMetricFilterMapFromExperimentIds(experimentIds),
getDashboardHparamFilterMap,
getDashboardMetricsFilterMap,
getRouteKind,
(regexFilter, runItems, hparamFilters, metricFilters, routeKind) => {
const regexFilteredItems = utils.filterRunItemsByRegex(
Expand All @@ -245,28 +260,10 @@ const getFilteredRenderableRuns = memoize((experimentIds: string[]) => {
);
});

const getCurrentColumnFilters = memoize((experimentIds: string[]) => {
return createSelector(
getHparamFilterMapFromExperimentIds(experimentIds, true),
getMetricFilterMapFromExperimentIds(experimentIds, true),
(hparamFilters, metricFilters) => {
return new Map([...hparamFilters, ...metricFilters]);
}
);
});

export const getCurrentColumnFiltersFromRoute = createSelector(
(state) => state,
getExperimentIdsFromRoute,
(state, experimentIds) => {
return getCurrentColumnFilters(experimentIds || [])(state);
}
);

export const getFilteredRenderableRunsFromRoute = createSelector(
(state) => state,
getExperimentIdsFromRoute,
(state, experimentIds) => {
(state, experimentIds): RunTableItem[] => {
return getFilteredRenderableRuns(experimentIds || [])(state);
}
);
Expand Down
27 changes: 8 additions & 19 deletions tensorboard/webapp/runs/views/runs_table/runs_table_container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ import {
} from './runs_table_component';
import {RunsTableColumn, RunTableItem} from './types';
import {
getCurrentColumnFiltersFromRoute,
getCurrentColumnFilters,
getFilteredRenderableRunsFromRoute,
getPotentialHparamColumns,
} from '../../../metrics/views/main_view/common_selectors';
Expand Down Expand Up @@ -375,7 +375,7 @@ export class RunsTableContainer implements OnInit, OnDestroy {
})
);

columnFilters$ = this.store.select(getCurrentColumnFiltersFromRoute);
columnFilters$ = this.store.select(getCurrentColumnFilters);

allRunsTableData$ = this.store
.select(getFilteredRenderableRunsFromRoute)
Expand Down Expand Up @@ -835,23 +835,12 @@ export class RunsTableContainer implements OnInit, OnDestroy {
}

addHparamFilter(event: FilterAddedEvent) {
switch (event.value.type) {
case DomainType.INTERVAL:
this.onHparamIntervalFilterChanged({
name: event.header.name,
includeUndefined: event.value.includeUndefined,
filterLowerValue: (event.value as IntervalFilter).filterLowerValue,
filterUpperValue: (event.value as IntervalFilter).filterUpperValue,
});
break;
case DomainType.DISCRETE:
this.onHparamDiscreteFilterChanged({
hparamName: event.header.name,
includeUndefined: event.value.includeUndefined,
filterValues: (event.value as DiscreteFilter).filterValues,
});
break;
}
this.store.dispatch(
hparamsActions.dashboardHparamFilterAdded({
name: event.header.name,
filter: event.value,
})
);
}
}

Expand Down

0 comments on commit d5fd188

Please sign in to comment.