From 5420d0d4ff07e132db170ebc16fc6a2255ef2d3c Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Wed, 4 Dec 2024 11:33:40 -0800 Subject: [PATCH] Frontend for Inference Profile. PiperOrigin-RevId: 702799979 --- frontend/app/common/interfaces/data_table.ts | 39 +++ .../app/components/inference_profile/BUILD | 42 +++ .../inference_profile.ng.html | 133 ++++++++ .../inference_profile/inference_profile.scss | 27 ++ .../inference_profile/inference_profile.ts | 234 ++++++++++++++ .../inference_profile_module.ts | 24 ++ frontend/app/components/main_page/BUILD | 1 + .../components/main_page/main_page_module.ts | 4 + .../app/services/data_service/data_service.ts | 7 +- .../tensorboard_plugin_profile/convert/BUILD | 10 + .../convert/inference_stats_proto_to_gviz.py | 297 ++++++++++++++++++ .../convert/raw_to_tool_data.py | 5 + .../tensorboard_plugin_profile/protobuf/BUILD | 2 + .../protobuf/inference_stats.proto | 285 +++++++++++++++++ 14 files changed, 1109 insertions(+), 1 deletion(-) create mode 100644 frontend/app/components/inference_profile/BUILD create mode 100644 frontend/app/components/inference_profile/inference_profile.ng.html create mode 100644 frontend/app/components/inference_profile/inference_profile.scss create mode 100644 frontend/app/components/inference_profile/inference_profile.ts create mode 100644 frontend/app/components/inference_profile/inference_profile_module.ts create mode 100644 plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py create mode 100644 plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto diff --git a/frontend/app/common/interfaces/data_table.ts b/frontend/app/common/interfaces/data_table.ts index adfd10688..3876f85ca 100644 --- a/frontend/app/common/interfaces/data_table.ts +++ b/frontend/app/common/interfaces/data_table.ts @@ -384,3 +384,42 @@ export type DataTableUnion = SimpleDataTable|FrameworkOpStatsData| TfFunctionExplanationTable|TfFunctionDataTable|MetaHostOpTable|HostOpTable| GeneralAnalysis|InputPipelineAnalysis|InputPipelineHostAnalysis| RunEnvironment|RecommendationResult|RecommendationResult; + +/** The base interface for a property of inference latency. */ +export declare interface InferenceLatencyProperty { + sessionsPerSecond?: string; +} + +/** The metadata property for Inference Profile. */ +export declare interface InferenceProfileMetadataProperty { + modelIdList?: string; + hasBatching?: string; + hasTensorPattern?: string; +} + +/** The metadata for Inference Profile. */ +export declare interface InferenceProfileMetadata extends SimpleDataTable { + p: InferenceProfileMetadataProperty; +} + +/** The data property of Inference Profile. */ +export declare interface InferenceProfileDataProperty { + throughput?: string; + averageLatencyMs?: string; + tableExplanation?: string; + hasBatchingParam?: string; + batchingParamNumBatchThreads?: string; + batchingParamMaxBatchSize?: string; + batchingParamBatchTimeoutMicros?: string; + batchingParamMaxEnqueuedBatches?: string; + batchingParamAllowedBatchSizes?: string; +} + +/** The data of Inference Profile. */ +export declare interface InferenceProfileData extends SimpleDataTable { + p: InferenceProfileDataProperty; +} + +/** All Inference Stats page data table type. */ +export type InferenceProfileTable = + |InferenceProfileMetadata|InferenceProfileData; diff --git a/frontend/app/components/inference_profile/BUILD b/frontend/app/components/inference_profile/BUILD new file mode 100644 index 000000000..36cfc0612 --- /dev/null +++ b/frontend/app/components/inference_profile/BUILD @@ -0,0 +1,42 @@ +load("@io_bazel_rules_sass//:defs.bzl", "sass_binary") +load("//defs:defs.bzl", "xprof_ng_module") + +package(default_visibility = ["//frontend:internal"]) + +xprof_ng_module( + name = "inference_profile", + srcs = [ + "inference_profile.ts", + "inference_profile_module.ts", + ], + assets = [ + ":inference_profile_css", + "inference_profile.ng.html", + ], + deps = [ + "@npm//@angular/common", + "@npm//@angular/core", + "@npm//@angular/router", + "@npm//@ngrx/store", + "@npm//@types/google.visualization", + "@npm//rxjs", + "@org_xprof//frontend/app/common/angular:angular_material_core", + "@org_xprof//frontend/app/common/angular:angular_material_progress_bar", + "@org_xprof//frontend/app/common/angular:angular_material_select", + "@org_xprof//frontend/app/common/angular:angular_material_sidenav", + "@org_xprof//frontend/app/common/interfaces", + "@org_xprof//frontend/app/common/utils", + "@org_xprof//frontend/app/components/chart/table", + "@org_xprof//frontend/app/services/data_service", + "@org_xprof//frontend/app/store", + ], +) + +sass_binary( + name = "inference_profile_css", + src = "inference_profile.scss", + sourcemap = False, + deps = [ + "@org_xprof//frontend/app/styles:common", + ], +) diff --git a/frontend/app/components/inference_profile/inference_profile.ng.html b/frontend/app/components/inference_profile/inference_profile.ng.html new file mode 100644 index 000000000..349a8235d --- /dev/null +++ b/frontend/app/components/inference_profile/inference_profile.ng.html @@ -0,0 +1,133 @@ + +
+

