From d3277a44e38df3e93f7c09b9febe82929a9f3f36 Mon Sep 17 00:00:00 2001 From: Riley Jones <78179109+rileyajones@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:19:52 -0700 Subject: [PATCH] HParams: Create hparams data source to fetch data from the hparams plugin (#6535) ## Motivation for features / changes Today we fetch data from the hparams plugin using the runs data source. The data is then written to both the runs and hparams state. Unfortunately there are a few issues with both this event structure and our current data model. The data model implies a 1:1 mapping between experiment ids and hparam/metric specs. This is not the case when in an experiment view and thus the model will need to be changed. The event structure is inconsistent with our typical redux structure and thus is hard to refactor to work the way we need it to. This is the first PR in a series I am creating to address this issue. Future PRs will: * add a new effects file, action, and hparams state entries for both a `runToHparamsAndMetrics` mapping along with a single `currentSpecs` (name pending). * remove much of the logic from runs data source and update the `getRuns` selector to populate data ## Screenshots of UI changes (or N/A) N/A --- tensorboard/webapp/hparams/BUILD | 1 + tensorboard/webapp/hparams/_redux/BUILD | 18 + .../hparams/_redux/hparams_data_source.ts | 198 +++++++++ .../_redux/hparams_data_source_test.ts | 404 ++++++++++++++++++ .../_redux/hparams_selectors_utils_test.ts | 40 ++ tensorboard/webapp/hparams/_redux/testing.ts | 5 + tensorboard/webapp/hparams/_types.ts | 17 + tensorboard/webapp/runs/data_source/BUILD | 5 +- .../runs/data_source/runs_data_source.ts | 5 +- .../runs/data_source/runs_data_source_test.ts | 16 + .../data_source/runs_data_source_types.ts | 3 +- .../webapp/runs/data_source/testing.ts | 4 + 12 files changed, 711 insertions(+), 5 deletions(-) create mode 100644 tensorboard/webapp/hparams/_redux/hparams_data_source.ts create mode 100644 tensorboard/webapp/hparams/_redux/hparams_data_source_test.ts diff --git a/tensorboard/webapp/hparams/BUILD b/tensorboard/webapp/hparams/BUILD index 8a38a39a921..c3f0b23f586 100644 --- a/tensorboard/webapp/hparams/BUILD +++ b/tensorboard/webapp/hparams/BUILD @@ -22,6 +22,7 @@ tf_ts_library( ], deps = [ "//tensorboard/webapp/runs/data_source", + "//tensorboard/webapp/runs/data_source:backend_types", ], ) diff --git a/tensorboard/webapp/hparams/_redux/BUILD b/tensorboard/webapp/hparams/_redux/BUILD index b7f7cf23c47..b04ba6be5c5 100644 --- a/tensorboard/webapp/hparams/_redux/BUILD +++ b/tensorboard/webapp/hparams/_redux/BUILD @@ -81,6 +81,20 @@ tf_ts_library( ], ) +tf_ng_module( + name = "hparams_data_source", + srcs = [ + "hparams_data_source.ts", + ], + deps = [ + "//tensorboard/webapp/hparams:types", + "//tensorboard/webapp/webapp_data_source:http_client", + "@npm//@angular/core", + "@npm//@ngrx/store", + "@npm//rxjs", + ], +) + tf_ts_library( name = "testing", testonly = True, @@ -98,6 +112,7 @@ tf_ts_library( name = "_redux_test_lib", testonly = True, srcs = [ + "hparams_data_source_test.ts", "hparams_reducers_test.ts", "hparams_selectors_test.ts", "hparams_selectors_utils_test.ts", @@ -105,14 +120,17 @@ tf_ts_library( ], deps = [ ":hparams_actions", + ":hparams_data_source", ":hparams_reducers", ":hparams_selectors", ":testing", ":utils", + "//tensorboard/webapp/angular:expect_angular_core_testing", "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/runs/actions", "//tensorboard/webapp/runs/data_source:testing", "//tensorboard/webapp/runs/store:testing", + "//tensorboard/webapp/webapp_data_source:http_client_testing", "@npm//@types/jasmine", ], ) diff --git a/tensorboard/webapp/hparams/_redux/hparams_data_source.ts b/tensorboard/webapp/hparams/_redux/hparams_data_source.ts new file mode 100644 index 00000000000..a522057ccab --- /dev/null +++ b/tensorboard/webapp/hparams/_redux/hparams_data_source.ts @@ -0,0 +1,198 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {Injectable} from '@angular/core'; +import {Observable} from 'rxjs'; +import {map} from 'rxjs/operators'; + +import { + Domain, + DomainType, + BackendListSessionGroupRequest, + BackendHparamsExperimentResponse, + BackendHparamSpec, + DiscreteDomainHparamSpec, + SessionGroup, + HparamAndMetricSpec, + IntervalDomainHparamSpec, + BackendListSessionGroupResponse, + RunStatus, +} from '../types'; +import {TBHttpClient} from '../../webapp_data_source/tb_http_client'; + +const HPARAMS_HTTP_PATH_PREFIX = 'data/plugin/hparams'; + +function isHparamDiscrete( + hparam: BackendHparamSpec +): hparam is DiscreteDomainHparamSpec { + return Boolean((hparam as DiscreteDomainHparamSpec).domainDiscrete); +} + +function isHparamInterval( + hparam: BackendHparamSpec +): hparam is IntervalDomainHparamSpec { + return Boolean((hparam as IntervalDomainHparamSpec).domainInterval); +} + +function getHparamDomain(hparam: BackendHparamSpec): Domain { + if (isHparamDiscrete(hparam)) { + return { + type: DomainType.DISCRETE, + values: hparam.domainDiscrete, + }; + } + + if (isHparamInterval(hparam)) { + return { + ...hparam.domainInterval, + type: DomainType.INTERVAL, + }; + } + + return { + values: [], + type: DomainType.DISCRETE, + }; +} + +@Injectable() +export class HparamsDataSource { + constructor(private readonly http: TBHttpClient) {} + + private getPrefix(experimentIds: string[]) { + return experimentIds.length > 1 ? 'compare' : 'experiment'; + } + + private formatExperimentIds(experimentIds: string[]) { + if (experimentIds.length === 1) { + return experimentIds[0]; + } + + // The server does not send back experiment ids. Instead the response is formatted as + // `[AliasNumber] ExperimentAlias/RunName` + // By using the index as the alias we can translate associate the response with an experiment id + // Note: The experiment id itself cannot be the alias because it may contain ':' + return experimentIds.map((eid, index) => `${index}:${eid}`).join(','); + } + + fetchExperimentInfo( + experimentIds: string[] + ): Observable { + const formattedExperimentIds = this.formatExperimentIds(experimentIds); + return this.http + .post( + `/${this.getPrefix( + experimentIds + )}/${formattedExperimentIds}/${HPARAMS_HTTP_PATH_PREFIX}/experiment`, + {experimentName: formattedExperimentIds}, + {}, + 'request' + ) + .pipe( + map((response) => { + return { + hparams: response.hparamInfos.map((hparam) => { + const feHparam = { + ...hparam, + domain: getHparamDomain(hparam), + }; + + delete (feHparam as any).domainInterval; + delete (feHparam as any).domainDiscrete; + + return feHparam; + }), + metrics: response.metricInfos.map((metric) => ({ + ...metric, + tag: metric.name.tag, + })), + }; + }) + ); + } + + fetchSessionGroups( + experimentIds: string[], + hparamsAndMetricsSpecs: HparamAndMetricSpec + ): Observable { + const formattedExperimentIds = this.formatExperimentIds(experimentIds); + + const colParams: BackendListSessionGroupRequest['colParams'] = []; + + for (const hparam of hparamsAndMetricsSpecs.hparams) { + colParams.push({hparam: hparam.name}); + } + for (const mectric of hparamsAndMetricsSpecs.metrics) { + colParams.push({ + metric: mectric.name, + }); + } + + const listSessionRequestParams: BackendListSessionGroupRequest = { + experimentName: formattedExperimentIds, + allowedStatuses: [ + RunStatus.STATUS_FAILURE, + RunStatus.STATUS_RUNNING, + RunStatus.STATUS_SUCCESS, + RunStatus.STATUS_UNKNOWN, + ], + colParams, + startIndex: 0, + // arbitrary large number so it does not get clipped. + sliceSize: 1e6, + }; + + return this.http + .post( + `/${this.getPrefix( + experimentIds + )}/${formattedExperimentIds}/${HPARAMS_HTTP_PATH_PREFIX}/session_groups`, + listSessionRequestParams, + {}, + 'request' + ) + .pipe( + map((response) => + response.sessionGroups.map((sessionGroup) => { + sessionGroup.sessions = sessionGroup.sessions.map((session) => { + /* + * In single experiment mode the Session.name is equal to the runName. + * In comparison view it is `[AliasNumber] ExperimentAlias/runName` + * + * We store runs as experimentId/runName so it is necessary to prepend the experiment name + * in single experiment view. "In comparison view we pass the indeces of the experimentIds + * as the aliases in the request. That allows us to parse the indeces from the response and + * use them to lookup the correct ids from the experimentIds argument. + */ + if (experimentIds.length > 1) { + const [, ...aliasAndRunName] = session.name.split(' '); + const [experimentIndex, ...runName] = aliasAndRunName + .join(' ') + .split('/'); + session.name = [ + // This parseInt should not be necessary because JS Arrays DO support indexing by string + experimentIds[parseInt(experimentIndex)], + ...runName, + ].join('/'); + } else { + session.name = [experimentIds[0], session.name].join('/'); + } + return session; + }); + return sessionGroup; + }) + ) + ); + } +} diff --git a/tensorboard/webapp/hparams/_redux/hparams_data_source_test.ts b/tensorboard/webapp/hparams/_redux/hparams_data_source_test.ts new file mode 100644 index 00000000000..086934cdc26 --- /dev/null +++ b/tensorboard/webapp/hparams/_redux/hparams_data_source_test.ts @@ -0,0 +1,404 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {TestBed} from '@angular/core/testing'; +import { + HttpTestingController, + TBHttpClientTestingModule, +} from '../../webapp_data_source/tb_http_client_testing'; +import {HparamsDataSource} from './hparams_data_source'; +import { + BackendHparamsValueType, + DatasetType, + DomainType, + SessionGroup, + BackendHparamSpec, + BackendHparamsExperimentResponse, + BackendListSessionGroupResponse, + RunStatus, +} from '../types'; + +describe('HparamsDataSource Test', () => { + let httpMock: HttpTestingController; + let dataSource: HparamsDataSource; + + beforeEach(async () => { + await TestBed.configureTestingModule({ + imports: [TBHttpClientTestingModule], + providers: [HparamsDataSource], + }).compileComponents(); + + httpMock = TestBed.inject(HttpTestingController); + dataSource = TestBed.inject(HparamsDataSource); + }); + + describe('fetchExperimentInfo', () => { + it('uses /experiment when a single experiment id is provided', () => { + const returnValue = jasmine.createSpy(); + dataSource.fetchExperimentInfo(['eid']).subscribe(returnValue); + httpMock + .expectOne('/experiment/eid/data/plugin/hparams/experiment') + .flush(createHparamsExperimentResponse()); + expect(returnValue).toHaveBeenCalled(); + }); + + it('uses /compare when a multiple experiment ids are provided', () => { + const returnValue = jasmine.createSpy(); + dataSource.fetchExperimentInfo(['eid1', 'eid2']).subscribe(returnValue); + httpMock + .expectOne('/compare/0:eid1,1:eid2/data/plugin/hparams/experiment') + .flush(createHparamsExperimentResponse()); + expect(returnValue).toHaveBeenCalled(); + }); + + it('maps interval and discrete domains to domain', () => { + const returnValue = jasmine.createSpy(); + dataSource.fetchExperimentInfo(['eid']).subscribe(returnValue); + httpMock + .expectOne('/experiment/eid/data/plugin/hparams/experiment') + .flush(createHparamsExperimentResponse()); + expect(returnValue).toHaveBeenCalledWith({ + hparams: [ + { + description: 'describes hparams one', + displayName: 'hparams one', + name: 'hparams1', + type: BackendHparamsValueType.DATA_TYPE_STRING, + domain: { + type: DomainType.INTERVAL, + minValue: -100, + maxValue: 100, + }, + }, + { + description: 'describes hparams two', + displayName: 'hparams two', + name: 'hparams2', + type: BackendHparamsValueType.DATA_TYPE_BOOL, + domain: { + type: DomainType.DISCRETE, + values: ['foo', 'bar', 'baz'], + }, + }, + ], + metrics: [ + { + name: { + tag: 'metrics1', + group: '', + }, + tag: 'metrics1', + displayName: 'Metrics One', + description: 'describe metrics one', + datasetType: DatasetType.DATASET_UNKNOWN, + }, + { + name: { + tag: 'metrics2', + group: 'group', + }, + tag: 'metrics2', + displayName: 'Metrics Two', + description: 'describe metrics two', + datasetType: DatasetType.DATASET_TRAINING, + }, + ], + }); + }); + + it('treats missing domains as discrete domains', () => { + const returnValue = jasmine.createSpy(); + dataSource.fetchExperimentInfo(['eid']).subscribe(returnValue); + httpMock + .expectOne('/experiment/eid/data/plugin/hparams/experiment') + .flush(createHparamsExperimentNoDomainResponse()); + expect(returnValue).toHaveBeenCalledWith( + jasmine.objectContaining({ + hparams: [ + { + description: 'describes hparams one', + displayName: 'hparams one', + name: 'hparams1', + type: BackendHparamsValueType.DATA_TYPE_STRING, + domain: { + type: DomainType.DISCRETE, + values: [], + }, + }, + { + description: 'describes hparams two', + displayName: 'hparams two', + name: 'hparams2', + type: BackendHparamsValueType.DATA_TYPE_BOOL, + domain: { + type: DomainType.DISCRETE, + values: ['foo', 'bar', 'baz'], + }, + }, + ], + }) + ); + }); + }); + + describe('fetchSessionGroups', () => { + it('uses /experiment when a single experiment id is provided', () => { + const returnValue = jasmine.createSpy(); + dataSource + .fetchSessionGroups(['eid'], {hparams: [], metrics: []}) + .subscribe(returnValue); + httpMock + .expectOne('/experiment/eid/data/plugin/hparams/session_groups') + .flush(createHparamsListSessionGroupResponse()); + expect(returnValue).toHaveBeenCalled(); + }); + + it('uses /compare when a multiple experiment ids are provided', () => { + const returnValue = jasmine.createSpy(); + dataSource + .fetchSessionGroups(['eid1', 'eid2'], {hparams: [], metrics: []}) + .subscribe(returnValue); + httpMock + .expectOne('/compare/0:eid1,1:eid2/data/plugin/hparams/session_groups') + .flush(createHparamsListSessionGroupResponse()); + expect(returnValue).toHaveBeenCalled(); + }); + + it('renames Session.name in single experiment view', () => { + let sessionGroups: SessionGroup[] = []; + const callback = (resp: SessionGroup[]) => { + sessionGroups = resp; + }; + dataSource + .fetchSessionGroups(['eid'], {hparams: [], metrics: []}) + .subscribe(callback); + httpMock + .expectOne('/experiment/eid/data/plugin/hparams/session_groups') + .flush(createHparamsListSessionGroupResponse()); + expect(sessionGroups.length).toEqual(2); + expect(sessionGroups[0].sessions[0].name).toEqual('eid/run_name_1'); + }); + + it('renames Session.name to runId in comparison view', () => { + let sessionGroups: SessionGroup[] = []; + const callback = (resp: SessionGroup[]) => { + sessionGroups = resp; + }; + dataSource + .fetchSessionGroups(['eid1', 'eid2'], {hparams: [], metrics: []}) + .subscribe(callback); + + const response = createHparamsListSessionGroupResponse(); + // This is the format expected in comparison view. + response.sessionGroups[0].sessions[0].name = '[1] 0/run_name_1'; + response.sessionGroups[1].sessions[0].name = '[2] 1/run_name_2'; + httpMock + .expectOne('/compare/0:eid1,1:eid2/data/plugin/hparams/session_groups') + .flush(response); + expect(sessionGroups.length).toEqual(2); + expect(sessionGroups[0].sessions[0].name).toEqual('eid1/run_name_1'); + expect(sessionGroups[1].sessions[0].name).toEqual('eid2/run_name_2'); + }); + }); +}); + +export function createHparamsListSessionGroupResponse(): BackendListSessionGroupResponse { + return { + sessionGroups: [ + { + name: 'session_id_1', + hparams: { + hparams1: -100, + hparams2: 'bar', + }, + sessions: [ + { + endTimeSecs: 0, + metricValues: [ + { + name: { + group: '', + tag: 'metrics1', + }, + trainingStep: 1000, + value: 1, + wallTimeSecs: 0, + }, + ], + modelUri: '', + monitorUrl: '', + name: 'run_name_1', + startTimeSecs: 0, + status: RunStatus.STATUS_SUCCESS, + }, + ], + }, + { + name: 'session_id_2', + hparams: { + hparams1: 100, + hparams2: 'foo', + }, + sessions: [ + { + endTimeSecs: 0, + metricValues: [ + { + name: { + group: 'train', + tag: 'metrics1', + }, + trainingStep: 2000, + value: 0.1, + wallTimeSecs: 0, + }, + { + name: { + group: 'test', + tag: 'metrics1', + }, + trainingStep: 5000, + value: 0.6, + wallTimeSecs: 0, + }, + ], + modelUri: '', + monitorUrl: '', + name: 'run_name_2', + startTimeSecs: 0, + status: RunStatus.STATUS_SUCCESS, + }, + { + endTimeSecs: 0, + metricValues: [ + { + name: { + group: 'train', + tag: 'metrics1', + }, + trainingStep: 10000, + value: 0.3, + wallTimeSecs: 0, + }, + { + name: { + group: 'train', + tag: 'metrics2', + }, + trainingStep: 10000, + value: 0, + wallTimeSecs: 0, + }, + ], + modelUri: '', + monitorUrl: '', + name: 'run_name_2', + startTimeSecs: 0, + status: RunStatus.STATUS_RUNNING, + }, + ], + }, + ], + totalSize: 2, + }; +} + +export function createHparamsExperimentResponse(): BackendHparamsExperimentResponse { + return { + description: 'some description', + hparamInfos: [ + { + description: 'describes hparams one', + displayName: 'hparams one', + name: 'hparams1', + type: BackendHparamsValueType.DATA_TYPE_STRING, + domainInterval: {minValue: -100, maxValue: 100}, + }, + { + description: 'describes hparams two', + displayName: 'hparams two', + name: 'hparams2', + type: BackendHparamsValueType.DATA_TYPE_BOOL, + domainDiscrete: ['foo', 'bar', 'baz'], + }, + ], + metricInfos: [ + { + name: { + group: '', + tag: 'metrics1', + }, + displayName: 'Metrics One', + description: 'describe metrics one', + datasetType: DatasetType.DATASET_UNKNOWN, + }, + { + name: { + group: 'group', + tag: 'metrics2', + }, + displayName: 'Metrics Two', + description: 'describe metrics two', + datasetType: DatasetType.DATASET_TRAINING, + }, + ], + name: 'experiment name', + timeCreatedSecs: 1337, + user: 'user name', + }; +} + +export function createHparamsExperimentNoDomainResponse(): BackendHparamsExperimentResponse { + return { + description: 'some description', + hparamInfos: [ + { + description: 'describes hparams one', + displayName: 'hparams one', + name: 'hparams1', + type: BackendHparamsValueType.DATA_TYPE_STRING, + } as BackendHparamSpec, + { + description: 'describes hparams two', + displayName: 'hparams two', + name: 'hparams2', + type: BackendHparamsValueType.DATA_TYPE_BOOL, + domainDiscrete: ['foo', 'bar', 'baz'], + }, + ], + metricInfos: [ + { + name: { + group: '', + tag: 'metrics1', + }, + displayName: 'Metrics One', + description: 'describe metrics one', + datasetType: DatasetType.DATASET_UNKNOWN, + }, + { + name: { + group: 'group', + tag: 'metrics2', + }, + displayName: 'Metrics Two', + description: 'describe metrics two', + datasetType: DatasetType.DATASET_TRAINING, + }, + ], + name: 'experiment name', + timeCreatedSecs: 1337, + user: 'user name', + }; +} diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_utils_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_utils_test.ts index 47eabda41f1..ca0e44c0558 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_utils_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_utils_test.ts @@ -88,12 +88,20 @@ describe('runs selectors utils test', () => { hparams: [], metrics: [ { + name: { + tag: 'acc', + group: '', + }, tag: 'acc', displayName: 'Accuracy', description: '', datasetType: DatasetType.DATASET_TRAINING, }, { + name: { + tag: 'loss', + group: '', + }, tag: 'loss', displayName: 'Loss', description: '', @@ -105,6 +113,10 @@ describe('runs selectors utils test', () => { hparams: [], metrics: [ { + name: { + tag: 'acc', + group: '', + }, tag: 'acc', displayName: 'Accuracy', description: '', @@ -118,12 +130,20 @@ describe('runs selectors utils test', () => { hparams: [], metrics: [ { + name: { + tag: 'acc', + group: '', + }, tag: 'acc', displayName: 'Accuracy', description: '', datasetType: DatasetType.DATASET_TRAINING, }, { + name: { + tag: 'loss', + group: '', + }, tag: 'loss', displayName: 'Loss', description: '', @@ -233,12 +253,20 @@ describe('runs selectors utils test', () => { hparams: [], metrics: [ { + name: { + tag: 'acc', + group: '', + }, tag: 'acc', displayName: 'Accuracy', description: '', datasetType: DatasetType.DATASET_TRAINING, }, { + name: { + tag: 'acc', + group: '', + }, tag: 'acc', displayName: 'Accuracy', description: '', @@ -274,6 +302,10 @@ describe('runs selectors utils test', () => { ], metrics: [ { + name: { + tag: 'acc', + group: '', + }, tag: 'acc', displayName: 'Accuracy', description: '', @@ -298,6 +330,10 @@ describe('runs selectors utils test', () => { ], metrics: [ { + name: { + tag: 'acc', + group: ',', + }, tag: 'acc', displayName: 'Acc', description: '', @@ -324,6 +360,10 @@ describe('runs selectors utils test', () => { ], metrics: [ { + name: { + tag: 'acc', + group: '', + }, tag: 'acc', displayName: 'Accuracy or Acc', description: '', diff --git a/tensorboard/webapp/hparams/_redux/testing.ts b/tensorboard/webapp/hparams/_redux/testing.ts index 8f06ad96c4e..e672d5ca64d 100644 --- a/tensorboard/webapp/hparams/_redux/testing.ts +++ b/tensorboard/webapp/hparams/_redux/testing.ts @@ -99,6 +99,11 @@ export function buildMetricSpec( override: Partial = {} ): MetricSpec { return { + name: { + ...override?.name, + tag: 'metric', + group: 'some group', + }, tag: 'tag', displayName: 'Tag', description: 'This is a tags', diff --git a/tensorboard/webapp/hparams/_types.ts b/tensorboard/webapp/hparams/_types.ts index d018b5c5859..ccde5e297c3 100644 --- a/tensorboard/webapp/hparams/_types.ts +++ b/tensorboard/webapp/hparams/_types.ts @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import { + HparamValue, + MetricValue, DiscreteHparamValues, DomainType, HparamSpec, @@ -27,8 +29,23 @@ export { HparamSpec, HparamsValueType, MetricSpec, + Domain, + HparamValue, + RunToHparamsAndMetrics, } from '../runs/data_source/runs_data_source_types'; +export { + HparamSpec as BackendHparamSpec, + DiscreteDomainHparamSpec, + IntervalDomainHparamSpec, + BackendHparamsExperimentResponse, + BackendListSessionGroupResponse, + BackendListSessionGroupRequest, + BackendHparamsValueType, + RunStatus, + SessionGroup, +} from '../runs/data_source/runs_backend_types'; + export interface HparamAndMetricSpec { hparams: HparamSpec[]; metrics: MetricSpec[]; diff --git a/tensorboard/webapp/runs/data_source/BUILD b/tensorboard/webapp/runs/data_source/BUILD index b35f4e47051..b93da6a6b4d 100644 --- a/tensorboard/webapp/runs/data_source/BUILD +++ b/tensorboard/webapp/runs/data_source/BUILD @@ -24,7 +24,10 @@ tf_ts_library( srcs = [ "runs_backend_types.ts", ], - visibility = ["//visibility:private"], + visibility = [ + "//tensorboard/webapp/hparams:__pkg__", + "//tensorboard/webapp/runs/data_source:__subpackages__", + ], ) tf_ng_module( diff --git a/tensorboard/webapp/runs/data_source/runs_data_source.ts b/tensorboard/webapp/runs/data_source/runs_data_source.ts index 729a3de3367..23ee022c2cf 100644 --- a/tensorboard/webapp/runs/data_source/runs_data_source.ts +++ b/tensorboard/webapp/runs/data_source/runs_data_source.ts @@ -67,10 +67,9 @@ function transformBackendHparamSpec( function transformBackendMetricSpec( metricInfo: backendTypes.MetricSpec ): MetricSpec { - const {name, ...otherSpec} = metricInfo; return { - ...otherSpec, - tag: name.tag, + ...metricInfo, + tag: metricInfo.name.tag, }; } diff --git a/tensorboard/webapp/runs/data_source/runs_data_source_test.ts b/tensorboard/webapp/runs/data_source/runs_data_source_test.ts index 12af554bfcd..461935d6ce9 100644 --- a/tensorboard/webapp/runs/data_source/runs_data_source_test.ts +++ b/tensorboard/webapp/runs/data_source/runs_data_source_test.ts @@ -97,12 +97,20 @@ describe('TBRunsDataSource test', () => { ], metricSpecs: [ { + name: { + tag: 'metrics1', + group: '', + }, tag: 'metrics1', displayName: 'Metrics One', description: 'describe metrics one', datasetType: types.DatasetType.DATASET_UNKNOWN, }, { + name: { + tag: 'metrics2', + group: 'group', + }, tag: 'metrics2', displayName: 'Metrics Two', description: 'describe metrics two', @@ -203,12 +211,20 @@ describe('TBRunsDataSource test', () => { ], metricSpecs: [ { + name: { + tag: 'metrics1', + group: '', + }, tag: 'metrics1', displayName: 'Metrics One', description: 'describe metrics one', datasetType: types.DatasetType.DATASET_UNKNOWN, }, { + name: { + tag: 'metrics2', + group: 'group', + }, tag: 'metrics2', displayName: 'Metrics Two', description: 'describe metrics two', diff --git a/tensorboard/webapp/runs/data_source/runs_data_source_types.ts b/tensorboard/webapp/runs/data_source/runs_data_source_types.ts index 1865c8cf04c..fa33f26cc42 100644 --- a/tensorboard/webapp/runs/data_source/runs_data_source_types.ts +++ b/tensorboard/webapp/runs/data_source/runs_data_source_types.ts @@ -63,7 +63,8 @@ export interface HparamSpec domain: Domain; } -export interface MetricSpec extends Omit { +// TODO(rileyajones) merge these types by deleting this one. +export interface MetricSpec extends backendTypes.MetricSpec { tag: string; } diff --git a/tensorboard/webapp/runs/data_source/testing.ts b/tensorboard/webapp/runs/data_source/testing.ts index e47b0b95f2a..2245cc8e6b5 100644 --- a/tensorboard/webapp/runs/data_source/testing.ts +++ b/tensorboard/webapp/runs/data_source/testing.ts @@ -44,6 +44,10 @@ export function buildHparamsAndMetadata( ], metricSpecs: [ { + name: { + tag: 'metric', + group: 'some group', + }, tag: 'metric', displayName: 'Metric', description: 'This is a metric',