From 7b5c356ac83077402f05ab184077a02f27e25d8a Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Wed, 20 Nov 2024 16:08:30 -0800 Subject: [PATCH] Create constants, interface and child analysis components for Roofline Model PiperOrigin-RevId: 698554594 --- frontend/app/common/constants/BUILD | 7 + .../constants/roofline_model_constants.ts | 250 +++++++++++++++ .../app/common/interfaces/roofline_model.ts | 25 ++ frontend/app/components/hlo_stats/BUILD | 49 +++ .../components/hlo_stats/hlo_stats.ng.html | 100 ++++++ .../app/components/hlo_stats/hlo_stats.scss | 35 +++ .../app/components/hlo_stats/hlo_stats.ts | 294 ++++++++++++++++++ .../components/hlo_stats/hlo_stats_module.ts | 31 ++ frontend/app/components/main_page/BUILD | 1 + .../components/main_page/main_page_module.ts | 5 + .../operation_level_analysis/BUILD | 36 +++ .../operation_level_analysis.ng.html | 12 + .../operation_level_analysis.scss | 7 + .../operation_level_analysis.ts | 163 ++++++++++ .../operation_level_analysis_module.ts | 14 + .../program_level_analysis/BUILD | 33 ++ .../program_level_analysis.ng.html | 5 + .../program_level_analysis.scss | 3 + .../program_level_analysis.ts | 108 +++++++ .../program_level_analysis_module.ts | 13 + .../tensorboard_plugin_profile/convert/BUILD | 11 + .../convert/hlo_stats_proto_to_gviz.py | 128 ++++++++ .../convert/raw_to_tool_data.py | 5 + .../profile_plugin.py | 1 + .../tensorboard_plugin_profile/protobuf/BUILD | 2 + .../protobuf/hlo_stats.proto | 124 ++++++++ 26 files changed, 1462 insertions(+) create mode 100644 frontend/app/common/constants/roofline_model_constants.ts create mode 100644 frontend/app/common/interfaces/roofline_model.ts create mode 100644 frontend/app/components/hlo_stats/BUILD create mode 100644 frontend/app/components/hlo_stats/hlo_stats.ng.html create mode 100644 frontend/app/components/hlo_stats/hlo_stats.scss create mode 100644 frontend/app/components/hlo_stats/hlo_stats.ts create mode 100644 frontend/app/components/hlo_stats/hlo_stats_module.ts create mode 100644 frontend/app/components/roofline_model/operation_level_analysis/BUILD create mode 100644 frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ng.html create mode 100644 frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.scss create mode 100644 frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ts create mode 100644 frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis_module.ts create mode 100644 frontend/app/components/roofline_model/program_level_analysis/BUILD create mode 100644 frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ng.html create mode 100644 frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.scss create mode 100644 frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ts create mode 100644 frontend/app/components/roofline_model/program_level_analysis/program_level_analysis_module.ts create mode 100644 plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py create mode 100644 plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto diff --git a/frontend/app/common/constants/BUILD b/frontend/app/common/constants/BUILD index 3a0092ca..95afc064 100644 --- a/frontend/app/common/constants/BUILD +++ b/frontend/app/common/constants/BUILD @@ -10,6 +10,13 @@ ts_library( ], ) +ts_library( + name = "roofline_model_constants", + srcs = [ + "roofline_model_constants.ts", + ], +) + ts_library( name = "testing", srcs = [ diff --git a/frontend/app/common/constants/roofline_model_constants.ts b/frontend/app/common/constants/roofline_model_constants.ts new file mode 100644 index 00000000..3eb75f4c --- /dev/null +++ b/frontend/app/common/constants/roofline_model_constants.ts @@ -0,0 +1,250 @@ +/** + * color palette for roofline model pie chart + */ +export const PIE_CHART_PALETTE = [ + '#3366CC', '#FF9900', '#109618', '#990099', '#3B3EAC', '#0099C6', '#DD4477', + '#66AA00', '#B82E2E', '#316395', '#994499', '#22AA99', '#AAAA11', '#6633CC', + '#E67300', '#8B0707', '#329262', '#5574A6', '#3B3EAC', +]; + +/** axis boundary for roofline model scatter chart */ +export const SCATTER_CHART_AXIS = { + minX: 0.00001, + maxX: 100000, + minY: 0.00001, + maxY: 1000000, +}; + +/** scatter base options for roofline chart */ +export const SCATTER_CHART_OPTIONS = { + title: 'Roofline Model', + width: 720, + height: 400, + hAxis: { + title: 'FLOP/Byte (log scale)', + scaleType: 'log', + viewWindow: { + min: SCATTER_CHART_AXIS.minX, + max: SCATTER_CHART_AXIS.maxX, + }, + // Ticks have to be explicitly defined for scaling axis evenly. + ticks: [ + 0, + 0.00001, + 0.0001, + 0.001, + 0.01, + 0.1, + 1, + 10, + 100, + 1000, + 10000, + 100000, + ], + }, + vAxis: { + title: 'GFLOP/s (log scale)', + scaleType: 'log', + viewWindow: { + min: SCATTER_CHART_AXIS.minY, + max: SCATTER_CHART_AXIS.maxY, + }, + // Ticks have to be explicitly defined for scaling axis evenly. + ticks: [0, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000, 100000, 1000000], + }, + legend: {position: 'right'}, + tooltip: {isHtml: true}, + // Be mindful that series is not specified here, otherwise the shallow copy in + // components of this object could result in overwritten series styles +}; + +/** roofline plot line styles */ +export const ROOFLINE_STYLES = { + read: { + lineWidth: 1, + color: 'red', + pointsVisible: false, + }, + write: { + lineWidth: 1, + color: 'green', + pointsVisible: false, + }, + hbm: { + lineWidth: 1, + color: 'black', + pointsVisible: false, + }, +}; + +/** roofline model properties configuration */ +export const DEVICE_INFO = [ + { + id: 'device_type', + label: 'Device Type', + type: 'string', + display: true, + }, + { + id: 'megacore', + label: 'Megacore', + type: 'string', + context: '', + display: true, + }, + { + id: 'peak_flop_rate', + label: 'Peak FLOP Rate per TensorCore', + type: 'number', + unit: 'GFLOP/s', + display: true, + }, + { + id: 'peak_hbm_bw', + label: 'Peak HBM Bandwidth per TensorCore', + type: 'number', + unit: 'GiB/s', + context: '', + display: true, + }, + { + id: 'peak_vmem_read_bw', + label: 'Peak VMEM Read Bandwidth per TensorCore', + type: 'number', + unit: 'GiB/s', + display: true, + }, + { + id: 'peak_vmem_write_bw', + label: 'Peak VMEM Write Bandwidth per TensorCore', + type: 'number', + unit: 'GiB/s', + display: true, + }, + { + id: 'peak_cmem_read_bw', + label: 'Peak CMEM Read Bandwidth per TensorCore', + type: 'number', + unit: 'GiB/s', + display: true, + }, + { + id: 'peak_cmem_write_bw', + label: 'Peak CMEM Write Bandwidth per TensorCore', + type: 'number', + unit: 'GiB/s', + display: true, + }, + { + id: 'cmem_write_ridge_point', + label: 'CMEM Write Ridge Point', + type: 'number', + unit: 'Flop/byte', + display: false, + }, + { + id: 'cmem_read_ridge_point', + label: 'CMEM Read Ridge Point', + type: 'number', + unit: 'Flop/byte', + display: false, + }, + { + id: 'vmem_write_ridge_point', + label: 'VMEM Write Ridge Point', + type: 'number', + unit: 'Flop/byte', + display: false, + }, + { + id: 'vmem_read_ridge_point', + label: 'VMEM Read Ridge Point', + type: 'number', + unit: 'Flop/byte', + display: false, + }, + { + id: 'hbm_ridge_point', + label: 'HBM Ridge Point', + type: 'number', + unit: 'Flop/byte', + display: false, + }, +]; + +/** + * numeric data display formatting config + * might be feasible to be passed through api response data + */ +export const NUMERIC_DATA_FORMAT: + {[key: string]: {type: string; digit?: number};} = { + 'total_time_per_core': { + type: 'decimal', + digit: 0, + }, + 'total_time': { + type: 'decimal', + digit: 2, + }, + 'avg_time': { + type: 'decimal', + digit: 2, + }, + 'total_self_time': { + type: 'decimal', + digit: 2, + }, + 'avg_self_time': { + type: 'decimal', + digit: 2, + }, + 'measured_flop_rate': { + type: 'decimal', + digit: 2, + }, + 'measured_memory_bw': { + type: 'decimal', + digit: 2, + }, + 'hbm_bw': { + type: 'decimal', + digit: 2, + }, + 'cmem_read_bw': { + type: 'decimal', + digit: 2, + }, + 'cmem_write_bw': { + type: 'decimal', + digit: 2, + }, + 'operational_intensity': { + type: 'decimal', + digit: 2, + }, + 'total_self_time_percent': { + type: 'percent', + digit: 1, + }, + 'cumulative_total_self_time_percent': { + type: 'percent', + digit: 1, + }, + 'dma_stall_percent': { + type: 'percent', + digit: 1, + }, + 'roofline_efficiency': { + type: 'percent', + digit: 1, + }, + 'compute_efficiency': { + type: 'percent', + digit: 1, + }, + 'max_mem_bw_utilization': { + type: 'percent', + digit: 1, + }, + }; diff --git a/frontend/app/common/interfaces/roofline_model.ts b/frontend/app/common/interfaces/roofline_model.ts new file mode 100644 index 00000000..b7968142 --- /dev/null +++ b/frontend/app/common/interfaces/roofline_model.ts @@ -0,0 +1,25 @@ +import {SimpleDataTable,} from 'org_xprof/frontend/app/common/interfaces/data_table'; + +/** The base interfact for roofline model property. */ +declare interface RooflineModelProperty { + device_type?: string; + megacore?: string; + has_cmem?: string; + has_merged_vmem?: string; + peak_flop_rate?: string; + peak_hbm_bw?: string; + peak_cmem_read_bw?: string; + peak_cmem_write_bw?: string; + peak_vmem_read_bw?: string; + peak_vmem_write_bw?: string; + hbm_ridge_point?: string; + cmem_read_ridge_point?: string; + cmem_write_ridge_point?: string; + vmem_read_ridge_point?: string; + vmem_write_ridge_point?: string; +} + +/** The base interfact for roofline model. */ +export declare interface RooflineModelData extends SimpleDataTable { + p: RooflineModelProperty; +} diff --git a/frontend/app/components/hlo_stats/BUILD b/frontend/app/components/hlo_stats/BUILD new file mode 100644 index 00000000..aa85a015 --- /dev/null +++ b/frontend/app/components/hlo_stats/BUILD @@ -0,0 +1,49 @@ +load("@io_bazel_rules_sass//:defs.bzl", "sass_binary") +load("//defs:defs.bzl", "xprof_ng_module") + +package(default_visibility = ["//frontend:internal"]) + +xprof_ng_module( + name = "hlo_stats", + srcs = [ + "hlo_stats.ts", + "hlo_stats_module.ts", + ], + assets = [ + ":hlo_stats_css", + "hlo_stats.ng.html", + ], + deps = [ + "@npm//@angular/core", + "@npm//@angular/router", + "@npm//@ngrx/store", + "@npm//rxjs", + "@org_xprof//frontend/app/common/angular:angular_material_divider", + "@org_xprof//frontend/app/common/angular:angular_material_expansion", + "@org_xprof//frontend/app/common/angular:angular_material_select", + "@org_xprof//frontend/app/common/angular:angular_material_tooltip", + "@org_xprof//frontend/app/common/constants", + "@org_xprof//frontend/app/common/interfaces", + "@org_xprof//frontend/app/common/interfaces:chart", + "@org_xprof//frontend/app/common/utils", + "@org_xprof//frontend/app/components/chart", + "@org_xprof//frontend/app/components/chart:chart_options", + "@org_xprof//frontend/app/components/chart:data_processor", + "@org_xprof//frontend/app/components/chart:default_data_provider", + "@org_xprof//frontend/app/components/chart/dashboard", + "@org_xprof//frontend/app/components/controls/category_filter", + "@org_xprof//frontend/app/components/controls/string_filter", + "@org_xprof//frontend/app/components/framework_op_stats/flop_rate_chart", + "@org_xprof//frontend/app/services/data_service", + "@org_xprof//frontend/app/store", + ], +) + +sass_binary( + name = "hlo_stats_css", + src = "hlo_stats.scss", + sourcemap = False, + deps = [ + "@org_xprof//frontend/app/styles:common", + ], +) diff --git a/frontend/app/components/hlo_stats/hlo_stats.ng.html b/frontend/app/components/hlo_stats/hlo_stats.ng.html new file mode 100644 index 00000000..7ba5d435 --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats.ng.html @@ -0,0 +1,100 @@ +
+
+
+
Notes
+
+
+ (1) You can filter the table by HLO op category, HLO op name or Framework op name (eg. JAX op, TF op). +
+
+ (2) In the charts and table below, "IDLE" represents the portion of the total execution + time on device that is idle.. +
+
+ (3) In the pie charts, the "Other" sector represents the sum of sectors that are too + small to be shown individually. +
+
+
+
+
+ +
+ + + + Chart Section + + +
+
+
Time per HLO category
+ +
+
+
Time per HLO operation
+ +
+
+
Time spent on rematerialization
+ +
+
+
Time per HLO by Replica Groups
+ + + {{op}} + + + +
+
+ +
+
+
Time spent on rematerialization per HLO category
+ +
+
+
Time spent on outside compilation
+ +
+ + + +
+
+
+ +
+
+ + + + + + + + + + +
+
+ +
+ + +
diff --git a/frontend/app/components/hlo_stats/hlo_stats.scss b/frontend/app/components/hlo_stats/hlo_stats.scss new file mode 100644 index 00000000..93a572ca --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats.scss @@ -0,0 +1,35 @@ +/** CSS for an hlo stats component. */ + +@import 'frontend/app/styles/common'; + +:host { + display: block; +} + +.section-container { + padding: 20px; +} + +.header-content { + font-style: italic; + margin-top: -16px; +} + +.row { + display: flex; +} + +.flex-space { + flex: 1; +} + +.pie-chart { + height: fit-content; + margin: 8px 8px 8px 0; + width: 450px; +} + +.table { + border: 1px solid $border-color; + margin: 0 16px; +} diff --git a/frontend/app/components/hlo_stats/hlo_stats.ts b/frontend/app/components/hlo_stats/hlo_stats.ts new file mode 100644 index 00000000..1b3ba676 --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats.ts @@ -0,0 +1,294 @@ +import {Component, OnDestroy} from '@angular/core'; +import {ActivatedRoute} from '@angular/router'; +import {Store} from '@ngrx/store'; +import {OpType} from 'org_xprof/frontend/app/common/constants/enums'; +import {ChartDataInfo} from 'org_xprof/frontend/app/common/interfaces/chart'; +import {SimpleDataTable} from 'org_xprof/frontend/app/common/interfaces/data_table'; +import {NavigationEvent} from 'org_xprof/frontend/app/common/interfaces/navigation_event'; +import {setLoadingState} from 'org_xprof/frontend/app/common/utils/utils'; +import {CategoryTableDataProcessor} from 'org_xprof/frontend/app/components/chart/category_table_data_processor'; +import {PIE_CHART_OPTIONS, TABLE_OPTIONS,} from 'org_xprof/frontend/app/components/chart/chart_options'; +import {Dashboard} from 'org_xprof/frontend/app/components/chart/dashboard/dashboard'; +import {DefaultDataProvider, ReplicaGroupDataProvider,} from 'org_xprof/frontend/app/components/chart/default_data_provider'; +import {DataService} from 'org_xprof/frontend/app/services/data_service/data_service'; +import {setCurrentToolStateAction} from 'org_xprof/frontend/app/store/actions'; +import {ReplaySubject} from 'rxjs'; +import {takeUntil} from 'rxjs/operators'; + +const OP_CATEGORY_ID = 'category'; +const OP_NAME_ID = 'hlo_op_name'; +const PROGRAM_ID = 'program_id'; +const OP_EXPRESSION_ID = 'hlo_op_expression'; +const SELF_TIME_ID = 'total_self_time'; +const HLO_REMAT_ID = 'hlo_rematerialization'; +const OUTSIDE_COMPILATION_ID = 'outside_compilation'; +const MEASURED_FLOP_RATE_ID = 'measured_flop_rate'; + +/** A Hlo Stats component. */ +@Component({ + standalone: false, + selector: 'hlo-stats', + templateUrl: './hlo_stats.ng.html', + styleUrls: ['./hlo_stats.css'], +}) +export class HloStats extends Dashboard implements OnDestroy { + readonly tool = 'hlo_stats'; + /** Handles on-destroy Subject, used to unsubscribe. */ + private readonly destroyed = new ReplaySubject(1); + sessionId = ''; + data: SimpleDataTable | null = null; + hloOpNameSelected = ''; + programIdSelected = ''; + // Flop rate chart properties. + readonly opType = OpType.XLA_HLO; + flopRateChartXColumn = -1; + flopRateChartYColumn = -1; + // Pie charts properties. + pieChartDataProvider = new DefaultDataProvider(); + replicaGroupDataProvider = new ReplicaGroupDataProvider(); + dataInfoCategoryChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoOpChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + communicationOps = new Set(); + selectedCommOp = ''; + dataInfoOpReplicaGroupChart: ChartDataInfo = { + data: null, + dataProvider: this.replicaGroupDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoRematerializationChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoRematerializationCategoryChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + dataInfoOutsideCompilationChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: PIE_CHART_OPTIONS, + }; + // Table properties. + dataInfoForTable: ChartDataInfo = { + data: null, + dataProvider: new DefaultDataProvider(), + filters: [], + options: { + ...TABLE_OPTIONS, + showRowNumber: false, + page: 'enable', + pageSize: 100, + sortAscending: true, + sortColumn: 0, + }, + }; + showChartSection = true; + + constructor( + route: ActivatedRoute, + private readonly dataService: DataService, + private readonly store: Store<{}>, + ) { + super(); + route.params.pipe(takeUntil(this.destroyed)).subscribe((params) => { + this.update(params as NavigationEvent); + }); + this.store.dispatch(setCurrentToolStateAction({currentTool: this.tool})); + } + + update(event: NavigationEvent) { + const run = event.run || ''; + const tag = event.tag || 'hlo_stats'; + const host = event.host || ''; + + setLoadingState(true, this.store, 'Loading hlo data'); + + this.dataService.getData(run, tag, host) + .pipe(takeUntil(this.destroyed)) + .subscribe((data) => { + setLoadingState(false, this.store); + this.data = data as SimpleDataTable | null; + this.process(this.data); + this.onCheckInputParams(); + }); + } + + onCheckInputParams() { + this.hloOpNameSelected = + this.dataService.searchParams?.get('hlo_op_name') || ''; + // Assumption: the program_id is in format like 'main()' + // parsing with a regex to match content in the bracket + const programIdParsed = this.dataService.searchParams + ?.get('program_id') + ?.match(/\((.*)\)/); + this.programIdSelected = + programIdParsed?.length === 2 ? programIdParsed[1] : ''; + } + + // Iterate through the table data + // and inject graph link to the hlo op text cell + addGraphViewerLinkInTableData(data: SimpleDataTable) { + const programIdColumnIdx = + data.cols?.findIndex((col) => col.id === PROGRAM_ID) ?? -1; + const hloOpExpressionColumnIdx = + data.cols?.findIndex((col) => col.id === OP_EXPRESSION_ID) ?? -1; + const hloOpNameColumnIdx = + data.cols?.findIndex((col) => col.id === OP_NAME_ID) ?? -1; + if (programIdColumnIdx === -1 || hloOpExpressionColumnIdx === -1) { + return data; + } + + const updatedData = { + ...data, + rows: data?.rows!.map((row, index) => { + const programId = (row.c![programIdColumnIdx].v as string).trim() || ''; + const hloOpName = (row.c![hloOpNameColumnIdx].v as string).trim() || ''; + const hloOpExpression = + (row.c![hloOpExpressionColumnIdx].v as string) || ''; + const graphViewerLink = `/graph_viewer/${this.sessionId}?program_id=${programId}&node_name=${hloOpName}`; + return { + ...row, + c: [ + ...row.c!.slice(0, hloOpExpressionColumnIdx), + { + ...row.c![hloOpExpressionColumnIdx], + v: `${hloOpExpression}`, + }, + ...row.c!.slice(hloOpExpressionColumnIdx + 1), + ], + }; + }), + }; + return updatedData; + } + + private process(data: SimpleDataTable | null) { + if (!data) return; + + this.parseData(data); + this.drawFlopRateChart(); + this.updateOpReplicaGroupChart(); + + const updatedData = this.addGraphViewerLinkInTableData(data); + this.dataInfoForTable = { + ...this.dataInfoForTable, + data: updatedData, + }; + } + + updateView() { + this.dataInfoForTable = { + ...this.dataInfoForTable, + filters: this.getFilters(), + }; + } + + updateOpReplicaGroupChart() { + if ( + !this.replicaGroupDataProvider.opCategoryIndex || + !this.replicaGroupDataProvider.hloOpNameIndex || + !this.replicaGroupDataProvider.selfTimeIndex + ) { + return; + } + + const filtersForReplicaGroup = [ + { + column: this.replicaGroupDataProvider.opCategoryIndex, + value: this.selectedCommOp, + }, + ]; + + this.dataInfoOpReplicaGroupChart.customChartDataProcessor = + new CategoryTableDataProcessor( + filtersForReplicaGroup, + this.replicaGroupDataProvider.hloOpNameIndex, + this.replicaGroupDataProvider.selfTimeIndex, + ); + + // Since the DataInfo has not been updated, the notifyCharts function is + // called to redraw the graph. + this.replicaGroupDataProvider.notifyCharts(); + } + + parseData(data: SimpleDataTable | null) { + if (!data) return; + // Five charts share one DataProvider. In order to prevent DataTable from + // being created multiple times, it calls DataProvider function directly. + this.pieChartDataProvider.parseData(data); + const dataTable = this.pieChartDataProvider.getDataTable(); + if (!dataTable) return; + + this.dataTable = dataTable; + this.updateView(); + + const hloOpNameIndex = dataTable.getColumnIndex(OP_EXPRESSION_ID); + const opCategoryIndex = dataTable.getColumnIndex(OP_CATEGORY_ID); + const selfTimeIndex = dataTable.getColumnIndex(SELF_TIME_ID); + const hloRematIndex = dataTable.getColumnIndex(HLO_REMAT_ID); + const outsideCompilationIndex = dataTable.getColumnIndex( + OUTSIDE_COMPILATION_ID, + ); + + const filtersForRemat = [{column: hloRematIndex, value: 'Yes'}]; + + this.dataInfoCategoryChart.customChartDataProcessor = + new CategoryTableDataProcessor([], opCategoryIndex, selfTimeIndex); + this.dataInfoOpChart.customChartDataProcessor = + new CategoryTableDataProcessor([], hloOpNameIndex, selfTimeIndex); + this.dataInfoRematerializationChart.customChartDataProcessor = + new CategoryTableDataProcessor([], hloRematIndex, selfTimeIndex, false); + this.dataInfoRematerializationCategoryChart.customChartDataProcessor = + new CategoryTableDataProcessor( + filtersForRemat, + opCategoryIndex, + selfTimeIndex, + ); + this.dataInfoOutsideCompilationChart.customChartDataProcessor = + new CategoryTableDataProcessor( + [], + outsideCompilationIndex, + selfTimeIndex, + false, + ); + + // Since the DataInfo has not been updated, the notifyCharts function is + // called to redraw the graph. + this.pieChartDataProvider.notifyCharts(); + + // Create a DataProvider in which the row string value for hloOpName column + // is truncated to only be the 'replica_groups={{...}}' string. + this.replicaGroupDataProvider.parseData(data); + this.communicationOps = this.replicaGroupDataProvider.communicationOps; + + if (this.communicationOps.size) { + // Set value to the first communication Op in the set. + this.selectedCommOp = this.communicationOps.values().next().value; + } + } + + private drawFlopRateChart() { + if (!this.dataTable || !this.dataTable.getColumnIndex) return; + this.flopRateChartXColumn = this.dataTable.getColumnIndex(OP_EXPRESSION_ID); + this.flopRateChartYColumn = this.dataTable.getColumnIndex( + MEASURED_FLOP_RATE_ID, + ); + } + + ngOnDestroy() { + // Unsubscribes all pending subscriptions. + setLoadingState(false, this.store); + this.destroyed.next(); + this.destroyed.complete(); + } +} diff --git a/frontend/app/components/hlo_stats/hlo_stats_module.ts b/frontend/app/components/hlo_stats/hlo_stats_module.ts new file mode 100644 index 00000000..8cf73e5c --- /dev/null +++ b/frontend/app/components/hlo_stats/hlo_stats_module.ts @@ -0,0 +1,31 @@ +import {CommonModule} from '@angular/common'; +import {NgModule} from '@angular/core'; +import {MatDividerModule} from '@angular/material/divider'; +import {MatExpansionModule} from '@angular/material/expansion'; +import {MatSelectModule} from '@angular/material/select'; +import {MatTooltipModule} from '@angular/material/tooltip'; +import {ChartModule} from 'org_xprof/frontend/app/components/chart/chart'; +import {CategoryFilterModule} from 'org_xprof/frontend/app/components/controls/category_filter/category_filter_module'; +import {StringFilterModule} from 'org_xprof/frontend/app/components/controls/string_filter/string_filter_module'; +import {FlopRateChartModule} from 'org_xprof/frontend/app/components/framework_op_stats/flop_rate_chart/flop_rate_chart_module'; + +import {HloStats} from './hlo_stats'; + +/** An HLO stats module. */ +@NgModule({ + declarations: [HloStats], + imports: [ + ChartModule, + CommonModule, + FlopRateChartModule, + CategoryFilterModule, + StringFilterModule, + MatDividerModule, + MatSelectModule, + MatExpansionModule, + MatTooltipModule, + ], + exports: [HloStats], +}) +export class HloStatsModule { +} diff --git a/frontend/app/components/main_page/BUILD b/frontend/app/components/main_page/BUILD index fa609499..33ed3893 100644 --- a/frontend/app/components/main_page/BUILD +++ b/frontend/app/components/main_page/BUILD @@ -29,6 +29,7 @@ xprof_ng_module( "@org_xprof//frontend/app/components/empty_page", "@org_xprof//frontend/app/components/framework_op_stats:framework_op_stats_adapter", "@org_xprof//frontend/app/components/graph_viewer", + "@org_xprof//frontend/app/components/hlo_stats", "@org_xprof//frontend/app/components/input_pipeline", "@org_xprof//frontend/app/components/kernel_stats:kernel_stats_adapter", "@org_xprof//frontend/app/components/memory_profile", diff --git a/frontend/app/components/main_page/main_page_module.ts b/frontend/app/components/main_page/main_page_module.ts index 6ff34df3..52c49d23 100644 --- a/frontend/app/components/main_page/main_page_module.ts +++ b/frontend/app/components/main_page/main_page_module.ts @@ -12,6 +12,8 @@ import {EmptyPageModule} from 'org_xprof/frontend/app/components/empty_page/empt import {FrameworkOpStatsAdapter, FrameworkOpStatsAdapterModule} from 'org_xprof/frontend/app/components/framework_op_stats/framework_op_stats_adapter'; import {GraphViewer} from 'org_xprof/frontend/app/components/graph_viewer/graph_viewer'; import {GraphViewerModule} from 'org_xprof/frontend/app/components/graph_viewer/graph_viewer_module'; +import {HloStats} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats'; +import {HloStatsModule} from 'org_xprof/frontend/app/components/hlo_stats/hlo_stats_module'; import {InputPipeline} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline'; import {InputPipelineModule} from 'org_xprof/frontend/app/components/input_pipeline/input_pipeline_module'; import {KernelStatsAdapter, KernelStatsAdapterModule} from 'org_xprof/frontend/app/components/kernel_stats/kernel_stats_adapter'; @@ -61,6 +63,8 @@ export const routes: Routes = [ {path: 'trace_viewer^', component: TraceViewer}, {path: 'graph_viewer^', component: GraphViewer}, {path: 'dcn_collective_stats^', component: DcnCollectiveStats}, + {path: 'hlo_stats', component: HloStats}, + {path: 'hlo_stats^', component: HloStats}, {path: '**', component: EmptyPage}, ]; @@ -87,6 +91,7 @@ export const routes: Routes = [ TfDataBottleneckAnalysisModule, FrameworkOpStatsAdapterModule, DcnCollectiveStatsModule, + HloStatsModule, RouterModule.forRoot(routes), ], exports: [MainPage] diff --git a/frontend/app/components/roofline_model/operation_level_analysis/BUILD b/frontend/app/components/roofline_model/operation_level_analysis/BUILD new file mode 100644 index 00000000..7f5942e6 --- /dev/null +++ b/frontend/app/components/roofline_model/operation_level_analysis/BUILD @@ -0,0 +1,36 @@ +load("@io_bazel_rules_sass//:defs.bzl", "sass_binary") +load("//defs:defs.bzl", "xprof_ng_module") + +package(default_visibility = ["//frontend:internal"]) + +xprof_ng_module( + name = "operation_level_analysis", + srcs = [ + "operation_level_analysis.ts", + "operation_level_analysis_module.ts", + ], + assets = [ + ":operation_level_analysis_css", + "operation_level_analysis.ng.html", + ], + deps = [ + "@npm//@angular/core", + "@org_xprof//frontend/app/common/constants:roofline_model_constants", + "@org_xprof//frontend/app/common/interfaces", + "@org_xprof//frontend/app/common/interfaces:chart", + "@org_xprof//frontend/app/components/chart", + "@org_xprof//frontend/app/components/chart:chart_options", + "@org_xprof//frontend/app/components/chart:data_processor", + "@org_xprof//frontend/app/components/chart:default_data_provider", + "@org_xprof//frontend/app/components/chart/dashboard", + "@org_xprof//frontend/app/components/chart/table", + "@org_xprof//frontend/app/components/controls/category_filter", + "@org_xprof//frontend/app/components/controls/string_filter", + ], +) + +sass_binary( + name = "operation_level_analysis_css", + src = "operation_level_analysis.scss", + sourcemap = False, +) diff --git a/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ng.html b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ng.html new file mode 100644 index 00000000..1f5b0200 --- /dev/null +++ b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ng.html @@ -0,0 +1,12 @@ +
+ + + + +
+
+ +
+ +
+
diff --git a/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.scss b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.scss new file mode 100644 index 00000000..504b56ae --- /dev/null +++ b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.scss @@ -0,0 +1,7 @@ +.row { + display: flex; +} + +.flex-space { + flex: 0; +} diff --git a/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ts b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ts new file mode 100644 index 00000000..8df79b09 --- /dev/null +++ b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis.ts @@ -0,0 +1,163 @@ +import { + Component, + EventEmitter, + Input, + OnChanges, + OnInit, + Output, + SimpleChanges, +} from '@angular/core'; +import {PIE_CHART_PALETTE} from 'org_xprof/frontend/app/common/constants/roofline_model_constants'; +import {ChartDataInfo} from 'org_xprof/frontend/app/common/interfaces/chart'; +import {SimpleDataTable} from 'org_xprof/frontend/app/common/interfaces/data_table'; +import {CategoryTableDataProcessor} from 'org_xprof/frontend/app/components/chart/category_table_data_processor'; +import { + PIE_CHART_OPTIONS, + SCATTER_CHART_OPTIONS, +} from 'org_xprof/frontend/app/components/chart/chart_options'; +import {Dashboard} from 'org_xprof/frontend/app/components/chart/dashboard/dashboard'; +import {DefaultDataProvider} from 'org_xprof/frontend/app/components/chart/default_data_provider'; + +type ColumnIdxArr = Array; + +/** + * An operation level analysis table view component (step appregation: total). + */ +@Component({ + standalone: false, + selector: 'operation-level-analysis', + templateUrl: './operation_level_analysis.ng.html', + styleUrls: ['./operation_level_analysis.scss'], +}) +export class OperationLevelAnalysis + extends Dashboard + implements OnInit, OnChanges +{ + /** The roofline model data, original dataset */ + // used for table chart and pie chart + @Input() rooflineModelData?: google.visualization.DataTable | null = null; + @Input() viewColumns: ColumnIdxArr = []; + // data for scatter chart, heavey data preprocessing handled in parent + @Input() rooflineSeriesData?: google.visualization.DataTable | null = null; + @Input() scatterChartOptions: google.visualization.ScatterChartOptions = {}; + // Op name prepopulated from url + @Input() selectedOp = ''; + + @Output() + readonly filterUpdated = new EventEmitter< + google.visualization.DataTableCellFilter[] + >(); + + pieChartDataProvider = new DefaultDataProvider(); + scatterChartDataProvider = new DefaultDataProvider(); + dataInfoCategoryPieChart: ChartDataInfo = { + data: null, + dataProvider: this.pieChartDataProvider, + options: { + ...PIE_CHART_OPTIONS, + width: 400, + height: 400, + chartArea: { + width: '70%', + height: '70%', + }, + title: 'Percentage of self time per HLO op category', + colors: PIE_CHART_PALETTE, + sliceVisibilityThreshold: 0.01, + }, + }; + dataInfoRooflineScatterChart: ChartDataInfo = { + data: null, + dataProvider: this.scatterChartDataProvider, + options: SCATTER_CHART_OPTIONS, + }; + + constructor() { + super(); + } + + ngOnInit() { + this.update(); + } + + ngOnChanges(changes: SimpleChanges) { + this.update(); + } + + update() { + this.parseData(); + // call inheried method to update table chart view + this.updateView(); + } + + parseData() { + // base data already preprocessed in parent component + if (!this.rooflineModelData) { + return; + } + + // process data for table chart + // columns are used in parent logic to set the dataView + this.columns = this.viewColumns; + this.dataTable = this.rooflineModelData; + + // process data for pie chart + this.pieChartDataProvider.parseData( + JSON.parse(this.dataTable.toJSON()) as SimpleDataTable, + ); + this.updateAndDrawPieCharts(); + + // process data for roofline scatter chart + if (this.rooflineSeriesData) { + this.scatterChartDataProvider.parseData( + JSON.parse(this.rooflineSeriesData.toJSON()) as SimpleDataTable, + ); + this.updateAndDrawScatterChart(); + } + } + + /** + * Triggered when filter update event is emited + * this is a temp solutino to make other charts view updated as well as the + * table chart when filters are changed + * TODO: remove this function when the Dashboard generalization is done + * building dashboard with multiple charts + */ + onUpdateFilters(filter: google.visualization.DataTableCellFilter) { + this.updateFilters(filter); + this.updateAndDrawPieCharts(); + this.updateAndDrawScatterChart(); + this.filterUpdated.emit(this.getFilters()); + } + + /** + * Helper functiont to update data for pie chart and refresh view + * TODO: update either chart component or Dashboard base class to generalize + * building dashboard with multiple charts this is a temp solutino to make + */ + updateAndDrawPieCharts() { + if (!this.dataTable) return; + const opCategoryIndex = this.dataTable.getColumnIndex('category'); + const opTotalSelfTimeIndex = + this.dataTable.getColumnIndex('total_self_time'); + this.dataInfoCategoryPieChart.customChartDataProcessor = + new CategoryTableDataProcessor( + this.getFilters(), + opCategoryIndex, + opTotalSelfTimeIndex, + ); + this.pieChartDataProvider.notifyCharts(); + } + + updateAndDrawScatterChart() { + if (!this.rooflineSeriesData) return; + this.dataInfoRooflineScatterChart = { + ...this.dataInfoRooflineScatterChart, + options: { + ...this.dataInfoRooflineScatterChart.options, + ...this.scatterChartOptions, + }, + }; + this.scatterChartDataProvider.notifyCharts(); + } +} diff --git a/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis_module.ts b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis_module.ts new file mode 100644 index 00000000..a44864cd --- /dev/null +++ b/frontend/app/components/roofline_model/operation_level_analysis/operation_level_analysis_module.ts @@ -0,0 +1,14 @@ +import {NgModule} from '@angular/core'; +import {ChartModule} from 'org_xprof/frontend/app/components/chart/chart'; +import {TableModule} from 'org_xprof/frontend/app/components/chart/table/table_module'; +import {CategoryFilterModule} from 'org_xprof/frontend/app/components/controls/category_filter/category_filter_module'; +import {StringFilterModule} from 'org_xprof/frontend/app/components/controls/string_filter/string_filter_module'; + +import {OperationLevelAnalysis} from './operation_level_analysis'; + +@NgModule({ + declarations: [OperationLevelAnalysis], + imports: [CategoryFilterModule, TableModule, StringFilterModule, ChartModule], + exports: [OperationLevelAnalysis], +}) +export class OperationLevelAnalysisModule {} diff --git a/frontend/app/components/roofline_model/program_level_analysis/BUILD b/frontend/app/components/roofline_model/program_level_analysis/BUILD new file mode 100644 index 00000000..7cac3f91 --- /dev/null +++ b/frontend/app/components/roofline_model/program_level_analysis/BUILD @@ -0,0 +1,33 @@ +load("@io_bazel_rules_sass//:defs.bzl", "sass_binary") +load("//defs:defs.bzl", "xprof_ng_module") + +package(default_visibility = ["//frontend:internal"]) + +xprof_ng_module( + name = "program_level_analysis", + srcs = [ + "program_level_analysis.ts", + "program_level_analysis_module.ts", + ], + assets = [ + ":program_level_analysis_css", + "program_level_analysis.ng.html", + ], + deps = [ + "@npm//@angular/core", + "@org_xprof//frontend/app/common/interfaces", + "@org_xprof//frontend/app/common/interfaces:chart", + "@org_xprof//frontend/app/components/chart", + "@org_xprof//frontend/app/components/chart:chart_options", + "@org_xprof//frontend/app/components/chart:default_data_provider", + "@org_xprof//frontend/app/components/chart/dashboard", + "@org_xprof//frontend/app/components/chart/table", + "@org_xprof//frontend/app/components/controls/category_filter", + ], +) + +sass_binary( + name = "program_level_analysis_css", + src = "program_level_analysis.scss", + sourcemap = False, +) diff --git a/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ng.html b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ng.html new file mode 100644 index 00000000..92da4a42 --- /dev/null +++ b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ng.html @@ -0,0 +1,5 @@ +
+ +
+ +
diff --git a/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.scss b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.scss new file mode 100644 index 00000000..4a2a5d88 --- /dev/null +++ b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.scss @@ -0,0 +1,3 @@ +.row { + display: flex; +} diff --git a/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ts b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ts new file mode 100644 index 00000000..0ffa40cd --- /dev/null +++ b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis.ts @@ -0,0 +1,108 @@ +import { + Component, + EventEmitter, + Input, + OnChanges, + OnInit, + Output, + SimpleChanges, +} from '@angular/core'; +import {ChartDataInfo} from 'org_xprof/frontend/app/common/interfaces/chart'; +import {SimpleDataTable} from 'org_xprof/frontend/app/common/interfaces/data_table'; +import {SCATTER_CHART_OPTIONS} from 'org_xprof/frontend/app/components/chart/chart_options'; +import {Dashboard} from 'org_xprof/frontend/app/components/chart/dashboard/dashboard'; +import {DefaultDataProvider} from 'org_xprof/frontend/app/components/chart/default_data_provider'; + +type ColumnIdxArr = Array; + +/** An program level analysis table view component. */ +@Component({ + standalone: false, + selector: 'program-level-analysis', + templateUrl: './program_level_analysis.ng.html', + styleUrls: ['./program_level_analysis.scss'], +}) +export class ProgramLevelAnalysis + extends Dashboard + implements OnInit, OnChanges +{ + /** The roofline model data */ + @Input() rooflineModelData?: google.visualization.DataTable | null = null; + @Input() viewColumns: ColumnIdxArr = []; + // data for scatter chart, heavey data preprocessing handled in parent + @Input() rooflineSeriesData?: google.visualization.DataTable | null = null; + @Input() scatterChartOptions: google.visualization.ScatterChartOptions = {}; + + @Output() + readonly filterUpdated = new EventEmitter< + google.visualization.DataTableCellFilter[] + >(); + + scatterChartDataProvider = new DefaultDataProvider(); + dataInfoRooflineScatterChart: ChartDataInfo = { + data: null, + dataProvider: this.scatterChartDataProvider, + options: {...SCATTER_CHART_OPTIONS, width: 800}, + }; + + constructor() { + super(); + } + + ngOnInit() { + this.update(); + } + + ngOnChanges(changes: SimpleChanges) { + this.update(); + } + + update() { + this.parseData(); + this.updateView(); + } + + parseData() { + // base data already preprocessed in parent component + if (!this.rooflineModelData) { + return; + } + + // process data for table chart + this.columns = this.viewColumns; + this.dataTable = this.rooflineModelData; + + // process data for roofline scatter chart + if (this.rooflineSeriesData) { + this.scatterChartDataProvider.parseData( + JSON.parse(this.rooflineSeriesData.toJSON()) as SimpleDataTable, + ); + this.updateAndDrawScatterChart(); + } + } + + /** + * Triggered when filter update event is emited + * this is a temp solutino to make other charts view updated as well as the + * table chart when filters are changed + * TODO: remove this function when the Dashboard generalization is done + * building dashboard with multiple charts + */ + onUpdateFilters(filter: google.visualization.DataTableCellFilter) { + this.updateFilters(filter); + this.updateAndDrawScatterChart(); + this.filterUpdated.emit(this.getFilters()); + } + + updateAndDrawScatterChart() { + if (!this.rooflineSeriesData) return; + this.dataInfoRooflineScatterChart = { + ...this.dataInfoRooflineScatterChart, + options: { + ...this.dataInfoRooflineScatterChart.options, + ...this.scatterChartOptions, + }, + }; + this.scatterChartDataProvider.notifyCharts(); + } +} diff --git a/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis_module.ts b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis_module.ts new file mode 100644 index 00000000..dd80dc1a --- /dev/null +++ b/frontend/app/components/roofline_model/program_level_analysis/program_level_analysis_module.ts @@ -0,0 +1,13 @@ +import {NgModule} from '@angular/core'; +import {ChartModule} from 'org_xprof/frontend/app/components/chart/chart'; +import {TableModule} from 'org_xprof/frontend/app/components/chart/table/table_module'; +import {CategoryFilterModule} from 'org_xprof/frontend/app/components/controls/category_filter/category_filter_module'; + +import {ProgramLevelAnalysis} from './program_level_analysis'; + +@NgModule({ + declarations: [ProgramLevelAnalysis], + imports: [CategoryFilterModule, TableModule, ChartModule], + exports: [ProgramLevelAnalysis], +}) +export class ProgramLevelAnalysisModule {} diff --git a/plugin/tensorboard_plugin_profile/convert/BUILD b/plugin/tensorboard_plugin_profile/convert/BUILD index 08665fda..26d574db 100644 --- a/plugin/tensorboard_plugin_profile/convert/BUILD +++ b/plugin/tensorboard_plugin_profile/convert/BUILD @@ -87,6 +87,16 @@ py_library( ], ) +py_library( + name = "hlo_stats_proto_to_gviz", + srcs = ["hlo_stats_proto_to_gviz.py"], + srcs_version = "PY2AND3", + deps = [ + requirement("gviz_api"), + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2", + ], +) + py_test( name = "overview_page_proto_to_gviz_test", size = "small", @@ -236,6 +246,7 @@ py_library( visibility = visibility, deps = [ ":dcn_collective_stats_proto_to_gviz", + ":hlo_stats_proto_to_gviz", ":input_pipeline_proto_to_gviz", ":kernel_stats_proto_to_gviz", ":overview_page_proto_to_gviz", diff --git a/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py b/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py new file mode 100644 index 00000000..ede95e0b --- /dev/null +++ b/plugin/tensorboard_plugin_profile/convert/hlo_stats_proto_to_gviz.py @@ -0,0 +1,128 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""For conversion of HloStats protos to GViz DataTables. + +Usage: + gviz_data_tables = generate_all_chart_tables(hlo_stats_db) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gviz_api + +from tensorboard_plugin_profile.protobuf import hlo_stats_pb2 + + +def get_hlo_op_name_from_expression(hlo_op_expression): + """Returns the hlo op name from the hlo op expression.""" + # The parse logic based on the assumption that the hlo op text is in format of + # '%op_name = ' + parts = hlo_op_expression.split(" = ") + hlo_op_name = parts[0] + if hlo_op_name and hlo_op_name[0] == "%": + hlo_op_name = hlo_op_name[1:] + return hlo_op_name + + +def get_hlo_stats_table_args(hlo_stats_db): + """Creates hlo op stats table from a hlo stats proto. + + Args: + hlo_stats_db: A HloStatsDatabase proto. + + Returns: + Returns a gviz_api.DataTable + """ + + table_description = [ + ("rank", "number", "Rank"), + ("program_id", "string", "Program id"), + ("category", "string", "HLO op category"), + ("hlo_op_name", "string", "HLO op name"), + ("hlo_op_expression", "string", "HLO op text"), + ("tf_op_name", "string", "Framework op name"), + ("occurrences", "number", "#Occurrences"), + ("total_time", "number", "Total time (us)"), + ("avg_time", "number", "Avg. time (us)"), + ("total_self_time", "number", "Total self time (us)"), + ("avg_self_time", "number", "Avg. self time (us)"), + ("total_self_time_percent", "number", "Total self time (%)"), + ( + "cumulative_total_self_time_percent", + "number", + "Cumulative total self time (%)", + ), + ("dma_stall_percent", "number", "%time stalled by DMA"), + ("model_flop_rate", "number", "Model GFLOP/s"), + ("normalized_flop_rate", "number", "Normalized GFLOP/s"), + ("measured_memory_bw", "number", "Measured memory BW (GiB/s)"), + ("hbm_bw", "number", "HBM BW (GiB/s)"), + ("cmem_read_bw", "number", "CMEM Read BW (GiB/s)"), + ("cmem_write_bw", "number", "CMEM Write BW (GiB/s)"), + ("operational_intensity", "number", "Operational intensity (FLOPS/Byte)"), + ("bound_by", "string", "Bound by"), + ("hlo_rematerialization", "string", "Rematerialization"), + ("outside_compilation", "string", "Outside Compilation"), + ("autotuned", "string", "Autotuned"), + ] + + data = [] + for record in hlo_stats_db.hlo_stats_record: + row = [ + record.rank, + str(record.program_id), + record.hlo_category, + get_hlo_op_name_from_expression(record.hlo_expression), + record.hlo_expression, + record.tf_op_name, + record.occurrences, + record.total_time_in_us, + record.avg_time_in_us, + record.total_self_time_in_us, + record.avg_self_time_in_us, + record.total_self_time_as_fraction, + record.cumulative_total_self_time_as_fraction, + record.dma_stall_fraction, + record.model_flop_rate, + record.measured_flop_rate, + record.measured_memory_bw, + record.hbm_bw, + record.cmem_read_bw, + record.cmem_write_bw, + record.operational_intensity, + record.bound_by, + "Yes" if record.rematerialization else "No", + "Yes" if record.outside_compilation else "No", + "Yes" if record.autotuned else "No", + ] + data.append(row) + + return (table_description, data, {}) + + +def generate_hlo_stats_table(hlo_stats_db): + (table_description, data, custom_properties) = get_hlo_stats_table_args( + hlo_stats_db + ) + return gviz_api.DataTable(table_description, data, custom_properties) + + +def to_json(raw_data): + """Converts a serialized HloStatsDb string to json.""" + hlo_stats_db = hlo_stats_pb2.HloStatsDatabase() + hlo_stats_db.ParseFromString(raw_data) + return generate_hlo_stats_table(hlo_stats_db).ToJSon() diff --git a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py index c2acb269..095e3540 100644 --- a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py +++ b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py @@ -27,6 +27,7 @@ from tensorflow.python.profiler.internal import _pywrap_profiler_plugin # pylint: disable=g-direct-tensorflow-import from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz +from tensorboard_plugin_profile.convert import hlo_stats_proto_to_gviz from tensorboard_plugin_profile.convert import input_pipeline_proto_to_gviz from tensorboard_plugin_profile.convert import kernel_stats_proto_to_gviz from tensorboard_plugin_profile.convert import overview_page_proto_to_gviz @@ -169,6 +170,10 @@ def xspace_to_tool_data( raw_data, success = xspace_wrapper_func(xspace_paths, tool) if success: data = raw_data + elif tool == 'hlo_stats': + raw_data, success = xspace_wrapper_func(xspace_paths, tool) + if success: + data = hlo_stats_proto_to_gviz.to_json(raw_data) elif tool == 'graph_viewer': options = params.get('graph_viewer_options', {}) raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) diff --git a/plugin/tensorboard_plugin_profile/profile_plugin.py b/plugin/tensorboard_plugin_profile/profile_plugin.py index 8a9071ec..832ebe85 100644 --- a/plugin/tensorboard_plugin_profile/profile_plugin.py +++ b/plugin/tensorboard_plugin_profile/profile_plugin.py @@ -110,6 +110,7 @@ 'pod_viewer^', 'tf_data_bottleneck_analysis^', 'op_profile^', + 'hlo_stats^', ] # XPlane generated tools that support all host mode. diff --git a/plugin/tensorboard_plugin_profile/protobuf/BUILD b/plugin/tensorboard_plugin_profile/protobuf/BUILD index df25677e..9d048700 100644 --- a/plugin/tensorboard_plugin_profile/protobuf/BUILD +++ b/plugin/tensorboard_plugin_profile/protobuf/BUILD @@ -12,6 +12,7 @@ proto_library( srcs = [ "dcn_slack_analysis.proto", "diagnostics.proto", + "hlo_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", @@ -28,6 +29,7 @@ py_proto_library( srcs = [ "dcn_slack_analysis.proto", "diagnostics.proto", + "hlo_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", diff --git a/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto b/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto new file mode 100644 index 00000000..19776256 --- /dev/null +++ b/plugin/tensorboard_plugin_profile/protobuf/hlo_stats.proto @@ -0,0 +1,124 @@ +// This proto describes the format of the output profile file from +// the HLO stats tool. +syntax = "proto2"; + +package tensorflow.profiler.hlo_stats; + +// A database of HloStats records. +message HloStatsDatabase { + // All HloStats records, one for each HLO operation. + repeated HloStatsRecord hlo_stats_record = 1; +} + +// There is one HloStatsRecord for each HLO operation profiled. +// Next ID: 39 +message HloStatsRecord { + // The rank by self time + optional uint64 rank = 1; + + // program_id for this op + optional uint64 program_id = 30; + + // The HLO category name. + optional string hlo_category = 17; + + // The HLO expression. + optional string hlo_expression = 2; + + // The framework op name (TF Op, JAX Op) + optional string tf_op_name = 21; + + // Number of occurrences of the operation. + optional int64 occurrences = 3; + + // Total "accumulated" time in micro-seconds that the operation + // took. If this operation has any children operations, + // the "accumulated" time includes the time spent inside children. + optional double total_time_in_us = 4; + + // Average "accumulated" time in micro-seconds that each + // occurrence of the operation took. + optional double avg_time_in_us = 5; + + // Total "self" time in micro-seconds that the operation took. + // If this operation has any children operations, the "self" time + // doesn't include the time spent inside children. + optional double total_self_time_in_us = 6; + + // Average "self" time in micro-seconds that the operation took. + optional double avg_self_time_in_us = 7; + optional double total_self_time_as_fraction = 8; + optional double cumulative_total_self_time_as_fraction = 9; + + // Percentage of the total "accumulated" time that was caused by + // DMA stall. + optional double dma_stall_fraction = 10; + + // Total floating-point operations (FLOPs) performed per second normalized to + // the bf16 peak capacity. + optional double measured_flop_rate = 13; + + // Total Floating-point operations for the op per second. + optional double model_flop_rate = 34; + + // Number of total bytes (including both read and write) accessed per second. + optional double measured_memory_bw = 14; + + // Number of bytes accessed from HBM (including both read and write) per + // second. + optional double hbm_bw = 22; + + // Number of bytes read from CMEM per second. + optional double cmem_read_bw = 23; + + // Number of bytes written to CMEM per second. + optional double cmem_write_bw = 24; + + // Number of bytes read from VMEM per second. + optional double vmem_read_bw = 35; + + // Number of bytes written to VMEM per second. + optional double vmem_write_bw = 36; + + // Overall operational intensity in FLOP/Byte. + optional double operational_intensity = 15; + + // Operational intensity based on HBM in FLOP/Byte. + optional double hbm_operational_intensity = 26; + + // Operational intensity based on CMEM read in FLOP/Byte. + optional double cmem_read_operational_intensity = 27; + + // Operational intensity based on CMEM write in FLOP/Byte. + optional double cmem_write_operational_intensity = 28; + + // Operational intensity based on VMEM read in FLOP/Byte. + optional double vmem_read_operational_intensity = 37; + + // Operational intensity based on VMEM write in FLOP/Byte. + optional double vmem_write_operational_intensity = 38; + + // Operational intensity based on the bottleneck resource in FLOP/Byte. + optional double bottleneck_operational_intensity = 29; + + // Whether this operation is "Compute", "HBM", "CMEM Read", "CMEM Write" + // bound, according to the Roofline Model. + optional string bound_by = 16; + + // Whether this operation is for HLO or Framework rematerialization. + optional bool rematerialization = 20; + + // Whether this op is for outside compilation. + optional bool outside_compilation = 25; + + // Whether this op is autotuned. + optional bool autotuned = 31; + + // Flops for the record + optional uint64 flops = 32; + + // Bytes accessed for the record + optional uint64 bytes_accessed = 33; + + reserved 11, 12, 18, 19; +}