Model Selection

+
+ + Select Model + + + {{ id }} + + + +
+
+ + +
+

Request level analysis

+
+ +
+ Request throughput: {{ allRequestProperties[selectedIndex]?.throughput }} requests/sec +
+
+ Request average Latency: {{ allRequestProperties[selectedIndex]?.averageLatencyMs }} ms +
+ + +
+

All the time related columns displayed in this table are in milliseconds.

+

Request details table, percentile is calculated based on column:

+ + + + {{ column }} + + + + + +
+
+
+
+ + +
+

Batch level analysis

+
+ +
+ Batch throughput: {{ allBatchProperties[selectedIndex]?.throughput }} batches/sec +
+
+ Batch average Latency: {{ allBatchProperties[selectedIndex]?.averageLatencyMs }} ms +
+ + +
+ This model is controlled by the following batching parameters:
+
    +
  • + Number of batching threads: + {{ allBatchProperties[selectedIndex]?.batchingParamNumBatchThreads }} +
  • + +
  • + Batch timeout in microseconds: + {{ allBatchProperties[selectedIndex]?.batchingParamBatchTimeoutMicros }} +
  • + +
  • + Maximum size of a batch: + {{ allBatchProperties[selectedIndex]?.batchingParamMaxBatchSize }} +
  • + +
  • + Maximum number of enqueued batches: + {{ allBatchProperties[selectedIndex]?.batchingParamMaxEnqueuedBatches }} +
  • + +
  • + Sizes that are allowed to form a batch: + {{ allBatchProperties[selectedIndex]?.batchingParamAllowedBatchSizes }} +
  • +
+
+ + +
+

Batch details table, percentile is calculated based on column:

+ + + + {{ column }} + + + + + +
+
+
+
+ + +
+

Tensor transfer analysis

