diff --git a/tensorboard/webapp/hparams/BUILD b/tensorboard/webapp/hparams/BUILD index 8a38a39a921..e93ca9fd9f3 100644 --- a/tensorboard/webapp/hparams/BUILD +++ b/tensorboard/webapp/hparams/BUILD @@ -22,6 +22,7 @@ tf_ts_library( ], deps = [ "//tensorboard/webapp/runs/data_source", + "//tensorboard/webapp/widgets/data_table:types", ], ) diff --git a/tensorboard/webapp/hparams/_types.ts b/tensorboard/webapp/hparams/_types.ts index d018b5c5859..31184729304 100644 --- a/tensorboard/webapp/hparams/_types.ts +++ b/tensorboard/webapp/hparams/_types.ts @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import { - DiscreteHparamValues, - DomainType, HparamSpec, MetricSpec, } from '../runs/data_source/runs_data_source_types'; +export {DiscreteFilter, IntervalFilter} from '../widgets/data_table/types'; + export { DatasetType, DiscreteHparamValue, @@ -33,21 +33,3 @@ export interface HparamAndMetricSpec { hparams: HparamSpec[]; metrics: MetricSpec[]; } - -export interface DiscreteFilter { - type: DomainType.DISCRETE; - includeUndefined: boolean; - possibleValues: DiscreteHparamValues; - // Subset of `possibleValues` - filterValues: DiscreteHparamValues; -} - -export interface IntervalFilter { - type: DomainType.INTERVAL; - includeUndefined: boolean; - minValue: number; - maxValue: number; - // Filter values have to be in between min and max values (inclusive). - filterLowerValue: number; - filterUpperValue: number; -} diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 526c99fbacf..66d783085e6 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -245,6 +245,24 @@ const getFilteredRenderableRuns = memoize((experimentIds: string[]) => { ); }); +const getCurrentColumnFilters = memoize((experimentIds: string[]) => { + return createSelector( + getHparamFilterMapFromExperimentIds(experimentIds), + getMetricFilterMapFromExperimentIds(experimentIds), + (hparamFilters, metricFilters) => { + return new Map([...hparamFilters, ...metricFilters]); + } + ); +}); + +export const getCurrentColumnFiltersFromRoute = createSelector( + (state) => state, + getExperimentIdsFromRoute, + (state, experimentIds) => { + return getCurrentColumnFilters(experimentIds || [])(state); + } +); + export const getFilteredRenderableRunsFromRoute = createSelector( (state) => state, getExperimentIdsFromRoute, @@ -272,6 +290,9 @@ export const getPotentialHparamColumns = createSelector( experimentIds, }); + const hparamFilters = + getHparamFilterMapFromExperimentIds(experimentIds)(state); + return hparams.map((spec) => ({ type: ColumnHeaderType.HPARAM, name: spec.name, @@ -282,6 +303,7 @@ export const getPotentialHparamColumns = createSelector( removable: true, sortable: true, movable: true, + filterable: true, })); } ); diff --git a/tensorboard/webapp/runs/data_source/BUILD b/tensorboard/webapp/runs/data_source/BUILD index b35f4e47051..23162e3d43f 100644 --- a/tensorboard/webapp/runs/data_source/BUILD +++ b/tensorboard/webapp/runs/data_source/BUILD @@ -13,6 +13,7 @@ tf_ng_module( ], deps = [ ":backend_types", + "//tensorboard/webapp/widgets/data_table:types", "//tensorboard/webapp/webapp_data_source:http_client", "@npm//@angular/core", "@npm//rxjs", @@ -24,7 +25,7 @@ tf_ts_library( srcs = [ "runs_backend_types.ts", ], - visibility = ["//visibility:private"], + visibility = ["//tensorboard/webapp/runs/data_source:__subpackages__", "//tensorboard/webapp/widgets/data_table:__subpackages__",], ) tf_ng_module( diff --git a/tensorboard/webapp/runs/data_source/runs_data_source.ts b/tensorboard/webapp/runs/data_source/runs_data_source.ts index 729a3de3367..4ab408b45c3 100644 --- a/tensorboard/webapp/runs/data_source/runs_data_source.ts +++ b/tensorboard/webapp/runs/data_source/runs_data_source.ts @@ -20,9 +20,9 @@ import { TBHttpClient, } from '../../webapp_data_source/tb_http_client'; import * as backendTypes from './runs_backend_types'; +import {DomainType} from '../../widgets/data_table/types'; import { Domain, - DomainType, HparamsAndMetadata, HparamSpec, HparamValue, diff --git a/tensorboard/webapp/runs/data_source/runs_data_source_types.ts b/tensorboard/webapp/runs/data_source/runs_data_source_types.ts index 1865c8cf04c..2daac58d4ae 100644 --- a/tensorboard/webapp/runs/data_source/runs_data_source_types.ts +++ b/tensorboard/webapp/runs/data_source/runs_data_source_types.ts @@ -16,6 +16,9 @@ import {Injectable} from '@angular/core'; import {Observable} from 'rxjs'; import * as backendTypes from './runs_backend_types'; +import {DomainType} from '../../widgets/data_table/types'; +export {DomainType} from '../../widgets/data_table/types'; + export { BackendHparamsValueType as HparamsValueType, DatasetType, @@ -41,11 +44,6 @@ export interface RunToHparamsAndMetrics { }; } -export enum DomainType { - DISCRETE, - INTERVAL, -} - interface IntervalDomain { type: DomainType.INTERVAL; minValue: number; diff --git a/tensorboard/webapp/runs/store/runs_reducers.ts b/tensorboard/webapp/runs/store/runs_reducers.ts index 1d144cb745f..16056953875 100644 --- a/tensorboard/webapp/runs/store/runs_reducers.ts +++ b/tensorboard/webapp/runs/store/runs_reducers.ts @@ -330,6 +330,7 @@ const {initialState: uiInitialState, reducers: uiNamespaceContextedReducers} = enabled: true, sortable: true, movable: true, + filterable: false, }, ], sortingInfo: { diff --git a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html index ea15a8261f7..b66b5a92c36 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html +++ b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html @@ -28,10 +28,12 @@ [sortingInfo]="sortingInfo" [columnCustomizationEnabled]="true" [selectableColumns]="selectableColumns" + [columnFilters]="columnFilters" (sortDataBy)="sortDataBy.emit($event)" (orderColumns)="orderColumns.emit($event)" (addColumn)="addColumn.emit($event)" (removeColumn)="removeColumn.emit($event)" + (addFilter)="addFilter.emit($event)" > diff --git a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ts b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ts index 1dbddad69c5..1fa1ddb258a 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ts @@ -24,8 +24,10 @@ import { TableData, SortingInfo, ColumnHeaderType, + FilterAddedEvent, + DiscreteFilter, + IntervalFilter, } from '../../../widgets/data_table/types'; - @Component({ selector: 'runs-data-table', templateUrl: 'runs_data_table.ng.html', @@ -40,6 +42,7 @@ export class RunsDataTable { @Input() regexFilter!: string; @Input() isFullScreen!: boolean; @Input() selectableColumns!: ColumnHeader[]; + @Input() columnFilters!: Map; ColumnHeaderType = ColumnHeaderType; @@ -59,6 +62,7 @@ export class RunsDataTable { }>(); @Output() removeColumn = new EventEmitter(); @Output() onSelectionDblClick = new EventEmitter(); + @Output() addFilter = new EventEmitter(); // These columns must be stored and reused to stop needless re-rendering of // the content and headers in these columns. This has been known to cause diff --git a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts index 5af7981c05d..9f9c6c3941c 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts @@ -68,7 +68,7 @@ import {matchRunToRegex} from '../../../util/matcher'; import {getEnableHparamsInTimeSeries} from '../../../feature_flag/store/feature_flag_selectors'; import { ColumnHeader, - ColumnHeaderType, + FilterAddedEvent, SortingInfo, SortingOrder, TableData, @@ -96,10 +96,10 @@ import { } from './runs_table_component'; import {RunsTableColumn, RunTableItem} from './types'; import { + getCurrentColumnFiltersFromRoute, getFilteredRenderableRunsFromRoute, getPotentialHparamColumns, } from '../../../metrics/views/main_view/common_selectors'; -import {RunToHParamValues} from '../../data_source/runs_data_source_types'; import {runsTableFullScreenToggled} from '../../../core/actions'; const getRunsLoading = createSelector< @@ -271,6 +271,7 @@ function matchFilter( [headers]="runsColumns$ | async" [data]="sortedRunsTableData$ | async" [selectableColumns]="selectableColumns$ | async" + [columnFilters]="columnFilters$ | async" [sortingInfo]="sortingInfo$ | async" [experimentIds]="experimentIds" [regexFilter]="regexFilter$ | async" @@ -285,6 +286,7 @@ function matchFilter( (toggleFullScreen)="toggleFullScreen()" (addColumn)="addColumn($event)" (removeColumn)="removeColumn($event)" + (addFilter)="addHparamFilter($event)" > `, host: { @@ -367,6 +369,8 @@ export class RunsTableContainer implements OnInit, OnDestroy { }) ); + columnFilters$ = this.store.select(getCurrentColumnFiltersFromRoute); + allRunsTableData$ = this.store .select(getFilteredRenderableRunsFromRoute) .pipe( @@ -823,6 +827,26 @@ export class RunsTableContainer implements OnInit, OnDestroy { useDataTable() { return this.hparamsEnabled.value && !this.forceLegacyTable; } + + addHparamFilter(event: FilterAddedEvent) { + switch (event.value.type) { + case DomainType.INTERVAL: + this.onHparamIntervalFilterChanged({ + name: event.header.name, + includeUndefined: event.value.includeUndefined, + filterLowerValue: (event.value as IntervalFilter).filterLowerValue, + filterUpperValue: (event.value as IntervalFilter).filterUpperValue, + }); + break; + case DomainType.DISCRETE: + this.onHparamDiscreteFilterChanged({ + hparamName: event.header.name, + includeUndefined: event.value.includeUndefined, + filterValues: (event.value as DiscreteFilter).filterValues, + }); + break; + } + } } export const TEST_ONLY = { diff --git a/tensorboard/webapp/widgets/data_table/BUILD b/tensorboard/webapp/widgets/data_table/BUILD index df4c373cc01..221d41290f6 100644 --- a/tensorboard/webapp/widgets/data_table/BUILD +++ b/tensorboard/webapp/widgets/data_table/BUILD @@ -51,6 +51,16 @@ tf_sass_binary( ], ) +tf_sass_binary( + name = "filter_dialogue_styles", + src = "filter_dialogue.scss", + strict_deps = False, + deps = [ + "//tensorboard/webapp:angular_material_sass_deps", + "//tensorboard/webapp/theme", + ], +) + tf_ng_module( name = "data_table", srcs = [ @@ -71,9 +81,11 @@ tf_ng_module( deps = [ ":column_selector", ":data_table_header", + ":filter_dialogue", ":types", "//tensorboard/webapp/angular:expect_angular_material_button", "//tensorboard/webapp/angular:expect_angular_material_icon", + "//tensorboard/webapp/widgets/range_input:types", "//tensorboard/webapp/metrics/views/card_renderer:scalar_card_types", "//tensorboard/webapp/widgets/custom_modal", "//tensorboard/webapp/widgets/line_chart_v2/lib:formatter", @@ -123,6 +135,28 @@ tf_ng_module( ], ) +tf_ng_module( + name = "filter_dialogue", + srcs = [ + "filter_dialogue.ts", + "filter_dialogue_module.ts", + ], + assets = [ + "filter_dialogue.ng.html", + ":filter_dialogue_styles", + ], + deps = [ + ":types", + "//tensorboard/webapp/angular:expect_angular_material_button", + "//tensorboard/webapp/angular:expect_angular_material_checkbox", + "//tensorboard/webapp/widgets/range_input", + "//tensorboard/webapp/widgets/range_input:types", + "@npm//@angular/common", + "@npm//@angular/core", + "@npm//@angular/forms", + ], +) + tf_ts_library( name = "types", srcs = [ diff --git a/tensorboard/webapp/widgets/data_table/data_table_component.ng.html b/tensorboard/webapp/widgets/data_table/data_table_component.ng.html index 2dc5674fdab..9d590b002f9 100644 --- a/tensorboard/webapp/widgets/data_table/data_table_component.ng.html +++ b/tensorboard/webapp/widgets/data_table/data_table_component.ng.html @@ -15,7 +15,7 @@
No Actions Available @@ -47,17 +47,25 @@ +