diff --git a/frontend/app/components/hlo_stats/BUILD b/frontend/app/components/hlo_stats/BUILD
new file mode 100644
index 00000000..aa85a015
--- /dev/null
+++ b/frontend/app/components/hlo_stats/BUILD
@@ -0,0 +1,49 @@
+load("@io_bazel_rules_sass//:defs.bzl", "sass_binary")
+load("//defs:defs.bzl", "xprof_ng_module")
+
+package(default_visibility = ["//frontend:internal"])
+
+xprof_ng_module(
+ name = "hlo_stats",
+ srcs = [
+ "hlo_stats.ts",
+ "hlo_stats_module.ts",
+ ],
+ assets = [
+ ":hlo_stats_css",
+ "hlo_stats.ng.html",
+ ],
+ deps = [
+ "@npm//@angular/core",
+ "@npm//@angular/router",
+ "@npm//@ngrx/store",
+ "@npm//rxjs",
+ "@org_xprof//frontend/app/common/angular:angular_material_divider",
+ "@org_xprof//frontend/app/common/angular:angular_material_expansion",
+ "@org_xprof//frontend/app/common/angular:angular_material_select",
+ "@org_xprof//frontend/app/common/angular:angular_material_tooltip",
+ "@org_xprof//frontend/app/common/constants",
+ "@org_xprof//frontend/app/common/interfaces",
+ "@org_xprof//frontend/app/common/interfaces:chart",
+ "@org_xprof//frontend/app/common/utils",
+ "@org_xprof//frontend/app/components/chart",
+ "@org_xprof//frontend/app/components/chart:chart_options",
+ "@org_xprof//frontend/app/components/chart:data_processor",
+ "@org_xprof//frontend/app/components/chart:default_data_provider",
+ "@org_xprof//frontend/app/components/chart/dashboard",
+ "@org_xprof//frontend/app/components/controls/category_filter",
+ "@org_xprof//frontend/app/components/controls/string_filter",
+ "@org_xprof//frontend/app/components/framework_op_stats/flop_rate_chart",
+ "@org_xprof//frontend/app/services/data_service",
+ "@org_xprof//frontend/app/store",
+ ],
+)
+
+sass_binary(
+ name = "hlo_stats_css",
+ src = "hlo_stats.scss",
+ sourcemap = False,
+ deps = [
+ "@org_xprof//frontend/app/styles:common",
+ ],
+)
diff --git a/frontend/app/components/hlo_stats/hlo_stats.ng.html b/frontend/app/components/hlo_stats/hlo_stats.ng.html
new file mode 100644
index 00000000..7ba5d435
--- /dev/null
+++ b/frontend/app/components/hlo_stats/hlo_stats.ng.html
@@ -0,0 +1,100 @@
+
+
+
+
+
+
+
+ Chart Section
+
+
+
+
+
Time per HLO category
+
+
+
+
Time per HLO operation
+
+
+
+
Time spent on rematerialization
+
+
+
+
Time per HLO by Replica Groups
+
+
+ {{op}}
+
+
+
+
+
+
+
+
+
Time spent on rematerialization per HLO category
+
+
+
+
Time spent on outside compilation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/app/components/hlo_stats/hlo_stats.scss b/frontend/app/components/hlo_stats/hlo_stats.scss
new file mode 100644
index 00000000..93a572ca
--- /dev/null
+++ b/frontend/app/components/hlo_stats/hlo_stats.scss
@@ -0,0 +1,35 @@
+/** CSS for an hlo stats component. */
+
+@import 'frontend/app/styles/common';
+
+:host {
+ display: block;
+}
+
+.section-container {
+ padding: 20px;
+}
+
+.header-content {
+ font-style: italic;
+ margin-top: -16px;
+}
+
+.row {
+ display: flex;
+}
+
+.flex-space {
+ flex: 1;
+}
+
+.pie-chart {
+ height: fit-content;
+ margin: 8px 8px 8px 0;
+ width: 450px;
+}
+
+.table {
+ border: 1px solid $border-color;
+ margin: 0 16px;
+}
diff --git a/frontend/app/components/hlo_stats/hlo_stats.ts b/frontend/app/components/hlo_stats/hlo_stats.ts
new file mode 100644
index 00000000..1b3ba676
--- /dev/null
+++ b/frontend/app/components/hlo_stats/hlo_stats.ts
@@ -0,0 +1,294 @@
+import {Component, OnDestroy} from '@angular/core';
+import {ActivatedRoute} from '@angular/router';
+import {Store} from '@ngrx/store';
+import {OpType} from 'org_xprof/frontend/app/common/constants/enums';
+import {ChartDataInfo} from 'org_xprof/frontend/app/common/interfaces/chart';
+import {SimpleDataTable} from 'org_xprof/frontend/app/common/interfaces/data_table';
+import {NavigationEvent} from 'org_xprof/frontend/app/common/interfaces/navigation_event';
+import {setLoadingState} from 'org_xprof/frontend/app/common/utils/utils';
+import {CategoryTableDataProcessor} from 'org_xprof/frontend/app/components/chart/category_table_data_processor';
+import {PIE_CHART_OPTIONS, TABLE_OPTIONS,} from 'org_xprof/frontend/app/components/chart/chart_options';
+import {Dashboard} from 'org_xprof/frontend/app/components/chart/dashboard/dashboard';
+import {DefaultDataProvider, ReplicaGroupDataProvider,} from 'org_xprof/frontend/app/components/chart/default_data_provider';
+import {DataService} from 'org_xprof/frontend/app/services/data_service/data_service';
+import {setCurrentToolStateAction} from 'org_xprof/frontend/app/store/actions';
+import {ReplaySubject} from 'rxjs';
+import {takeUntil} from 'rxjs/operators';
+
+const OP_CATEGORY_ID = 'category';
+const OP_NAME_ID = 'hlo_op_name';
+const PROGRAM_ID = 'program_id';
+const OP_EXPRESSION_ID = 'hlo_op_expression';
+const SELF_TIME_ID = 'total_self_time';
+const HLO_REMAT_ID = 'hlo_rematerialization';
+const OUTSIDE_COMPILATION_ID = 'outside_compilation';
+const MEASURED_FLOP_RATE_ID = 'measured_flop_rate';
+
+/** A Hlo Stats component. */
+@Component({
+ standalone: false,
+ selector: 'hlo-stats',
+ templateUrl: './hlo_stats.ng.html',
+ styleUrls: ['./hlo_stats.css'],
+})
+export class HloStats extends Dashboard implements OnDestroy {
+ readonly tool = 'hlo_stats';
+ /** Handles on-destroy Subject, used to unsubscribe. */
+ private readonly destroyed = new ReplaySubject(1);
+ sessionId = '';
+ data: SimpleDataTable | null = null;
+ hloOpNameSelected = '';
+ programIdSelected = '';
+ // Flop rate chart properties.
+ readonly opType = OpType.XLA_HLO;
+ flopRateChartXColumn = -1;
+ flopRateChartYColumn = -1;
+ // Pie charts properties.
+ pieChartDataProvider = new DefaultDataProvider();
+ replicaGroupDataProvider = new ReplicaGroupDataProvider();
+ dataInfoCategoryChart: ChartDataInfo = {
+ data: null,
+ dataProvider: this.pieChartDataProvider,
+ options: PIE_CHART_OPTIONS,
+ };
+ dataInfoOpChart: ChartDataInfo = {
+ data: null,
+ dataProvider: this.pieChartDataProvider,
+ options: PIE_CHART_OPTIONS,
+ };
+ communicationOps = new Set();
+ selectedCommOp = '';
+ dataInfoOpReplicaGroupChart: ChartDataInfo = {
+ data: null,
+ dataProvider: this.replicaGroupDataProvider,
+ options: PIE_CHART_OPTIONS,
+ };
+ dataInfoRematerializationChart: ChartDataInfo = {
+ data: null,
+ dataProvider: this.pieChartDataProvider,
+ options: PIE_CHART_OPTIONS,
+ };
+ dataInfoRematerializationCategoryChart: ChartDataInfo = {
+ data: null,
+ dataProvider: this.pieChartDataProvider,
+ options: PIE_CHART_OPTIONS,
+ };
+ dataInfoOutsideCompilationChart: ChartDataInfo = {
+ data: null,
+ dataProvider: this.pieChartDataProvider,
+ options: PIE_CHART_OPTIONS,
+ };
+ // Table properties.
+ dataInfoForTable: ChartDataInfo = {
+ data: null,
+ dataProvider: new DefaultDataProvider(),
+ filters: [],
+ options: {
+ ...TABLE_OPTIONS,
+ showRowNumber: false,
+ page: 'enable',
+ pageSize: 100,
+ sortAscending: true,
+ sortColumn: 0,
+ },
+ };
+ showChartSection = true;
+
+ constructor(
+ route: ActivatedRoute,
+ private readonly dataService: DataService,
+ private readonly store: Store<{}>,
+ ) {
+ super();
+ route.params.pipe(takeUntil(this.destroyed)).subscribe((params) => {
+ this.update(params as NavigationEvent);
+ });
+ this.store.dispatch(setCurrentToolStateAction({currentTool: this.tool}));
+ }
+
+ update(event: NavigationEvent) {
+ const run = event.run || '';
+ const tag = event.tag || 'hlo_stats';
+ const host = event.host || '';
+
+ setLoadingState(true, this.store, 'Loading hlo data');
+
+ this.dataService.getData(run, tag, host)
+ .pipe(takeUntil(this.destroyed))
+ .subscribe((data) => {
+ setLoadingState(false, this.store);
+ this.data = data as SimpleDataTable | null;
+ this.process(this.data);
+ this.onCheckInputParams();
+ });
+ }
+
+ onCheckInputParams() {
+ this.hloOpNameSelected =
+ this.dataService.searchParams?.get('hlo_op_name') || '';
+ // Assumption: the program_id is in format like 'main()'
+ // parsing with a regex to match content in the bracket
+ const programIdParsed = this.dataService.searchParams
+ ?.get('program_id')
+ ?.match(/\((.*)\)/);
+ this.programIdSelected =
+ programIdParsed?.length === 2 ? programIdParsed[1] : '';
+ }
+
+ // Iterate through the table data
+ // and inject graph link to the hlo op text cell
+ addGraphViewerLinkInTableData(data: SimpleDataTable) {
+ const programIdColumnIdx =
+ data.cols?.findIndex((col) => col.id === PROGRAM_ID) ?? -1;
+ const hloOpExpressionColumnIdx =
+ data.cols?.findIndex((col) => col.id === OP_EXPRESSION_ID) ?? -1;
+ const hloOpNameColumnIdx =
+ data.cols?.findIndex((col) => col.id === OP_NAME_ID) ?? -1;
+ if (programIdColumnIdx === -1 || hloOpExpressionColumnIdx === -1) {
+ return data;
+ }
+
+ const updatedData = {
+ ...data,
+ rows: data?.rows!.map((row, index) => {
+ const programId = (row.c![programIdColumnIdx].v as string).trim() || '';
+ const hloOpName = (row.c![hloOpNameColumnIdx].v as string).trim() || '';
+ const hloOpExpression =
+ (row.c![hloOpExpressionColumnIdx].v as string) || '';
+ const graphViewerLink = `/graph_viewer/${this.sessionId}?program_id=${programId}&node_name=${hloOpName}`;
+ return {
+ ...row,
+ c: [
+ ...row.c!.slice(0, hloOpExpressionColumnIdx),
+ {
+ ...row.c![hloOpExpressionColumnIdx],
+ v: `${hloOpExpression}`,
+ },
+ ...row.c!.slice(hloOpExpressionColumnIdx + 1),
+ ],
+ };
+ }),
+ };
+ return updatedData;
+ }
+
+ private process(data: SimpleDataTable | null) {
+ if (!data) return;
+
+ this.parseData(data);
+ this.drawFlopRateChart();
+ this.updateOpReplicaGroupChart();
+
+ const updatedData = this.addGraphViewerLinkInTableData(data);
+ this.dataInfoForTable = {
+ ...this.dataInfoForTable,
+ data: updatedData,
+ };
+ }
+
+ updateView() {
+ this.dataInfoForTable = {
+ ...this.dataInfoForTable,
+ filters: this.getFilters(),
+ };
+ }
+
+ updateOpReplicaGroupChart() {
+ if (
+ !this.replicaGroupDataProvider.opCategoryIndex ||
+ !this.replicaGroupDataProvider.hloOpNameIndex ||
+ !this.replicaGroupDataProvider.selfTimeIndex
+ ) {
+ return;
+ }
+
+ const filtersForReplicaGroup = [
+ {
+ column: this.replicaGroupDataProvider.opCategoryIndex,
+ value: this.selectedCommOp,
+ },
+ ];
+
+ this.dataInfoOpReplicaGroupChart.customChartDataProcessor =
+ new CategoryTableDataProcessor(
+ filtersForReplicaGroup,
+ this.replicaGroupDataProvider.hloOpNameIndex,
+ this.replicaGroupDataProvider.selfTimeIndex,
+ );
+
+ // Since the DataInfo has not been updated, the notifyCharts function is
+ // called to redraw the graph.
+ this.replicaGroupDataProvider.notifyCharts();
+ }
+
+ parseData(data: SimpleDataTable | null) {
+ if (!data) return;
+ // Five charts share one DataProvider. In order to prevent DataTable from
+ // being created multiple times, it calls DataProvider function directly.
+ this.pieChartDataProvider.parseData(data);
+ const dataTable = this.pieChartDataProvider.getDataTable();
+ if (!dataTable) return;
+
+ this.dataTable = dataTable;
+ this.updateView();
+
+ const hloOpNameIndex = dataTable.getColumnIndex(OP_EXPRESSION_ID);
+ const opCategoryIndex = dataTable.getColumnIndex(OP_CATEGORY_ID);
+ const selfTimeIndex = dataTable.getColumnIndex(SELF_TIME_ID);
+ const hloRematIndex = dataTable.getColumnIndex(HLO_REMAT_ID);
+ const outsideCompilationIndex = dataTable.getColumnIndex(
+ OUTSIDE_COMPILATION_ID,
+ );
+
+ const filtersForRemat = [{column: hloRematIndex, value: 'Yes'}];
+
+ this.dataInfoCategoryChart.customChartDataProcessor =
+ new CategoryTableDataProcessor([], opCategoryIndex, selfTimeIndex);
+ this.dataInfoOpChart.customChartDataProcessor =
+ new CategoryTableDataProcessor([], hloOpNameIndex, selfTimeIndex);
+ this.dataInfoRematerializationChart.customChartDataProcessor =
+ new CategoryTableDataProcessor([], hloRematIndex, selfTimeIndex, false);
+ this.dataInfoRematerializationCategoryChart.customChartDataProcessor =
+ new CategoryTableDataProcessor(
+ filtersForRemat,
+ opCategoryIndex,
+ selfTimeIndex,
+ );
+ this.dataInfoOutsideCompilationChart.customChartDataProcessor =
+ new CategoryTableDataProcessor(
+ [],
+ outsideCompilationIndex,
+ selfTimeIndex,
+ false,
+ );
+
+ // Since the DataInfo has not been updated, the notifyCharts function is
+ // called to redraw the graph.
+ this.pieChartDataProvider.notifyCharts();
+
+ // Create a DataProvider in which the row string value for hloOpName column
+ // is truncated to only be the 'replica_groups={{...}}' string.
+ this.replicaGroupDataProvider.parseData(data);
+ this.communicationOps = this.replicaGroupDataProvider.communicationOps;
+
+ if (this.communicationOps.size) {
+ // Set value to the first communication Op in the set.
+ this.selectedCommOp = this.communicationOps.values().next().value;
+ }
+ }
+
+ private drawFlopRateChart() {
+ if (!this.dataTable || !this.dataTable.getColumnIndex) return;
+ this.flopRateChartXColumn = this.dataTable.getColumnIndex(OP_EXPRESSION_ID);
+ this.flopRateChartYColumn = this.dataTable.getColumnIndex(
+ MEASURED_FLOP_RATE_ID,
+ );
+ }
+
+ ngOnDestroy() {
+ // Unsubscribes all pending subscriptions.
+ setLoadingState(false, this.store);
+ this.destroyed.next();
+ this.destroyed.complete();
+ }
+}
diff --git a/frontend/app/components/hlo_stats/hlo_stats_module.ts b/frontend/app/components/hlo_stats/hlo_stats_module.ts
new file mode 100644
index 00000000..8cf73e5c
--- /dev/null
+++ b/frontend/app/components/hlo_stats/hlo_stats_module.ts
@@ -0,0 +1,31 @@
+import {CommonModule} from '@angular/common';
+import {NgModule} from '@angular/core';
+import {MatDividerModule} from '@angular/material/divider';
+import {MatExpansionModule} from '@angular/material/expansion';
+import {MatSelectModule} from '@angular/material/select';
+import {MatTooltipModule} from '@angular/material/tooltip';
+import {ChartModule} from 'org_xprof/frontend/app/components/chart/chart';
+import {CategoryFilterModule} from 'org_xprof/frontend/app/components/controls/category_filter/category_filter_module';
+import {StringFilterModule} from 'org_xprof/frontend/app/components/controls/string_filter/string_filter_module';
+import {FlopRateChartModule} from 'org_xprof/frontend/app/components/framework_op_stats/flop_rate_chart/flop_rate_chart_module';
+
+import {HloStats} from './hlo_stats';
+
+/** An HLO stats module. */
+@NgModule({
+ declarations: [HloStats],
+ imports: [
+ ChartModule,
+ CommonModule,
+ FlopRateChartModule,
+ CategoryFilterModule,
+ StringFilterModule,
+ MatDividerModule,
+ MatSelectModule,
+ MatExpansionModule,
+ MatTooltipModule,
+ ],
+ exports: [HloStats],
+})
+export class HloStatsModule {
+}
diff --git a/frontend/app/components/main_page/BUILD b/frontend/app/components/main_page/BUILD
index fa609499..33ed3893 100644
--- a/frontend/app/components/main_page/BUILD
+++ b/frontend/app/components/main_page/BUILD
@@ -29,6 +29,7 @@ xprof_ng_module(
"@org_xprof//frontend/app/components/empty_page",
"@org_xprof//frontend/app/components/framework_op_stats:framework_op_stats_adapter",
"@org_xprof//frontend/app/components/graph_viewer",
+ "@org_xprof//frontend/app/components/hlo_stats",
"@org_xprof//frontend/app/components/input_pipeline",
"@org_xprof//frontend/app/components/kernel_stats:kernel_stats_adapter",
"@org_xprof//frontend/app/components/memory_profile",
diff --git a/frontend/app/components/main_page/main_page_module.ts b/frontend/app/components/main_page/main_page_module.ts
index 6ff34df3..52c49d23 100644
--- a/frontend/app/components/main_page/main_page_module.ts
+++ b/frontend/app/components/main_page/main_page_module.ts
@@ -12,6 +12,8 @@ import {EmptyPageModule} from 'org_xprof/frontend/app/components/empty_page/empt
import {FrameworkOpStatsAdapter, FrameworkOpStatsAdapterModule} from 'org_xprof/frontend/app/components/framework_op_stats/framework_op_stats_adapter';
import {GraphViewer} from 'org_xprof/frontend/app/components/graph_viewer/graph_viewer';
import {GraphViewerModule} from 'org_xprof/frontend/app/components/graph_viewer/graph_viewer_module';
+import {HloStats} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats';
+import {HloStatsModule} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats_module';
import {InputPipeline} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline';
import {InputPipelineModule} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline_module';
import {KernelStatsAdapter, KernelStatsAdapterModule} from 'org_xprof/frontend/app/components/kernel_stats/kernel_stats_adapter';
@@ -61,6 +63,8 @@ export const routes: Routes = [
{path: 'trace_viewer^', component: TraceViewer},
{path: 'graph_viewer^', component: GraphViewer},
{path: 'dcn_collective_stats^', component: DcnCollectiveStats},
+ {path: 'hlo_stats', component: HloStats},
+ {path: 'hlo_stats^', component: HloStats},
{path: '**', component: EmptyPage},
];
@@ -87,6 +91,7 @@ export const routes: Routes = [
TfDataBottleneckAnalysisModule,
FrameworkOpStatsAdapterModule,
DcnCollectiveStatsModule,
+ HloStatsModule,
RouterModule.forRoot(routes),
],
exports: [MainPage]
diff --git a/plugin/tensorboard_plugin_profile/convert/BUILD b/plugin/tensorboard_plugin_profile/convert/BUILD
index 08665fda..26d574db 100644
--- a/plugin/tensorboard_plugin_profile/convert/BUILD
+++ b/plugin/tensorboard_plugin_profile/convert/BUILD
@@ -87,6 +87,16 @@ py_library(
],
)
+py_library(
+ name = "hlo_stats_proto_to_gviz",
+ srcs = ["hlo_stats_proto_to_gviz.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ requirement("gviz_api"),
+ "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2",
+ ],
+)
+
py_test(
name = "overview_page_proto_to_gviz_test",
size = "small",
@@ -236,6 +246,7 @@ py_library(
visibility = visibility,
deps = [
":dcn_collective_stats_proto_to_gviz",
+ ":hlo_stats_proto_to_gviz",
":input_pipeline_proto_to_gviz",
":kernel_stats_proto_to_gviz",
":overview_page_proto_to_gviz",
diff --git a/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py b/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py
new file mode 100644
index 00000000..d0ac3892
--- /dev/null
+++ b/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py
@@ -0,0 +1,162 @@
+# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""For conversion of HloStats protos to GViz DataTables.
+
+Usage:
+ gviz_data_tables = generate_hlo_stats_table(hlo_stats_db)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from typing import Any
+
+import gviz_api
+
+from tensorboard_plugin_profile.protobuf import hlo_stats_pb2
+
+
+def get_hlo_op_name_from_expression(hlo_op_expression: str) -> str:
+ """Returns the hlo op name from the hlo op expression.
+
+ Args:
+ hlo_op_expression: A string of hlo op expression.
+
+ Returns:
+ Returns the hlo op name extracted from the hlo op expression.
+ """
+ # The parse logic based on the assumption that the hlo op text is in format of
+ # '%op_name = '
+ parts = hlo_op_expression.split(" = ")
+ hlo_op_name = parts[0]
+ if hlo_op_name and hlo_op_name[0] == "%":
+ hlo_op_name = hlo_op_name[1:]
+ return hlo_op_name
+
+
+TableColumnDescription = tuple[str, str, str]
+TableRow = list[Any]
+TableProperties = dict[str, Any]
+
+
+def get_hlo_stats_table_args(
+ hlo_stats_db: hlo_stats_pb2.HloStatsDatabase,
+) -> tuple[list[TableColumnDescription], list[TableRow], TableProperties]:
+ """Creates hlo op stats table from a hlo stats proto.
+
+ Args:
+ hlo_stats_db: A HloStatsDatabase proto.
+
+ Returns:
+ Returns table description(column defubutuibs), data(rows data) and custom
+ properties for preparing the HloStats gviz table.
+ """
+
+ table_description = [
+ ("rank", "number", "Rank"),
+ ("program_id", "string", "Program id"),
+ ("category", "string", "HLO op category"),
+ ("hlo_op_name", "string", "HLO op name"),
+ ("hlo_op_expression", "string", "HLO op text"),
+ ("tf_op_name", "string", "Framework op name"),
+ ("occurrences", "number", "#Occurrences"),
+ ("total_time", "number", "Total time (us)"),
+ ("avg_time", "number", "Avg. time (us)"),
+ ("total_self_time", "number", "Total self time (us)"),
+ ("avg_self_time", "number", "Avg. self time (us)"),
+ ("total_self_time_percent", "number", "Total self time (%)"),
+ (
+ "cumulative_total_self_time_percent",
+ "number",
+ "Cumulative total self time (%)",
+ ),
+ ("dma_stall_percent", "number", "%time stalled by DMA"),
+ ("model_flop_rate", "number", "Model GFLOP/s"),
+ ("normalized_flop_rate", "number", "Normalized GFLOP/s"),
+ ("measured_memory_bw", "number", "Measured memory BW (GiB/s)"),
+ ("hbm_bw", "number", "HBM BW (GiB/s)"),
+ ("cmem_read_bw", "number", "CMEM Read BW (GiB/s)"),
+ ("cmem_write_bw", "number", "CMEM Write BW (GiB/s)"),
+ ("operational_intensity", "number", "Operational intensity (FLOPS/Byte)"),
+ ("bound_by", "string", "Bound by"),
+ ("hlo_rematerialization", "string", "Rematerialization"),
+ ("outside_compilation", "string", "Outside Compilation"),
+ ("autotuned", "string", "Autotuned"),
+ ]
+
+ data = []
+ for record in hlo_stats_db.hlo_stats_record:
+ row = [
+ record.rank,
+ str(record.program_id),
+ record.hlo_category,
+ get_hlo_op_name_from_expression(record.hlo_expression),
+ record.hlo_expression,
+ record.tf_op_name,
+ record.occurrences,
+ record.total_time_in_us,
+ record.avg_time_in_us,
+ record.total_self_time_in_us,
+ record.avg_self_time_in_us,
+ record.total_self_time_as_fraction,
+ record.cumulative_total_self_time_as_fraction,
+ record.dma_stall_fraction,
+ record.model_flop_rate,
+ record.measured_flop_rate,
+ record.measured_memory_bw,
+ record.hbm_bw,
+ record.cmem_read_bw,
+ record.cmem_write_bw,
+ record.operational_intensity,
+ record.bound_by,
+ "Yes" if record.rematerialization else "No",
+ "Yes" if record.outside_compilation else "No",
+ "Yes" if record.autotuned else "No",
+ ]
+ data.append(row)
+
+ return (table_description, data, {})
+
+
+def generate_hlo_stats_table(
+ hlo_stats_db: hlo_stats_pb2.HloStatsDatabase,
+) -> gviz_api.DataTable:
+ """Converts a HloStatsDb proto to a GViz DataTable.
+
+ Args:
+ hlo_stats_db: A HloStatsDatabase proto.
+
+ Returns:
+ Returns a GViz DataTable for HloStats.
+ """
+ (table_description, data, custom_properties) = get_hlo_stats_table_args(
+ hlo_stats_db
+ )
+ return gviz_api.DataTable(table_description, data, custom_properties)
+
+
+def to_json(raw_data: bytes) -> str:
+ """Converts a serialized HloStatsDb bytes to json.
+
+ Args:
+ raw_data: Bytes of serialized HloStatsDb.
+
+ Returns:
+ Returns a json string of HloStats Table.
+ """
+ hlo_stats_db = hlo_stats_pb2.HloStatsDatabase()
+ hlo_stats_db.ParseFromString(raw_data)
+ return generate_hlo_stats_table(hlo_stats_db).ToJSon()
diff --git a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py
index c2acb269..095e3540 100644
--- a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py
+++ b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py
@@ -27,6 +27,7 @@
from tensorflow.python.profiler.internal import _pywrap_profiler_plugin # pylint: disable=g-direct-tensorflow-import
from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz
+from tensorboard_plugin_profile.convert import hlo_stats_proto_to_gviz
from tensorboard_plugin_profile.convert import input_pipeline_proto_to_gviz
from tensorboard_plugin_profile.convert import kernel_stats_proto_to_gviz
from tensorboard_plugin_profile.convert import overview_page_proto_to_gviz
@@ -169,6 +170,10 @@ def xspace_to_tool_data(
raw_data, success = xspace_wrapper_func(xspace_paths, tool)
if success:
data = raw_data
+ elif tool == 'hlo_stats':
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool)
+ if success:
+ data = hlo_stats_proto_to_gviz.to_json(raw_data)
elif tool == 'graph_viewer':
options = params.get('graph_viewer_options', {})
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
diff --git a/plugin/tensorboard_plugin_profile/profile_plugin.py b/plugin/tensorboard_plugin_profile/profile_plugin.py
index 8a9071ec..832ebe85 100644
--- a/plugin/tensorboard_plugin_profile/profile_plugin.py
+++ b/plugin/tensorboard_plugin_profile/profile_plugin.py
@@ -110,6 +110,7 @@
'pod_viewer^',
'tf_data_bottleneck_analysis^',
'op_profile^',
+ 'hlo_stats^',
]
# XPlane generated tools that support all host mode.
diff --git a/plugin/tensorboard_plugin_profile/protobuf/BUILD b/plugin/tensorboard_plugin_profile/protobuf/BUILD
index df25677e..9d048700 100644
--- a/plugin/tensorboard_plugin_profile/protobuf/BUILD
+++ b/plugin/tensorboard_plugin_profile/protobuf/BUILD
@@ -12,6 +12,7 @@ proto_library(
srcs = [
"dcn_slack_analysis.proto",
"diagnostics.proto",
+ "hlo_stats.proto",
"input_pipeline.proto",
"kernel_stats.proto",
"overview_page.proto",
@@ -28,6 +29,7 @@ py_proto_library(
srcs = [
"dcn_slack_analysis.proto",
"diagnostics.proto",
+ "hlo_stats.proto",
"input_pipeline.proto",
"kernel_stats.proto",
"overview_page.proto",
diff --git a/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto b/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto
new file mode 100644
index 00000000..749a3f1e
--- /dev/null
+++ b/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto
@@ -0,0 +1,124 @@
+// This proto describes the format of the output profile file from
+// the HLO stats tool.
+syntax = "proto2";
+
+package tensorflow.profiler.plugin.hlo_stats;
+
+// A database of HloStats records.
+message HloStatsDatabase {
+ // All HloStats records, one for each HLO operation.
+ repeated HloStatsRecord hlo_stats_record = 1;
+}
+
+// There is one HloStatsRecord for each HLO operation profiled.
+// Next ID: 39
+message HloStatsRecord {
+ // The rank by self time
+ optional uint64 rank = 1;
+
+ // program_id for this op
+ optional uint64 program_id = 30;
+
+ // The HLO category name.
+ optional string hlo_category = 17;
+
+ // The HLO expression.
+ optional string hlo_expression = 2;
+
+ // The framework op name (TF Op, JAX Op)
+ optional string tf_op_name = 21;
+
+ // Number of occurrences of the operation.
+ optional int64 occurrences = 3;
+
+ // Total "accumulated" time in micro-seconds that the operation
+ // took. If this operation has any children operations,
+ // the "accumulated" time includes the time spent inside children.
+ optional double total_time_in_us = 4;
+
+ // Average "accumulated" time in micro-seconds that each
+ // occurrence of the operation took.
+ optional double avg_time_in_us = 5;
+
+ // Total "self" time in micro-seconds that the operation took.
+ // If this operation has any children operations, the "self" time
+ // doesn't include the time spent inside children.
+ optional double total_self_time_in_us = 6;
+
+ // Average "self" time in micro-seconds that the operation took.
+ optional double avg_self_time_in_us = 7;
+ optional double total_self_time_as_fraction = 8;
+ optional double cumulative_total_self_time_as_fraction = 9;
+
+ // Percentage of the total "accumulated" time that was caused by
+ // DMA stall.
+ optional double dma_stall_fraction = 10;
+
+ // Total floating-point operations (FLOPs) performed per second normalized to
+ // the bf16 peak capacity.
+ optional double measured_flop_rate = 13;
+
+ // Total Floating-point operations for the op per second.
+ optional double model_flop_rate = 34;
+
+ // Number of total bytes (including both read and write) accessed per second.
+ optional double measured_memory_bw = 14;
+
+ // Number of bytes accessed from HBM (including both read and write) per
+ // second.
+ optional double hbm_bw = 22;
+
+ // Number of bytes read from CMEM per second.
+ optional double cmem_read_bw = 23;
+
+ // Number of bytes written to CMEM per second.
+ optional double cmem_write_bw = 24;
+
+ // Number of bytes read from VMEM per second.
+ optional double vmem_read_bw = 35;
+
+ // Number of bytes written to VMEM per second.
+ optional double vmem_write_bw = 36;
+
+ // Overall operational intensity in FLOP/Byte.
+ optional double operational_intensity = 15;
+
+ // Operational intensity based on HBM in FLOP/Byte.
+ optional double hbm_operational_intensity = 26;
+
+ // Operational intensity based on CMEM read in FLOP/Byte.
+ optional double cmem_read_operational_intensity = 27;
+
+ // Operational intensity based on CMEM write in FLOP/Byte.
+ optional double cmem_write_operational_intensity = 28;
+
+ // Operational intensity based on VMEM read in FLOP/Byte.
+ optional double vmem_read_operational_intensity = 37;
+
+ // Operational intensity based on VMEM write in FLOP/Byte.
+ optional double vmem_write_operational_intensity = 38;
+
+ // Operational intensity based on the bottleneck resource in FLOP/Byte.
+ optional double bottleneck_operational_intensity = 29;
+
+ // Whether this operation is "Compute", "HBM", "CMEM Read", "CMEM Write"
+ // bound, according to the Roofline Model.
+ optional string bound_by = 16;
+
+ // Whether this operation is for HLO or Framework rematerialization.
+ optional bool rematerialization = 20;
+
+ // Whether this op is for outside compilation.
+ optional bool outside_compilation = 25;
+
+ // Whether this op is autotuned.
+ optional bool autotuned = 31;
+
+ // Flops for the record
+ optional uint64 flops = 32;
+
+ // Bytes accessed for the record
+ optional uint64 bytes_accessed = 33;
+
+ reserved 11, 12, 18, 19;
+}