+
+
+ + +
+
+
+
+
diff --git a/frontend/app/components/inference_profile/inference_profile.scss b/frontend/app/components/inference_profile/inference_profile.scss new file mode 100644 index 000000000..9230b773c --- /dev/null +++ b/frontend/app/components/inference_profile/inference_profile.scss @@ -0,0 +1,27 @@ +@import 'frontend/app/styles/common'; + +:host { + padding: 20px; +} + +.section-container { + margin: 20px; +} + +.model-selector { + width: 50%; + min-width: 300px; +} + +.row { + margin-top: 10px; +} + +.session_text { + width: 100%; +} + +mat-progress-bar { + width: 300px; + margin: 0 auto; +} diff --git a/frontend/app/components/inference_profile/inference_profile.ts b/frontend/app/components/inference_profile/inference_profile.ts new file mode 100644 index 000000000..406e07494 --- /dev/null +++ b/frontend/app/components/inference_profile/inference_profile.ts @@ -0,0 +1,234 @@ +import {Component, OnDestroy} from '@angular/core'; +import {ActivatedRoute} from '@angular/router'; +import {Store} from '@ngrx/store'; +import {InferenceProfileData, InferenceProfileDataProperty, InferenceProfileMetadata, InferenceProfileTable,} from 'org_xprof/frontend/app/common/interfaces/data_table'; +import {NavigationEvent} from 'org_xprof/frontend/app/common/interfaces/navigation_event'; +import {setLoadingState} from 'org_xprof/frontend/app/common/utils/utils'; +import {DataService} from 'org_xprof/frontend/app/services/data_service/data_service'; +import {setCurrentToolStateAction} from 'org_xprof/frontend/app/store/actions'; +import {ReplaySubject} from 'rxjs'; +import {takeUntil} from 'rxjs/operators'; + +/** An inference profile component. */ +@Component({ + standalone: false, + selector: 'inference-profile', + templateUrl: './inference_profile.ng.html', + styleUrls: ['./inference_profile.css'], +}) +export class InferenceProfile implements OnDestroy { + readonly tool = 'inference_profile'; + run = ''; + tag = ''; + host = ''; + /** Handles on-destroy Subject, used to unsubscribe. */ + private readonly destroyed = new ReplaySubject(1); + + // All the model IDs and data. + hasBatching: boolean = false; + hasTensorPattern: boolean = false; + allModelIds: string[] = []; + allRequestTables: google.visualization.DataTable[] = []; + allRequestProperties: InferenceProfileDataProperty[] = []; + allBatchTables: google.visualization.DataTable[] = []; + allBatchProperties: InferenceProfileDataProperty[] = []; + allTensorPatternTables: google.visualization.DataTable[] = []; + allTensorPatternProperties: InferenceProfileDataProperty[] = []; + // Selected model. + selectedIndex = 0; + isInitialLoad = true; + loading = true; + + requestView?: google.visualization.DataView; + batchView?: google.visualization.DataView; + tensorPatternView?: google.visualization.DataView; + + // Names of the columns that can be used to compute percentile. + requestPercentileColumns: string[] = []; + batchPercentileColumns: string[] = []; + // By default, request table and batch table both select the first column + // "Latency" to compute percentile. + requestPercentileIndex = 0; + batchPercentileIndex = 0; + + constructor( + route: ActivatedRoute, + private readonly dataService: DataService, + private readonly store: Store<{}>, + ) { + route.params.pipe(takeUntil(this.destroyed)).subscribe((params) => { + if (params as NavigationEvent) { + this.run = (params as NavigationEvent).run || ''; + this.tag = (params as NavigationEvent).tag || 'inference_profile'; + this.host = (params as NavigationEvent).host || ''; + } + this.update(); + }); + this.store.dispatch(setCurrentToolStateAction({currentTool: this.tool})); + } + + parseMetadata(metadataOrNull: InferenceProfileTable) { + if (!metadataOrNull) return false; + const metadata = (metadataOrNull as InferenceProfileMetadata).p; + this.hasBatching = metadata.hasBatching === 'true'; + this.hasTensorPattern = metadata.hasTensorPattern === 'true'; + + let parseSuccess = true; + this.allModelIds = []; + for (const row of metadataOrNull.rows || []) { + const modelName = row.c?.[0]?.v; + if (modelName) { + this.allModelIds.push(String(modelName)); + } else { + parseSuccess = false; + } + } + return parseSuccess; + } + + parseData(data: InferenceProfileTable[]) { + // is a list of tables. + // The first table is a metadat table, which contains model ids and whether + // this inference job is using batching. + // By default there is one table per model. + // If batching is enabled, there is one additional table per model for + // batching related metrics. + // If tensor pattern is recorded, there is one additional table per model + // for tensor pattern results. + if (!data || data.length <= 1) return false; + if (!this.parseMetadata(data[0])) return false; + // Check the number of inference tables is correct, and parse table data. + let expectedNum = 1 + this.allModelIds.length; + if (this.hasBatching) expectedNum += this.allModelIds.length; + if (this.hasTensorPattern) expectedNum += this.allModelIds.length; + console.log('expectedNum', expectedNum); + console.log('data', data.length); + if (data.length !== expectedNum) return false; + console.log('data', data); + for (let i = 1; i < data.length; ) { + const requestData = data[i] as InferenceProfileData; + this.allRequestTables.push( + new google.visualization.DataTable(requestData), + ); + this.allRequestProperties.push( + requestData.p as InferenceProfileDataProperty, + ); + i++; + + if (this.hasBatching) { + const batchData = data[i] as InferenceProfileData; + this.allBatchTables.push(new google.visualization.DataTable(batchData)); + this.allBatchProperties.push( + batchData.p as InferenceProfileDataProperty, + ); + i++; + } + + if (this.hasTensorPattern) { + const tensorPatternData = data[i] as InferenceProfileData; + this.allTensorPatternTables.push( + new google.visualization.DataTable(tensorPatternData), + ); + this.allTensorPatternProperties.push( + tensorPatternData.p as InferenceProfileDataProperty, + ); + i++; + } + } + console.log('allRequestTables', this.allRequestTables); + console.log('allRequestProperties', this.allRequestProperties); + console.log('allBatchTables', this.allBatchTables); + console.log('allBatchProperties', this.allBatchProperties); + console.log('allTensorPatternTables', this.allTensorPatternTables); + console.log('allTensorPatternProperties', this.allTensorPatternProperties); + + return true; + } + + updateView() { + this.requestView = new google.visualization.DataView( + this.allRequestTables[this.selectedIndex], + ); + // Percentile, Request ID, Batch ID and Trace Viewer URL are not sortable, + // all the other columns are sortable. + this.requestPercentileColumns = []; + for (let i = 0; i < this.requestView.getNumberOfColumns(); i++) { + const label = this.requestView.getColumnLabel(i); + if ( + label !== 'Request ID' && + label !== 'Percentile' && + label !== 'Trace Viewer URL' + ) { + this.requestPercentileColumns.push(label); + } + } + + this.batchPercentileColumns = []; + if (this.allBatchTables.length !== 0) { + this.batchView = new google.visualization.DataView( + this.allBatchTables[this.selectedIndex], + ); + // Percentile, Request ID, Batch ID and Trace Viewer URL are not sortable, + // all the other columns are sortable. + for (let i = 0; i < this.batchView.getNumberOfColumns(); i++) { + const label = this.batchView.getColumnLabel(i); + if ( + label !== 'Batch ID' && + label !== 'Percentile' && + label !== 'Trace Viewer URL' + ) { + this.batchPercentileColumns.push(label); + } + } + } + + if (this.allTensorPatternTables.length !== 0) { + this.tensorPatternView = new google.visualization.DataView( + this.allTensorPatternTables[this.selectedIndex], + ); + } + } + + update() { + if (this.isInitialLoad) { + setLoadingState(true, this.store, 'Loading inference profile data'); + } + // Clear the old data from previous update(). + this.allRequestTables = []; + this.allBatchTables = []; + + this.dataService + .getData( + this.run, + this.tag, + this.host, + new Map([ + [ + 'request_column', + this.requestPercentileColumns[this.requestPercentileIndex], + ], + [ + 'batch_column', + this.batchPercentileColumns[this.batchPercentileIndex], + ], + ]), + ) + .pipe(takeUntil(this.destroyed)) + .subscribe((data) => { + if (this.isInitialLoad) { + setLoadingState(false, this.store); + this.isInitialLoad = false; + } + this.loading = false; + if (!this.parseData(data as InferenceProfileTable[])) return; + this.updateView(); + }); + } + + ngOnDestroy() { + // Unsubscribes all pending subscriptions. + setLoadingState(false, this.store); + this.destroyed.next(); + this.destroyed.complete(); + } +} diff --git a/frontend/app/components/inference_profile/inference_profile_module.ts b/frontend/app/components/inference_profile/inference_profile_module.ts new file mode 100644 index 000000000..44fce76d7 --- /dev/null +++ b/frontend/app/components/inference_profile/inference_profile_module.ts @@ -0,0 +1,24 @@ +import {CommonModule} from '@angular/common'; +import {NgModule} from '@angular/core'; +import {MatOptionModule} from '@angular/material/core'; +import {MatProgressBarModule} from '@angular/material/progress-bar'; +import {MatSelectModule} from '@angular/material/select'; +import {MatSidenavModule} from '@angular/material/sidenav'; +import {TableModule} from 'org_xprof/frontend/app/components/chart/table/table_module'; + +import {InferenceProfile} from './inference_profile'; + +@NgModule({ + imports: [ + CommonModule, + TableModule, + MatOptionModule, + MatSelectModule, + MatSidenavModule, + MatProgressBarModule, + ], + declarations: [InferenceProfile], + exports: [InferenceProfile], +}) +export class InferenceProfileModule { +} diff --git a/frontend/app/components/main_page/BUILD b/frontend/app/components/main_page/BUILD index 33ed38933..6dfaf74b0 100644 --- a/frontend/app/components/main_page/BUILD +++ b/frontend/app/components/main_page/BUILD @@ -30,6 +30,7 @@ xprof_ng_module( "@org_xprof//frontend/app/components/framework_op_stats:framework_op_stats_adapter", "@org_xprof//frontend/app/components/graph_viewer", "@org_xprof//frontend/app/components/hlo_stats", + "@org_xprof//frontend/app/components/inference_profile", "@org_xprof//frontend/app/components/input_pipeline", "@org_xprof//frontend/app/components/kernel_stats:kernel_stats_adapter", "@org_xprof//frontend/app/components/memory_profile", diff --git a/frontend/app/components/main_page/main_page_module.ts b/frontend/app/components/main_page/main_page_module.ts index 52c49d230..6bb5ec2d4 100644 --- a/frontend/app/components/main_page/main_page_module.ts +++ b/frontend/app/components/main_page/main_page_module.ts @@ -14,6 +14,8 @@ import {GraphViewer} from 'org_xprof/frontend/app/components/graph_viewer/graph_ import {GraphViewerModule} from 'org_xprof/frontend/app/components/graph_viewer/graph_viewer_module'; import {HloStats} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats'; import {HloStatsModule} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats_module'; +import {InferenceProfile} from 'org_xprof/frontend/app/components/inference_profile/inference_profile'; +import {InferenceProfileModule} from 'org_xprof/frontend/app/components/inference_profile/inference_profile_module'; import {InputPipeline} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline'; import {InputPipelineModule} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline_module'; import {KernelStatsAdapter, KernelStatsAdapterModule} from 'org_xprof/frontend/app/components/kernel_stats/kernel_stats_adapter'; @@ -63,6 +65,7 @@ export const routes: Routes = [ {path: 'trace_viewer^', component: TraceViewer}, {path: 'graph_viewer^', component: GraphViewer}, {path: 'dcn_collective_stats^', component: DcnCollectiveStats}, + {path: 'inference_profile^', component: InferenceProfile}, {path: 'hlo_stats', component: HloStats}, {path: 'hlo_stats^', component: HloStats}, {path: '**', component: EmptyPage}, @@ -92,6 +95,7 @@ export const routes: Routes = [ FrameworkOpStatsAdapterModule, DcnCollectiveStatsModule, HloStatsModule, + InferenceProfileModule, RouterModule.forRoot(routes), ], exports: [MainPage] diff --git a/frontend/app/services/data_service/data_service.ts b/frontend/app/services/data_service/data_service.ts index 5ca3c8868..35e89d834 100644 --- a/frontend/app/services/data_service/data_service.ts +++ b/frontend/app/services/data_service/data_service.ts @@ -74,7 +74,9 @@ export class DataService { return this.httpClient.get(this.pathPrefix + RUN_TOOLS_API, {params}); } - getData(run: string, tag: string, host: string): Observable { + getData( + run: string, tag: string, host: string, + parameters: Map = new Map()): Observable { if (this.isLocalDevelopment) { if (tag.startsWith('overview_page')) { return of(mockData.DATA_PLUGIN_PROFILE_OVERVIEW_PAGE_DATA) @@ -109,6 +111,9 @@ export class DataService { } const params = new HttpParams().set('run', run).set('tag', tag).set('host', host); + parameters.forEach((value, key) => { + params.set(key, value); + }); return this.httpClient.get(this.pathPrefix + DATA_API, {params}) as Observable; } diff --git a/plugin/tensorboard_plugin_profile/convert/BUILD b/plugin/tensorboard_plugin_profile/convert/BUILD index 26d574db0..e450bde98 100644 --- a/plugin/tensorboard_plugin_profile/convert/BUILD +++ b/plugin/tensorboard_plugin_profile/convert/BUILD @@ -247,6 +247,7 @@ py_library( deps = [ ":dcn_collective_stats_proto_to_gviz", ":hlo_stats_proto_to_gviz", + ":inference_stats_proto_to_gviz", ":input_pipeline_proto_to_gviz", ":kernel_stats_proto_to_gviz", ":overview_page_proto_to_gviz", @@ -256,3 +257,12 @@ py_library( "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2", ], ) + +py_library( + name = "inference_stats_proto_to_gviz", + srcs = ["inference_stats_proto_to_gviz.py"], + deps = [ + requirement("gviz_api"), + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2", + ], +) diff --git a/plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py b/plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py new file mode 100644 index 000000000..b54d0bdc9 --- /dev/null +++ b/plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py @@ -0,0 +1,297 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""For conversion of InferenceStats proto to gviz tables. + +Usage: + gviz_data_tables = generate_all_chart_tables(inference_stats) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gviz_api + +from tensorboard_plugin_profile.protobuf import inference_stats_pb2 + + +def pico_to_milli(ps: float) -> float: + """Converts picoseconds to milliseconds.""" + return ps / 1e9 + + +def _add_request_details( + request_detail: inference_stats_pb2.RequestDetail, + percentile: str, + request_id: str, + has_batching: bool, + is_tpu: bool, + throughput: str, +): + """Generates the request details row.""" + row = [ + percentile, + request_id, + pico_to_milli(request_detail.end_time_ps - request_detail.start_time_ps), + ] + if has_batching: + row.append(request_detail.batching_request_size) + row.append(pico_to_milli(request_detail.batching_request_delay_ps)) + row.append(throughput) + if is_tpu: + row.append(pico_to_milli(request_detail.host_preprocessing_ps)) + row.append(pico_to_milli(request_detail.host_runtime_ps)) + row.append(pico_to_milli(request_detail.write_to_device_time_ps)) + row.append(pico_to_milli(request_detail.read_from_device_time_ps)) + row.append(pico_to_milli(request_detail.device_time_ps)) + row.append(pico_to_milli(request_detail.host_postprocessing_ps)) + row.append(pico_to_milli(request_detail.idle_time_ps)) + return row + + +def _create_request_table( + per_model_stats: inference_stats_pb2.PerModelInferenceStats, + has_batching: bool, + is_tpu: bool, +): + """Generates the request table.""" + columns = [ + ("percentile", "string", "Percentile"), + ("request_id", "string", "Request ID"), + ("latency_ms", "number", "Latency (ms)"), + ] + if has_batching: + columns.append(("batching_request_size", "number", "Batching Request Size")) + columns.append( + ("host_batch_information", "number", "Host Batch Information") + ) + columns.append(("throughput", "string", "Throughput")) + if is_tpu: + columns.append(("host_preprocessing", "number", "Host Preprocessing")) + columns.append(("host_runtime", "number", "Host Runtime")) + columns.append(("data_transfer_h2d", "number", "Data transfer H2D")) + columns.append(("data_transfer_d2h", "number", "Data transfer D2H")) + columns.append(("device_compute", "number", "Device compute")) + columns.append(("host_postprocess", "number", "Host Postprocessing")) + columns.append(("idle_time", "number", "Idle Time")) + data = [] + for request_detail in per_model_stats.per_batch_size_aggregated_result: + data.append( + _add_request_details( + request_detail.aggregated_request_result, + "Batch Size {}".format(request_detail.batch_size), + "N/A", + has_batching, + is_tpu, + "{:.1f}".format(request_detail.batch_throughput), + ) + ) + data.append( + _add_request_details( + per_model_stats.aggregated_request_detail, + "Aggregated", + "N/A", + has_batching, + is_tpu, + "{:.1f}".format(per_model_stats.request_throughput), + ) + ) + custom_properties = { + "throughput": "{:.1f}".format(per_model_stats.request_throughput), + "averageLatencyMs": "{:.3f}".format( + per_model_stats.request_average_latency_us / 1e3 + ), + } + return gviz_api.DataTable(columns, data, custom_properties) + + +def _generate_batch_details( + batch_detail: inference_stats_pb2.BatchDetail, + percentile: str, + batch_id: str, + throughput: str, +): + """Generates the batch details row.""" + return [ + percentile, + batch_id, + batch_detail.end_time_ps - batch_detail.start_time_ps, + batch_detail.padding_amount, + batch_detail.batch_size_after_padding, + (batch_detail.batch_size_after_padding - batch_detail.padding_amount) + / batch_detail.batch_size_after_padding, + batch_detail.batch_delay_ps, + throughput, + ] + + +def _generate_batch_table( + per_model_stats: inference_stats_pb2.PerModelInferenceStats, + model_id_database: inference_stats_pb2.ModelIdDatabase, + model_id: str, +): + """Generates the batch table.""" + columns = [ + ("percentile", "string", "Percentile"), + ("batch_id", "string", "Batch ID"), + ("latency", "number", "Latency (ms)"), + ("padding_amount", "number", "Padding Amount"), + ("batch_size_after_padding", "number", "Batch Size After Padding"), + ("batching_efficiency", "number", "Batch Efficiency"), + ("batch_delay_ms", "number", "Batch Delay (ms)"), + ("throughput", "string", "Throughput"), + ] + data = [] + properties = {} + properties["throughput"] = "{:.1f}".format(per_model_stats.batch_throughput) + properties["averageLatencyMs"] = "{:.3f}".format( + per_model_stats.batch_average_latency_us / 1e3 + ) + + if model_id in model_id_database.id_to_batching_params: + params = model_id_database.id_to_batching_params[model_id] + properties["hasBatchingParam"] = "true" + properties["batchingParamNumBatchThreads"] = str(params.num_batch_threads) + properties["batchingParamMaxBatchSize"] = str(params.max_batch_size) + properties["batchingParamBatchTimeoutMicros"] = str( + params.batch_timeout_micros + ) + properties["batchingParamMaxEnqueuedBatches"] = str( + params.max_enqueued_batches + ) + properties["batchingParamAllowedBatchSizes"] = str( + params.allowed_batch_sizes + ) + else: + properties["hasBatchingParam"] = "false" + for batch_detail in per_model_stats.per_batch_size_aggregated_result: + data.append( + _generate_batch_details( + batch_detail.aggregated_batch_result, + "Batch Size {}".format(batch_detail.batch_size), + "N/A", + "{:.1f}".format(batch_detail.batch_throughput), + ) + ) + data.append( + _generate_batch_details( + per_model_stats.aggregated_batch_detail, + "Aggregated", + "N/A", + "{:.1f}".format(per_model_stats.batch_throughput), + ) + ) + return gviz_api.DataTable(columns, data, properties) + + +def _generate_tensor_pattern_table( + per_model_inference_stats: inference_stats_pb2.PerModelInferenceStats, + tensor_pattern_db: inference_stats_pb2.TensorPatternDatabase, +): + """Generates the tensor pattern table.""" + table_description = [ + ("id", "number", "ID"), + ("tensor_pattern", "string", "Tensor Pattern"), + ("count", "number", "Number of Occurrence"), + ("percentile", "string", "Linearize/Delinearize latency"), + ] + data = [] + for counter, aggregated_result in enumerate( + per_model_inference_stats.tensor_transfer_aggregated_result.tensor_pattern_results + ): + tensor_pattern = tensor_pattern_db.tensor_pattern[ + aggregated_result.tensor_pattern_index + ] + data.append([ + counter, + tensor_pattern, + aggregated_result.count, + aggregated_result.linearize_delinearize_percentile_time, + ]) + return gviz_api.DataTable(table_description, data) + + +def _generate_per_model_inference_table( + inference_stats: inference_stats_pb2.InferenceStats, + sorted_model_ids: list[str], + has_batching: bool, + is_tpu: bool, +): + """Generates the per model inference table.""" + tables = [] + for model_id in sorted_model_ids: + try: + model_index = inference_stats.model_id_db.id_to_index[model_id] + per_model_stats = inference_stats.inference_stats_per_model[model_index] + tables.append( + _create_request_table(per_model_stats, has_batching, is_tpu) + ) + if has_batching: + tables.append( + _generate_batch_table( + per_model_stats, inference_stats.model_id_db, model_id + ) + ) + if inference_stats.tensor_pattern_db.tensor_pattern: + tables.append( + _generate_tensor_pattern_table( + per_model_stats, inference_stats.tensor_pattern_db + ) + ) + except KeyError: + continue + return tables + + +def generate_all_chart_tables( + inference_stats: inference_stats_pb2.InferenceStats, +): + """Converts a InferenceStats proto to gviz DataTables.""" + sorted_model_ids = [x for x in inference_stats.model_id_db.ids] + sorted_model_ids.sort() + has_batching = False + for _, per_model_stats in inference_stats.inference_stats_per_model.items(): + if per_model_stats.batch_details: + has_batching = True + break + is_tpu = True + table_properties = { + "hasBatching": "{}".format(has_batching).lower(), + "hasTensorPattern": "false", + } + columns = [ + ("model_name", "string", "Model Name"), + ] + data = [] + for model_id in sorted_model_ids: + data.append([model_id]) + return [ + gviz_api.DataTable(columns, data, table_properties), + *_generate_per_model_inference_table( + inference_stats, + sorted_model_ids, + has_batching, + is_tpu, + ), + ] + + +def to_json(raw_data): + """Converts a serialized DcnCollectiveAnalysis string to json.""" + inference_stats = inference_stats_pb2.InferenceStats() + inference_stats.ParseFromString(raw_data) + all_chart_tables = generate_all_chart_tables(inference_stats) + json_join = ",".join(x.ToJSon() if x else "{}" for x in all_chart_tables) + return "[" + json_join + "]" diff --git a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py index 095e35402..57c0f7909 100644 --- a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py +++ b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py @@ -28,6 +28,7 @@ from tensorflow.python.profiler.internal import _pywrap_profiler_plugin # pylint: disable=g-direct-tensorflow-import from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz from tensorboard_plugin_profile.convert import hlo_stats_proto_to_gviz +from tensorboard_plugin_profile.convert import inference_stats_proto_to_gviz from tensorboard_plugin_profile.convert import input_pipeline_proto_to_gviz from tensorboard_plugin_profile.convert import kernel_stats_proto_to_gviz from tensorboard_plugin_profile.convert import overview_page_proto_to_gviz @@ -204,6 +205,10 @@ def xspace_to_tool_data( raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) if success: data = dcn_collective_stats_proto_to_gviz.to_json(raw_data) + elif tool == 'inference_profile': + raw_data, success = xspace_wrapper_func(xspace_paths, tool) + if success: + data = inference_stats_proto_to_gviz.to_json(raw_data) else: logger.warning('%s is not a known xplane tool', tool) return data, content_type diff --git a/plugin/tensorboard_plugin_profile/protobuf/BUILD b/plugin/tensorboard_plugin_profile/protobuf/BUILD index 9d0487004..124bac7e7 100644 --- a/plugin/tensorboard_plugin_profile/protobuf/BUILD +++ b/plugin/tensorboard_plugin_profile/protobuf/BUILD @@ -13,6 +13,7 @@ proto_library( "dcn_slack_analysis.proto", "diagnostics.proto", "hlo_stats.proto", + "inference_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", @@ -30,6 +31,7 @@ py_proto_library( "dcn_slack_analysis.proto", "diagnostics.proto", "hlo_stats.proto", + "inference_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", diff --git a/plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto b/plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto new file mode 100644 index 000000000..47f378dbe --- /dev/null +++ b/plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto @@ -0,0 +1,285 @@ +// This proto is used for inference-specific analysis. +syntax = "proto2"; + +package tensorflow.profiler; + +message TensorEventDetail { + // The index of the tensor pattern in TensorPatternDatabase. + optional int32 tensor_pattern_index = 1; + + // The owner of this TensorEventDetail. + enum TensorEventOwner { + // Unknown. This should not happen in production code. + UNKNOWN = 0; + + // Owned by the request. + REQUEST = 1; + + // Owned by the batch. + BATCH = 2; + } + + // If batching is enabled, the TensorEventDetails in BatchDetail will have + // owner = BATCH, and they are counted when calculating statistics like the + // number of occurrence for each tensor pattern. The TensorEventDetails in + // RequestDetail will have owner = BATCH, which means the tensor events + // actually happen in the batch, and they are not counted when calculating + // various statistics. + // If batching is not enabled, the TensorEventDetail will only appear in + // RequestDetail and the owner will only be REQUEST. + optional TensorEventOwner owner = 2; + + // Total time in picosecs spent on linearize and delinearize tensors. + optional uint64 linearize_delinearize_time_ps = 3; +} + +// Detail of a user facing request. +// Next ID: 22 +message RequestDetail { + // Request id. + optional int64 request_id = 10 [default = -1]; + + // An index to the model_id inside InferenceStats below. Storing index + // instead of string to save space. It will be -1 if the model id is not + // given. + optional int32 model_id_index = 8 [default = -1]; + + // Start-time of the request in picosecs. + optional uint64 start_time_ps = 1 [default = 0]; + + // End-time of the request in picosecs. + optional uint64 end_time_ps = 2 [default = 0]; + + // Total time in picosecs in this request spent on device. + optional uint64 device_time_ps = 7 [default = 0]; + + // Total time in picosecs in this request spent on writes to device. + optional uint64 write_to_device_time_ps = 5 [default = 0]; + + // Total time in picosecs in this request spent on reads from device. + optional uint64 read_from_device_time_ps = 6 [default = 0]; + + // If this inference request is running in batching mode, record the latency + // between a request is scheduled and is processed in a batch. Otherwise, it + // will always be 0. + optional uint64 batching_request_delay_ps = 12 [default = 0]; + + // Batch ids related to this request. + repeated int64 related_batch_ids = 11; + + // If this inference request is running in batching mode, record the size of + // the request. Otherwise, it will always be 0. + optional int32 batching_request_size = 13; + + // Detailed breakdown for host side activities of a request. + // Total time in picosecs spent on host preprocessing. + optional uint64 host_preprocessing_ps = 14; + + // Total time in picosecs spent on host batch formation. + optional uint64 host_batch_formation_ps = 15; + + // Total time in picosecs spent on host runtime. + optional uint64 host_runtime_ps = 16; + + // Total time in picosecs spent on host postprocessing. + optional uint64 host_postprocessing_ps = 17; + + // Tensor event details. + // One request can have multiple TensorEventDetails because it might be + // split into multiple batches for execution. + repeated TensorEventDetail tensor_event_details = 18; + + // Host index for this request. + optional int32 host_id = 19; + + // Percentile of this request in all requests in the profile duration. + optional double percentile = 20; + + // The time no event associated with. It could be that the machine was idle or + // executing some events which were not traced. + optional double idle_time_ps = 21; + + // Were device_start_time_ps, device_end_time_ps, session_id + reserved 3, 4, 9; +} + +// Detail of a batch. +// Next ID: 12 +message BatchDetail { + // Batch id. + optional int64 batch_id = 1 [default = -1]; + + // Start time of the batch in picosecs. + optional uint64 start_time_ps = 2 [default = 0]; + + // End time of the batch in picosecs. + optional uint64 end_time_ps = 3 [default = 0]; + + // The latency between "start time of the first request in this batch", and + // the time this batch is processed. + optional uint64 batch_delay_ps = 5 [default = 0]; + + // Request ids related to this batch. + repeated int64 related_request_ids = 4; + + // Size of padding. + optional int32 padding_amount = 6; + + // Size of a batch after padding. + optional int32 batch_size_after_padding = 7; + + // Model ID of this batch. This is the same model_id as any of the request in + // this batch. All the requests from the same batch must share the same + // model_id. + optional int32 model_id_index = 8; + + // Tensor event details. + optional TensorEventDetail tensor_event_detail = 9; + + // Host index for this batch. + optional int32 host_id = 10; + + // Percentile of this batch in all batches in the profile duration. + optional double percentile = 11; +} + +// Per-host data for inference analysis. +message PerHostInferenceStats { + // A list of requests selected for inference analysis on this host. + // This list is in ascending order of the request duration. + repeated RequestDetail request_details = 3; + + // A list of batches selected for inference analysis on this host. + // This list is in ascending order of the batch duration. + repeated BatchDetail batch_details = 5; + + reserved 1, 2, 4, 6; + + // were session_run_times, sessions_per_second, requests_per_second, + // batches_per_second. +} + +// Per-model aggregated result of tensor transfer. +message TensorTransferAggregatedResult { + message TensorPatternResult { + // The index of the tensor pattern in TensorPatternDatabase. + optional int32 tensor_pattern_index = 1; + + // The number of occurrence of this tensor pattern in this model. + optional uint64 count = 2; + + message PercentileTime { + optional double percentile = 1; + optional uint64 time_ps = 2; + } + + // The percentiles of the linearize and delinearize time of this tensor + // pattern in this model. + repeated PercentileTime linearize_delinearize_percentile_time = 3; + } + + repeated TensorPatternResult tensor_pattern_results = 1; +} + +// Aggregated result per batch size. +message PerBatchSizeAggregatedResult { + optional int32 batch_size = 1; + optional RequestDetail aggregated_request_result = 2; + optional BatchDetail aggregated_batch_result = 3; + optional double request_throughput = 4; + optional double batch_throughput = 5; +} + +// Per-model data for inference analysis. +message PerModelInferenceStats { + // A list of requests selected for inference analysis on this model. + // This list is in ascending order of the request duration. + repeated RequestDetail request_details = 1; + + // Aggregated result from all the . + optional RequestDetail aggregated_request_detail = 8; + + // Inference requests per second for this model. + optional double request_throughput = 2; + + // Average latency in microseconds of the requests in this model. + optional double request_average_latency_us = 3; + + // A list of batches selected for inference analysis on this model. + // This list is in ascending order of the batch duration. + repeated BatchDetail batch_details = 4; + + // Aggregated result from all the . + optional BatchDetail aggregated_batch_detail = 9; + + // Batches per second for this model. + optional double batch_throughput = 5; + + // Average latency in microseconds of the batches in this model. + optional double batch_average_latency_us = 6; + + // The aggregated result of tensor transfer in this model. + optional TensorTransferAggregatedResult tensor_transfer_aggregated_result = 7; + + // Aggregated result per batch size. + repeated PerBatchSizeAggregatedResult per_batch_size_aggregated_result = 10; +} + +// Batching parameters collected from TFstreamz. +message BatchingParameters { + // Number of batch threads. + optional int64 num_batch_threads = 1; + + // How long a request can wait before being processed by a batch. + optional int64 batch_timeout_micros = 2; + + // Maximum size of a batch. + optional int64 max_batch_size = 3; + + // Maximum number of enqueued batches. + optional int64 max_enqueued_batches = 4; + + // Sizes that are allowed to form a batch. A list of integers separated by "," + optional string allowed_batch_sizes = 5; +} + +// Model ID database. Unknown model id will be "" and won't be stored here. So +// if model id is not found in the TF-session metadata, ModelIdDatabase will be +// empty. +message ModelIdDatabase { + // Array of model ids. + repeated string ids = 1; + + // Map from id to index. + map id_to_index = 2; + + // Map from id to batching parameters. + map id_to_batching_params = 3; +} + +// Tensor pattern database for all the tensor patterns that occurred during the +// profiling window. +message TensorPatternDatabase { + // A tensor pattern is the string concatenation of all the linearize and + // delinearize events in an inference request. Each event records the tensor + // shape, data type and the layout on device. + repeated string tensor_pattern = 1; +} + +// Proto consumed by inference analysis. +message InferenceStats { + // Map from host-id to the InferenceStats for that host. + map inference_stats_per_host = 3; + + // Map from model-id to the InferenceStats for that model. + map inference_stats_per_model = + 5; + + // A database of model ids. + optional ModelIdDatabase model_id_db = 4; + + // A database of tensor patterns. + optional TensorPatternDatabase tensor_pattern_db = 6; + + reserved 1, 2; // were processing_stats, session_run_times +}