From 581d0dbe7af2d8c48b74756feb8396a8d320be6e Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Wed, 20 Nov 2024 00:55:14 -0800 Subject: [PATCH] Create OpStatsToRooflineModel, in preparation of Roofline Model creation PiperOrigin-RevId: 698292881 --- frontend/app/components/hlo_stats/BUILD | 49 +++ .../components/hlo_stats/hlo_stats.ng.html | 100 ++++++ .../app/components/hlo_stats/hlo_stats.scss | 35 +++ .../app/components/hlo_stats/hlo_stats.ts | 294 ++++++++++++++++++ .../components/hlo_stats/hlo_stats_module.ts | 31 ++ frontend/app/components/main_page/BUILD | 1 + .../components/main_page/main_page_module.ts | 5 + .../tensorboard_plugin_profile/convert/BUILD | 11 + .../convert/hlo_stats_proto_to_gviz.py | 162 ++++++++++ .../convert/raw_to_tool_data.py | 5 + .../tpu/tensorflow/tpu_tf2_keras_test.py | 1 + .../profile_plugin.py | 1 + .../tensorboard_plugin_profile/protobuf/BUILD | 2 + .../protobuf/hlo_stats.proto | 124 ++++++++ 14 files changed, 821 insertions(+) create mode 100644 frontend/app/components/hlo_stats/BUILD create mode 100644 frontend/app/components/hlo_stats/hlo_stats.ng.html create mode 100644 frontend/app/components/hlo_stats/hlo_stats.scss create mode 100644 frontend/app/components/hlo_stats/hlo_stats.ts create mode 100644 frontend/app/components/hlo_stats/hlo_stats_module.ts create mode 100644 plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py create mode 100644 plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto diff --git a/frontend/app/components/hlo_stats/BUILD b/frontend/app/components/hlo_stats/BUILD new file mode 100644 index 00000000..aa85a015 --- /dev/null +++ b/frontend/app/components/hlo_stats/BUILD @@ -0,0 +1,49 @@ +load("@io_bazel_rules_sass//:defs.bzl", "sass_binary") +load("//defs:defs.bzl", "xprof_ng_module") + +package(default_visibility = ["//frontend:internal"]) + +xprof_ng_module( + name = "hlo_stats", + srcs = [ + "hlo_stats.ts", + "hlo_stats_module.ts", + ], + assets = [ + ":hlo_stats_css", + "hlo_stats.ng.html", + ], + deps = [ + "@npm//@angular/core", + "@npm//@angular/router", + "@npm//@ngrx/store", + "@npm//rxjs", + "@org_xprof//frontend/app/common/angular:angular_material_divider", + "@org_xprof//frontend/app/common/angular:angular_material_expansion", + "@org_xprof//frontend/app/common/angular:angular_material_select", + "@org_xprof//frontend/app/common/angular:angular_material_tooltip", + "@org_xprof//frontend/app/common/constants", + "@org_xprof//frontend/app/common/interfaces", + "@org_xprof//frontend/app/common/interfaces:chart", + "@org_xprof//frontend/app/common/utils", + "@org_xprof//frontend/app/components/chart", + "@org_xprof//frontend/app/components/chart:chart_options", + "@org_xprof//frontend/app/components/chart:data_processor", + "@org_xprof//frontend/app/components/chart:default_data_provider", + "@org_xprof//frontend/app/components/chart/dashboard", + "@org_xprof//frontend/app/components/controls/category_filter", + "@org_xprof//frontend/app/components/controls/string_filter", + "@org_xprof//frontend/app/components/framework_op_stats/flop_rate_chart", + "@org_xprof//frontend/app/services/data_service", + "@org_xprof//frontend/app/store", + ], +) + +sass_binary( + name = "hlo_stats_css", + src = "hlo_stats.scss", + sourcemap = False, + deps = [ + "@org_xprof//frontend/app/styles:common", + ], +) diff --git a/frontend/app/components/hlo_stats/hlo_stats.ng.html b/frontend/app/components/hlo_stats/hlo_stats.ng.html new file mode 100644 index 00000000..7ba5d435 --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats.ng.html @@ -0,0 +1,100 @@ +
+
+
+
Notes
+
+
+ (1) You can filter the table by HLO op category, HLO op name or Framework op name (eg. JAX op, TF op). +
+
+ (2) In the charts and table below, "IDLE" represents the portion of the total execution + time on device that is idle.. +
+
+ (3) In the pie charts, the "Other" sector represents the sum of sectors that are too + small to be shown individually. +
+
+
+
+
+ +
+ + + + Chart Section + + +
+
+
Time per HLO category
+ +
+
+
Time per HLO operation
+ +
+
+
Time spent on rematerialization
+ +
+
+
Time per HLO by Replica Groups
+ + + {{op}} + + + +
+
+ +
+
+
Time spent on rematerialization per HLO category
+ +
+
+
Time spent on outside compilation
+ +
+ + + +
+
+
+ +
+
+ + + + + + + + + + +
+
+ +
+ + +
diff --git a/frontend/app/components/hlo_stats/hlo_stats.scss b/frontend/app/components/hlo_stats/hlo_stats.scss new file mode 100644 index 00000000..93a572ca --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats.scss @@ -0,0 +1,35 @@ +/** CSS for an hlo stats component. */ + +@import 'frontend/app/styles/common'; + +:host { + display: block; +} + +.section-container { + padding: 20px; +} + +.header-content { + font-style: italic; + margin-top: -16px; +} + +.row { + display: flex; +} + +.flex-space { + flex: 1; +} + +.pie-chart { + height: fit-content; + margin: 8px 8px 8px 0; + width: 450px; +} + +.table { + border: 1px solid $border-color; + margin: 0 16px; +} diff --git a/frontend/app/components/hlo_stats/hlo_stats.ts b/frontend/app/components/hlo_stats/hlo_stats.ts new file mode 100644 index 00000000..1b3ba676 --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats.ts @@ -0,0 +1,294 @@ +import {Component, OnDestroy} from '@angular/core'; +import {ActivatedRoute} from '@angular/router'; +import {Store} from '@ngrx/store'; +import {OpType} from 'org_xprof/frontend/app/common/constants/enums'; +import {ChartDataInfo} from 'org_xprof/frontend/app/common/interfaces/chart'; +import {SimpleDataTable} from 'org_xprof/frontend/app/common/interfaces/data_table'; +import {NavigationEvent} from 'org_xprof/frontend/app/common/interfaces/navigation_event'; +import {setLoadingState} from 'org_xprof/frontend/app/common/utils/utils'; +import {CategoryTableDataProcessor} from 'org_xprof/frontend/app/components/chart/category_table_data_processor'; +import {PIE_CHART_OPTIONS, TABLE_OPTIONS,} from 'org_xprof/frontend/app/components/chart/chart_options'; +import {Dashboard} from 'org_xprof/frontend/app/components/chart/dashboard/dashboard'; +import {DefaultDataProvider, ReplicaGroupDataProvider,} from 'org_xprof/frontend/app/components/chart/default_data_provider'; +import {DataService} from 'org_xprof/frontend/app/services/data_service/data_service'; +import {setCurrentToolStateAction} from 'org_xprof/frontend/app/store/actions'; +import {ReplaySubject} from 'rxjs'; +import {takeUntil} from 'rxjs/operators'; + +const OP_CATEGORY_ID = 'category'; +const OP_NAME_ID = 'hlo_op_name'; +const PROGRAM_ID = 'program_id'; +const OP_EXPRESSION_ID = 'hlo_op_expression'; +const SELF_TIME_ID = 'total_self_time'; +const HLO_REMAT_ID = 'hlo_rematerialization'; +const OUTSIDE_COMPILATION_ID = 'outside_compilation'; +const MEASURED_FLOP_RATE_ID = 'measured_flop_rate'; + +/** A Hlo Stats component. */ +@Component({ + standalone: false, + selector: 'hlo-stats', + templateUrl: './hlo_stats.ng.html', + styleUrls: ['./hlo_stats.css'], +}) +export class HloStats extends Dashboard implements OnDestroy { + readonly tool = 'hlo_stats'; + /** Handles on-destroy Subject, used to unsubscribe. */ + private readonly destroyed = new ReplaySubject(1); + sessionId = ''; + data: SimpleDataTable | null = null; + hloOpNameSelected = ''; + programIdSelected = ''; + // Flop rate chart properties. + readonly opType = OpType.XLA_HLO; + flopRateChartXColumn = -1; + flopRateChartYColumn = -1; + // Pie charts properties. + pieChartDataProvider = new DefaultDataProvider(); + replicaGroupDataProvider = new ReplicaGroupDataProvider(); + dataInfoCategoryChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoOpChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + communicationOps = new Set(); + selectedCommOp = ''; + dataInfoOpReplicaGroupChart: ChartDataInfo = { + data: null, + dataProvider: this.replicaGroupDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoRematerializationChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoRematerializationCategoryChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoOutsideCompilationChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + // Table properties. + dataInfoForTable: ChartDataInfo = { + data: null, + dataProvider: new DefaultDataProvider(), + filters: [], + options: { + ...TABLE_OPTIONS, + showRowNumber: false, + page: 'enable', + pageSize: 100, + sortAscending: true, + sortColumn: 0, + }, + }; + showChartSection = true; + + constructor( + route: ActivatedRoute, + private readonly dataService: DataService, + private readonly store: Store<{}>, + ) { + super(); + route.params.pipe(takeUntil(this.destroyed)).subscribe((params) => { + this.update(params as NavigationEvent); + }); + this.store.dispatch(setCurrentToolStateAction({currentTool: this.tool})); + } + + update(event: NavigationEvent) { + const run = event.run || ''; + const tag = event.tag || 'hlo_stats'; + const host = event.host || ''; + + setLoadingState(true, this.store, 'Loading hlo data'); + + this.dataService.getData(run, tag, host) + .pipe(takeUntil(this.destroyed)) + .subscribe((data) => { + setLoadingState(false, this.store); + this.data = data as SimpleDataTable | null; + this.process(this.data); + this.onCheckInputParams(); + }); + } + + onCheckInputParams() { + this.hloOpNameSelected = + this.dataService.searchParams?.get('hlo_op_name') || ''; + // Assumption: the program_id is in format like 'main()' + // parsing with a regex to match content in the bracket + const programIdParsed = this.dataService.searchParams + ?.get('program_id') + ?.match(/\((.*)\)/); + this.programIdSelected = + programIdParsed?.length === 2 ? programIdParsed[1] : ''; + } + + // Iterate through the table data + // and inject graph link to the hlo op text cell + addGraphViewerLinkInTableData(data: SimpleDataTable) { + const programIdColumnIdx = + data.cols?.findIndex((col) => col.id === PROGRAM_ID) ?? -1; + const hloOpExpressionColumnIdx = + data.cols?.findIndex((col) => col.id === OP_EXPRESSION_ID) ?? -1; + const hloOpNameColumnIdx = + data.cols?.findIndex((col) => col.id === OP_NAME_ID) ?? -1; + if (programIdColumnIdx === -1 || hloOpExpressionColumnIdx === -1) { + return data; + } + + const updatedData = { + ...data, + rows: data?.rows!.map((row, index) => { + const programId = (row.c![programIdColumnIdx].v as string).trim() || ''; + const hloOpName = (row.c![hloOpNameColumnIdx].v as string).trim() || ''; + const hloOpExpression = + (row.c![hloOpExpressionColumnIdx].v as string) || ''; + const graphViewerLink = `/graph_viewer/${this.sessionId}?program_id=${programId}&node_name=${hloOpName}`; + return { + ...row, + c: [ + ...row.c!.slice(0, hloOpExpressionColumnIdx), + { + ...row.c![hloOpExpressionColumnIdx], + v: `${hloOpExpression}`, + }, + ...row.c!.slice(hloOpExpressionColumnIdx + 1), + ], + }; + }), + }; + return updatedData; + } + + private process(data: SimpleDataTable | null) { + if (!data) return; + + this.parseData(data); + this.drawFlopRateChart(); + this.updateOpReplicaGroupChart(); + + const updatedData = this.addGraphViewerLinkInTableData(data); + this.dataInfoForTable = { + ...this.dataInfoForTable, + data: updatedData, + }; + } + + updateView() { + this.dataInfoForTable = { + ...this.dataInfoForTable, + filters: this.getFilters(), + }; + } + + updateOpReplicaGroupChart() { + if ( + !this.replicaGroupDataProvider.opCategoryIndex || + !this.replicaGroupDataProvider.hloOpNameIndex || + !this.replicaGroupDataProvider.selfTimeIndex + ) { + return; + } + + const filtersForReplicaGroup = [ + { + column: this.replicaGroupDataProvider.opCategoryIndex, + value: this.selectedCommOp, + }, + ]; + + this.dataInfoOpReplicaGroupChart.customChartDataProcessor = + new CategoryTableDataProcessor( + filtersForReplicaGroup, + this.replicaGroupDataProvider.hloOpNameIndex, + this.replicaGroupDataProvider.selfTimeIndex, + ); + + // Since the DataInfo has not been updated, the notifyCharts function is + // called to redraw the graph. + this.replicaGroupDataProvider.notifyCharts(); + } + + parseData(data: SimpleDataTable | null) { + if (!data) return; + // Five charts share one DataProvider. In order to prevent DataTable from + // being created multiple times, it calls DataProvider function directly. + this.pieChartDataProvider.parseData(data); + const dataTable = this.pieChartDataProvider.getDataTable(); + if (!dataTable) return; + + this.dataTable = dataTable; + this.updateView(); + + const hloOpNameIndex = dataTable.getColumnIndex(OP_EXPRESSION_ID); + const opCategoryIndex = dataTable.getColumnIndex(OP_CATEGORY_ID); + const selfTimeIndex = dataTable.getColumnIndex(SELF_TIME_ID); + const hloRematIndex = dataTable.getColumnIndex(HLO_REMAT_ID); + const outsideCompilationIndex = dataTable.getColumnIndex( + OUTSIDE_COMPILATION_ID, + ); + + const filtersForRemat = [{column: hloRematIndex, value: 'Yes'}]; + + this.dataInfoCategoryChart.customChartDataProcessor = + new CategoryTableDataProcessor([], opCategoryIndex, selfTimeIndex); + this.dataInfoOpChart.customChartDataProcessor = + new CategoryTableDataProcessor([], hloOpNameIndex, selfTimeIndex); + this.dataInfoRematerializationChart.customChartDataProcessor = + new CategoryTableDataProcessor([], hloRematIndex, selfTimeIndex, false); + this.dataInfoRematerializationCategoryChart.customChartDataProcessor = + new CategoryTableDataProcessor( + filtersForRemat, + opCategoryIndex, + selfTimeIndex, + ); + this.dataInfoOutsideCompilationChart.customChartDataProcessor = + new CategoryTableDataProcessor( + [], + outsideCompilationIndex, + selfTimeIndex, + false, + ); + + // Since the DataInfo has not been updated, the notifyCharts function is + // called to redraw the graph. + this.pieChartDataProvider.notifyCharts(); + + // Create a DataProvider in which the row string value for hloOpName column + // is truncated to only be the 'replica_groups={{...}}' string. + this.replicaGroupDataProvider.parseData(data); + this.communicationOps = this.replicaGroupDataProvider.communicationOps; + + if (this.communicationOps.size) { + // Set value to the first communication Op in the set. + this.selectedCommOp = this.communicationOps.values().next().value; + } + } + + private drawFlopRateChart() { + if (!this.dataTable || !this.dataTable.getColumnIndex) return; + this.flopRateChartXColumn = this.dataTable.getColumnIndex(OP_EXPRESSION_ID); + this.flopRateChartYColumn = this.dataTable.getColumnIndex( + MEASURED_FLOP_RATE_ID, + ); + } + + ngOnDestroy() { + // Unsubscribes all pending subscriptions. + setLoadingState(false, this.store); + this.destroyed.next(); + this.destroyed.complete(); + } +} diff --git a/frontend/app/components/hlo_stats/hlo_stats_module.ts b/frontend/app/components/hlo_stats/hlo_stats_module.ts new file mode 100644 index 00000000..8cf73e5c --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats_module.ts @@ -0,0 +1,31 @@ +import {CommonModule} from '@angular/common'; +import {NgModule} from '@angular/core'; +import {MatDividerModule} from '@angular/material/divider'; +import {MatExpansionModule} from '@angular/material/expansion'; +import {MatSelectModule} from '@angular/material/select'; +import {MatTooltipModule} from '@angular/material/tooltip'; +import {ChartModule} from 'org_xprof/frontend/app/components/chart/chart'; +import {CategoryFilterModule} from 'org_xprof/frontend/app/components/controls/category_filter/category_filter_module'; +import {StringFilterModule} from 'org_xprof/frontend/app/components/controls/string_filter/string_filter_module'; +import {FlopRateChartModule} from 'org_xprof/frontend/app/components/framework_op_stats/flop_rate_chart/flop_rate_chart_module'; + +import {HloStats} from './hlo_stats'; + +/** An HLO stats module. */ +@NgModule({ + declarations: [HloStats], + imports: [ + ChartModule, + CommonModule, + FlopRateChartModule, + CategoryFilterModule, + StringFilterModule, + MatDividerModule, + MatSelectModule, + MatExpansionModule, + MatTooltipModule, + ], + exports: [HloStats], +}) +export class HloStatsModule { +} diff --git a/frontend/app/components/main_page/BUILD b/frontend/app/components/main_page/BUILD index fa609499..33ed3893 100644 --- a/frontend/app/components/main_page/BUILD +++ b/frontend/app/components/main_page/BUILD @@ -29,6 +29,7 @@ xprof_ng_module( "@org_xprof//frontend/app/components/empty_page", "@org_xprof//frontend/app/components/framework_op_stats:framework_op_stats_adapter", "@org_xprof//frontend/app/components/graph_viewer", + "@org_xprof//frontend/app/components/hlo_stats", "@org_xprof//frontend/app/components/input_pipeline", "@org_xprof//frontend/app/components/kernel_stats:kernel_stats_adapter", "@org_xprof//frontend/app/components/memory_profile", diff --git a/frontend/app/components/main_page/main_page_module.ts b/frontend/app/components/main_page/main_page_module.ts index 6ff34df3..52c49d23 100644 --- a/frontend/app/components/main_page/main_page_module.ts +++ b/frontend/app/components/main_page/main_page_module.ts @@ -12,6 +12,8 @@ import {EmptyPageModule} from 'org_xprof/frontend/app/components/empty_page/empt import {FrameworkOpStatsAdapter, FrameworkOpStatsAdapterModule} from 'org_xprof/frontend/app/components/framework_op_stats/framework_op_stats_adapter'; import {GraphViewer} from 'org_xprof/frontend/app/components/graph_viewer/graph_viewer'; import {GraphViewerModule} from 'org_xprof/frontend/app/components/graph_viewer/graph_viewer_module'; +import {HloStats} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats'; +import {HloStatsModule} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats_module'; import {InputPipeline} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline'; import {InputPipelineModule} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline_module'; import {KernelStatsAdapter, KernelStatsAdapterModule} from 'org_xprof/frontend/app/components/kernel_stats/kernel_stats_adapter'; @@ -61,6 +63,8 @@ export const routes: Routes = [ {path: 'trace_viewer^', component: TraceViewer}, {path: 'graph_viewer^', component: GraphViewer}, {path: 'dcn_collective_stats^', component: DcnCollectiveStats}, + {path: 'hlo_stats', component: HloStats}, + {path: 'hlo_stats^', component: HloStats}, {path: '**', component: EmptyPage}, ]; @@ -87,6 +91,7 @@ export const routes: Routes = [ TfDataBottleneckAnalysisModule, FrameworkOpStatsAdapterModule, DcnCollectiveStatsModule, + HloStatsModule, RouterModule.forRoot(routes), ], exports: [MainPage] diff --git a/plugin/tensorboard_plugin_profile/convert/BUILD b/plugin/tensorboard_plugin_profile/convert/BUILD index 08665fda..26d574db 100644 --- a/plugin/tensorboard_plugin_profile/convert/BUILD +++ b/plugin/tensorboard_plugin_profile/convert/BUILD @@ -87,6 +87,16 @@ py_library( ], ) +py_library( + name = "hlo_stats_proto_to_gviz", + srcs = ["hlo_stats_proto_to_gviz.py"], + srcs_version = "PY2AND3", + deps = [ + requirement("gviz_api"), + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2", + ], +) + py_test( name = "overview_page_proto_to_gviz_test", size = "small", @@ -236,6 +246,7 @@ py_library( visibility = visibility, deps = [ ":dcn_collective_stats_proto_to_gviz", + ":hlo_stats_proto_to_gviz", ":input_pipeline_proto_to_gviz", ":kernel_stats_proto_to_gviz", ":overview_page_proto_to_gviz", diff --git a/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py b/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py new file mode 100644 index 00000000..d0ac3892 --- /dev/null +++ b/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py @@ -0,0 +1,162 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""For conversion of HloStats protos to GViz DataTables. + +Usage: + gviz_data_tables = generate_hlo_stats_table(hlo_stats_db) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import Any + +import gviz_api + +from tensorboard_plugin_profile.protobuf import hlo_stats_pb2 + + +def get_hlo_op_name_from_expression(hlo_op_expression: str) -> str: + """Returns the hlo op name from the hlo op expression. + + Args: + hlo_op_expression: A string of hlo op expression. + + Returns: + Returns the hlo op name extracted from the hlo op expression. + """ + # The parse logic based on the assumption that the hlo op text is in format of + # '%op_name = ' + parts = hlo_op_expression.split(" = ") + hlo_op_name = parts[0] + if hlo_op_name and hlo_op_name[0] == "%": + hlo_op_name = hlo_op_name[1:] + return hlo_op_name + + +TableColumnDescription = tuple[str, str, str] +TableRow = list[Any] +TableProperties = dict[str, Any] + + +def get_hlo_stats_table_args( + hlo_stats_db: hlo_stats_pb2.HloStatsDatabase, +) -> tuple[list[TableColumnDescription], list[TableRow], TableProperties]: + """Creates hlo op stats table from a hlo stats proto. + + Args: + hlo_stats_db: A HloStatsDatabase proto. + + Returns: + Returns table description(column defubutuibs), data(rows data) and custom + properties for preparing the HloStats gviz table. + """ + + table_description = [ + ("rank", "number", "Rank"), + ("program_id", "string", "Program id"), + ("category", "string", "HLO op category"), + ("hlo_op_name", "string", "HLO op name"), + ("hlo_op_expression", "string", "HLO op text"), + ("tf_op_name", "string", "Framework op name"), + ("occurrences", "number", "#Occurrences"), + ("total_time", "number", "Total time (us)"), + ("avg_time", "number", "Avg. time (us)"), + ("total_self_time", "number", "Total self time (us)"), + ("avg_self_time", "number", "Avg. self time (us)"), + ("total_self_time_percent", "number", "Total self time (%)"), + ( + "cumulative_total_self_time_percent", + "number", + "Cumulative total self time (%)", + ), + ("dma_stall_percent", "number", "%time stalled by DMA"), + ("model_flop_rate", "number", "Model GFLOP/s"), + ("normalized_flop_rate", "number", "Normalized GFLOP/s"), + ("measured_memory_bw", "number", "Measured memory BW (GiB/s)"), + ("hbm_bw", "number", "HBM BW (GiB/s)"), + ("cmem_read_bw", "number", "CMEM Read BW (GiB/s)"), + ("cmem_write_bw", "number", "CMEM Write BW (GiB/s)"), + ("operational_intensity", "number", "Operational intensity (FLOPS/Byte)"), + ("bound_by", "string", "Bound by"), + ("hlo_rematerialization", "string", "Rematerialization"), + ("outside_compilation", "string", "Outside Compilation"), + ("autotuned", "string", "Autotuned"), + ] + + data = [] + for record in hlo_stats_db.hlo_stats_record: + row = [ + record.rank, + str(record.program_id), + record.hlo_category, + get_hlo_op_name_from_expression(record.hlo_expression), + record.hlo_expression, + record.tf_op_name, + record.occurrences, + record.total_time_in_us, + record.avg_time_in_us, + record.total_self_time_in_us, + record.avg_self_time_in_us, + record.total_self_time_as_fraction, + record.cumulative_total_self_time_as_fraction, + record.dma_stall_fraction, + record.model_flop_rate, + record.measured_flop_rate, + record.measured_memory_bw, + record.hbm_bw, + record.cmem_read_bw, + record.cmem_write_bw, + record.operational_intensity, + record.bound_by, + "Yes" if record.rematerialization else "No", + "Yes" if record.outside_compilation else "No", + "Yes" if record.autotuned else "No", + ] + data.append(row) + + return (table_description, data, {}) + + +def generate_hlo_stats_table( + hlo_stats_db: hlo_stats_pb2.HloStatsDatabase, +) -> gviz_api.DataTable: + """Converts a HloStatsDb proto to a GViz DataTable. + + Args: + hlo_stats_db: A HloStatsDatabase proto. + + Returns: + Returns a GViz DataTable for HloStats. + """ + (table_description, data, custom_properties) = get_hlo_stats_table_args( + hlo_stats_db + ) + return gviz_api.DataTable(table_description, data, custom_properties) + + +def to_json(raw_data: bytes) -> str: + """Converts a serialized HloStatsDb bytes to json. + + Args: + raw_data: Bytes of serialized HloStatsDb. + + Returns: + Returns a json string of HloStats Table. + """ + hlo_stats_db = hlo_stats_pb2.HloStatsDatabase() + hlo_stats_db.ParseFromString(raw_data) + return generate_hlo_stats_table(hlo_stats_db).ToJSon() diff --git a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py index c2acb269..095e3540 100644 --- a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py +++ b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py @@ -27,6 +27,7 @@ from tensorflow.python.profiler.internal import _pywrap_profiler_plugin # pylint: disable=g-direct-tensorflow-import from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz +from tensorboard_plugin_profile.convert import hlo_stats_proto_to_gviz from tensorboard_plugin_profile.convert import input_pipeline_proto_to_gviz from tensorboard_plugin_profile.convert import kernel_stats_proto_to_gviz from tensorboard_plugin_profile.convert import overview_page_proto_to_gviz @@ -169,6 +170,10 @@ def xspace_to_tool_data( raw_data, success = xspace_wrapper_func(xspace_paths, tool) if success: data = raw_data + elif tool == 'hlo_stats': + raw_data, success = xspace_wrapper_func(xspace_paths, tool) + if success: + data = hlo_stats_proto_to_gviz.to_json(raw_data) elif tool == 'graph_viewer': options = params.get('graph_viewer_options', {}) raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) diff --git a/plugin/tensorboard_plugin_profile/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py b/plugin/tensorboard_plugin_profile/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py index c291ffec..7e264c52 100644 --- a/plugin/tensorboard_plugin_profile/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py +++ b/plugin/tensorboard_plugin_profile/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py @@ -108,6 +108,7 @@ def test_tools_are_in_list(self): 'op_profile^', 'memory_viewer^', 'graph_viewer^', + 'hlo_stats^', ] expected.sort() self.assertListEqual(expected, result) diff --git a/plugin/tensorboard_plugin_profile/profile_plugin.py b/plugin/tensorboard_plugin_profile/profile_plugin.py index 8a9071ec..832ebe85 100644 --- a/plugin/tensorboard_plugin_profile/profile_plugin.py +++ b/plugin/tensorboard_plugin_profile/profile_plugin.py @@ -110,6 +110,7 @@ 'pod_viewer^', 'tf_data_bottleneck_analysis^', 'op_profile^', + 'hlo_stats^', ] # XPlane generated tools that support all host mode. diff --git a/plugin/tensorboard_plugin_profile/protobuf/BUILD b/plugin/tensorboard_plugin_profile/protobuf/BUILD index df25677e..9d048700 100644 --- a/plugin/tensorboard_plugin_profile/protobuf/BUILD +++ b/plugin/tensorboard_plugin_profile/protobuf/BUILD @@ -12,6 +12,7 @@ proto_library( srcs = [ "dcn_slack_analysis.proto", "diagnostics.proto", + "hlo_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", @@ -28,6 +29,7 @@ py_proto_library( srcs = [ "dcn_slack_analysis.proto", "diagnostics.proto", + "hlo_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", diff --git a/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto b/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto new file mode 100644 index 00000000..19776256 --- /dev/null +++ b/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto @@ -0,0 +1,124 @@ +// This proto describes the format of the output profile file from +// the HLO stats tool. +syntax = "proto2"; + +package tensorflow.profiler.hlo_stats; + +// A database of HloStats records. +message HloStatsDatabase { + // All HloStats records, one for each HLO operation. + repeated HloStatsRecord hlo_stats_record = 1; +} + +// There is one HloStatsRecord for each HLO operation profiled. +// Next ID: 39 +message HloStatsRecord { + // The rank by self time + optional uint64 rank = 1; + + // program_id for this op + optional uint64 program_id = 30; + + // The HLO category name. + optional string hlo_category = 17; + + // The HLO expression. + optional string hlo_expression = 2; + + // The framework op name (TF Op, JAX Op) + optional string tf_op_name = 21; + + // Number of occurrences of the operation. + optional int64 occurrences = 3; + + // Total "accumulated" time in micro-seconds that the operation + // took. If this operation has any children operations, + // the "accumulated" time includes the time spent inside children. + optional double total_time_in_us = 4; + + // Average "accumulated" time in micro-seconds that each + // occurrence of the operation took. + optional double avg_time_in_us = 5; + + // Total "self" time in micro-seconds that the operation took. + // If this operation has any children operations, the "self" time + // doesn't include the time spent inside children. + optional double total_self_time_in_us = 6; + + // Average "self" time in micro-seconds that the operation took. + optional double avg_self_time_in_us = 7; + optional double total_self_time_as_fraction = 8; + optional double cumulative_total_self_time_as_fraction = 9; + + // Percentage of the total "accumulated" time that was caused by + // DMA stall. + optional double dma_stall_fraction = 10; + + // Total floating-point operations (FLOPs) performed per second normalized to + // the bf16 peak capacity. + optional double measured_flop_rate = 13; + + // Total Floating-point operations for the op per second. + optional double model_flop_rate = 34; + + // Number of total bytes (including both read and write) accessed per second. + optional double measured_memory_bw = 14; + + // Number of bytes accessed from HBM (including both read and write) per + // second. + optional double hbm_bw = 22; + + // Number of bytes read from CMEM per second. + optional double cmem_read_bw = 23; + + // Number of bytes written to CMEM per second. + optional double cmem_write_bw = 24; + + // Number of bytes read from VMEM per second. + optional double vmem_read_bw = 35; + + // Number of bytes written to VMEM per second. + optional double vmem_write_bw = 36; + + // Overall operational intensity in FLOP/Byte. + optional double operational_intensity = 15; + + // Operational intensity based on HBM in FLOP/Byte. + optional double hbm_operational_intensity = 26; + + // Operational intensity based on CMEM read in FLOP/Byte. + optional double cmem_read_operational_intensity = 27; + + // Operational intensity based on CMEM write in FLOP/Byte. + optional double cmem_write_operational_intensity = 28; + + // Operational intensity based on VMEM read in FLOP/Byte. + optional double vmem_read_operational_intensity = 37; + + // Operational intensity based on VMEM write in FLOP/Byte. + optional double vmem_write_operational_intensity = 38; + + // Operational intensity based on the bottleneck resource in FLOP/Byte. + optional double bottleneck_operational_intensity = 29; + + // Whether this operation is "Compute", "HBM", "CMEM Read", "CMEM Write" + // bound, according to the Roofline Model. + optional string bound_by = 16; + + // Whether this operation is for HLO or Framework rematerialization. + optional bool rematerialization = 20; + + // Whether this op is for outside compilation. + optional bool outside_compilation = 25; + + // Whether this op is autotuned. + optional bool autotuned = 31; + + // Flops for the record + optional uint64 flops = 32; + + // Bytes accessed for the record + optional uint64 bytes_accessed = 33; + + reserved 11, 12, 18, 19; +}