From 1413a423057f2f00270a83a89e893d5214850d05 Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Tue, 3 Dec 2024 13:09:17 -0800 Subject: [PATCH] Add InferenceStats to Gviz conversion. PiperOrigin-RevId: 702444296 --- .../tensorboard_plugin_profile/convert/BUILD | 11 + .../convert/inference_stats_proto_to_gviz.py | 304 ++++++++++++++++++ .../convert/raw_to_tool_data.py | 5 + .../tensorboard_plugin_profile/protobuf/BUILD | 2 + .../protobuf/inference_stats.proto | 285 ++++++++++++++++ 5 files changed, 607 insertions(+) create mode 100644 plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py create mode 100644 plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto diff --git a/plugin/tensorboard_plugin_profile/convert/BUILD b/plugin/tensorboard_plugin_profile/convert/BUILD index 26d574db..e7b69048 100644 --- a/plugin/tensorboard_plugin_profile/convert/BUILD +++ b/plugin/tensorboard_plugin_profile/convert/BUILD @@ -4,6 +4,7 @@ load("@python_deps//:requirements.bzl", "requirement") # Converter from protobuf to gviz/json formats. load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("//devtools/python/blaze:pytype.bzl", "py_library", "pytype_strict_library") visibility = ["//plugin:internal"] @@ -247,6 +248,7 @@ py_library( deps = [ ":dcn_collective_stats_proto_to_gviz", ":hlo_stats_proto_to_gviz", + ":inference_stats_proto_to_gviz", ":input_pipeline_proto_to_gviz", ":kernel_stats_proto_to_gviz", ":overview_page_proto_to_gviz", @@ -256,3 +258,12 @@ py_library( "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2", ], ) + +pytype_strict_library( + name = "inference_stats_proto_to_gviz", + srcs = ["inference_stats_proto_to_gviz.py"], + deps = [ + requirement("gviz_api"), + "@org_xprof//plugin/tensorboard_plugin_profile/protobuf:protos_all_py_pb2", + ], +) diff --git a/plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py b/plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py new file mode 100644 index 00000000..9950d92e --- /dev/null +++ b/plugin/tensorboard_plugin_profile/convert/inference_stats_proto_to_gviz.py @@ -0,0 +1,304 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""For conversion of InferenceStats proto to gviz tables. + +Usage: + gviz_data_tables = generate_all_chart_tables(inference_stats) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import gviz_api + +from tensorboard_plugin_profile.protobuf import inference_stats_pb2 + + +def pico_to_milli(ps: float) -> float: + """Converts picoseconds to milliseconds.""" + return ps / 1e9 + + +def _add_request_details( + request_detail: inference_stats_pb2.RequestDetail, + percentile: str, + request_id: str, + has_batching: bool, + is_tpu: bool, + throughput: str, +): + """Generates the request details row.""" + row = [ + percentile, + request_id, + pico_to_milli(request_detail.end_time_ps - request_detail.start_time_ps), + ] + if has_batching: + row.append(request_detail.batching_request_size) + row.append(pico_to_milli(request_detail.batching_request_delay_ps)) + row.append(throughput) + if is_tpu: + row.append(pico_to_milli(request_detail.host_preprocessing_ps)) + row.append(pico_to_milli(request_detail.host_runtime_ps)) + row.append(pico_to_milli(request_detail.write_to_device_time_ps)) + row.append(pico_to_milli(request_detail.read_from_device_time_ps)) + row.append(pico_to_milli(request_detail.device_time_ps)) + row.append(pico_to_milli(request_detail.host_postprocessing_ps)) + row.append(pico_to_milli(request_detail.idle_time_ps)) + return row + + +def _create_request_table( + per_model_stats: inference_stats_pb2.PerModelInferenceStats, + has_batching: bool, + is_tpu: bool, +): + """Generates the request table.""" + columns = [ + ("percentile", "string", "Percentile"), + ("request_id", "string", "Request ID"), + ("latency_ms", "number", "Latency (ms)"), + ] + if has_batching: + columns.append(("batching_request_size", "number", "Batching Request Size")) + columns.append( + ("host_batch_information", "number", "Host Batch Information") + ) + columns.append(("throughput", "string", "Throughput")) + if is_tpu: + columns.append(("host_preprocessing", "number", "Host Preprocessing")) + columns.append(("host_runtime", "number", "Host Runtime")) + columns.append(("data_transfer_h2d", "number", "Data transfer H2D")) + columns.append(("data_transfer_d2h", "number", "Data transfer D2H")) + columns.append(("device_compute", "number", "Device compute")) + columns.append(("host_postprocess", "number", "Host Postprocessing")) + columns.append(("idle_time", "number", "Idle Time")) + data = [] + for request_detail in per_model_stats.per_batch_size_aggregated_result: + data.append( + _add_request_details( + request_detail.aggregated_request_result, + "Batch Size {}".format(request_detail.batch_size), + "N/A", + has_batching, + is_tpu, + "{:.1f}".format(request_detail.batch_throughput), + ) + ) + data.append( + _add_request_details( + per_model_stats.aggregated_request_detail, + "Aggregated", + "N/A", + has_batching, + is_tpu, + "{:.1f}".format(per_model_stats.request_throughput), + ) + ) + custom_properties = { + "throughput": "{:.1f}".format(per_model_stats.request_throughput), + "averageLatencyMs": "{:.3f}".format( + per_model_stats.request_average_latency_us / 1e3 + ), + } + return gviz_api.DataTable(columns, data, custom_properties) + + +def _generate_batch_details( + batch_detail: inference_stats_pb2.BatchDetail, + percentile: str, + batch_id: str, + throughput: str, +): + """Generates the batch details row.""" + return [ + percentile, + batch_id, + batch_detail.end_time_ps - batch_detail.start_time_ps, + batch_detail.padding_amount, + batch_detail.batch_size_after_padding, + (batch_detail.batch_size_after_padding - batch_detail.padding_amount) + / batch_detail.batch_size_after_padding, + batch_detail.batch_delay_ps, + throughput, + ] + + +def _generate_batch_table( + per_model_stats: inference_stats_pb2.PerModelInferenceStats, + model_id_database: inference_stats_pb2.ModelIdDatabase, + model_id: str, +): + """Generates the batch table.""" + columns = [ + ("percentile", "string", "Percentile"), + ("batch_id", "string", "Batch ID"), + ("latency", "number", "Latency (ms)"), + ("padding_amount", "number", "Padding Amount"), + ("batch_size_after_padding", "number", "Batch Size After Padding"), + ("batching_efficiency", "number", "Batch Efficiency"), + ("batch_delay_ms", "number", "Batch Delay (ms)"), + ("throughput", "string", "Throughput"), + ] + data = [] + properties = {} + properties["throughput"] = "{:.1f}".format(per_model_stats.batch_throughput) + properties["averageLatencyMs"] = "{:.3f}".format( + per_model_stats.batch_average_latency_us / 1e3 + ) + + if model_id in model_id_database.id_to_batching_params: + params = model_id_database.id_to_batching_params[model_id] + properties["hasBatchingParam"] = "true" + properties["batchingParamNumBatchThreads"] = str(params.num_batch_threads) + properties["batchingParamMaxBatchSize"] = str(params.max_batch_size) + properties["batchingParamBatchTimeoutMicros"] = str( + params.batch_timeout_micros + ) + properties["batchingParamMaxEnqueuedBatches"] = str( + params.max_enqueued_batches + ) + properties["batchingParamAllowedBatchSizes"] = str( + params.allowed_batch_sizes + ) + else: + properties["hasBatchingParam"] = "false" + for batch_detail in per_model_stats.per_batch_size_aggregated_result: + data.append( + _generate_batch_details( + batch_detail.aggregated_batch_result, + "Batch Size {}".format(batch_detail.batch_size), + "N/A", + "{:.1f}".format(batch_detail.batch_throughput), + ) + ) + data.append( + _generate_batch_details( + per_model_stats.aggregated_batch_detail, + "Aggregated", + "N/A", + "{:.1f}".format(per_model_stats.batch_throughput), + ) + ) + return gviz_api.DataTable(columns, data, properties) + + +def _generate_tensor_pattern_table( + per_model_inference_stats: inference_stats_pb2.PerModelInferenceStats, + tensor_pattern_db: inference_stats_pb2.TensorPatternDatabase, +): + """Generates the tensor pattern table.""" + table_description = [ + ("id", "number", "ID"), + ("tensor_pattern", "string", "Tensor Pattern"), + ("count", "number", "Number of Occurrence"), + ("percentile", "string", "Linearize/Delinearize latency"), + ] + data = [] + for counter, aggregated_result in enumerate( + per_model_inference_stats.tensor_transfer_aggregated_result.tensor_pattern_results + ): + tensor_pattern = tensor_pattern_db.tensor_pattern[ + aggregated_result.tensor_pattern_index + ] + data.append([ + counter, + tensor_pattern, + aggregated_result.count, + aggregated_result.linearize_delinearize_percentile_time, + ]) + logging.info("here: %s", data) + return gviz_api.DataTable(table_description, data) + + +def _generate_per_model_inference_table( + inference_stats: inference_stats_pb2.InferenceStats, + sorted_model_ids: list[str], + has_batching: bool, + is_tpu: bool, +): + """Generates the per model inference table.""" + tables = [] + for model_id in sorted_model_ids: + try: + model_index = inference_stats.model_id_db.id_to_index[model_id] + per_model_stats = inference_stats.inference_stats_per_model[model_index] + tables.append( + _create_request_table(per_model_stats, has_batching, is_tpu) + ) + if has_batching: + tables.append( + _generate_batch_table( + per_model_stats, inference_stats.model_id_db, model_id + ) + ) + if inference_stats.tensor_pattern_db.tensor_pattern: + logging.info( + "here: %s", inference_stats.tensor_pattern_db.tensor_pattern + ) + tables.append( + _generate_tensor_pattern_table( + per_model_stats, inference_stats.tensor_pattern_db + ) + ) + except KeyError: + continue + return tables + + +def generate_all_chart_tables( + inference_stats: inference_stats_pb2.InferenceStats, +): + """Converts a InferenceStats proto to gviz DataTables.""" + sorted_model_ids = [x for x in inference_stats.model_id_db.ids] + sorted_model_ids.sort() + has_batching = False + for _, per_model_stats in inference_stats.inference_stats_per_model.items(): + if per_model_stats.batch_details: + has_batching = True + break + is_tpu = True + table_properties = { + "hasBatching": "{}".format(has_batching).lower(), + "hasTensorPattern": "false", + } + columns = [ + ("model_name", "string", "Model Name"), + ] + data = [] + for model_id in sorted_model_ids: + data.append([model_id]) + logging.info("here: %s", data) + return [ + gviz_api.DataTable(columns, data, table_properties), + *_generate_per_model_inference_table( + inference_stats, + sorted_model_ids, + has_batching, + is_tpu, + ), + ] + + +def to_json(raw_data): + """Converts a serialized DcnCollectiveAnalysis string to json.""" + inference_stats = inference_stats_pb2.InferenceStats() + inference_stats.ParseFromString(raw_data) + all_chart_tables = generate_all_chart_tables(inference_stats) + json_join = ",".join(x.ToJSon() if x else "{}" for x in all_chart_tables) + return "[" + json_join + "]" diff --git a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py index 095e3540..57c0f790 100644 --- a/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py +++ b/plugin/tensorboard_plugin_profile/convert/raw_to_tool_data.py @@ -28,6 +28,7 @@ from tensorflow.python.profiler.internal import _pywrap_profiler_plugin # pylint: disable=g-direct-tensorflow-import from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz from tensorboard_plugin_profile.convert import hlo_stats_proto_to_gviz +from tensorboard_plugin_profile.convert import inference_stats_proto_to_gviz from tensorboard_plugin_profile.convert import input_pipeline_proto_to_gviz from tensorboard_plugin_profile.convert import kernel_stats_proto_to_gviz from tensorboard_plugin_profile.convert import overview_page_proto_to_gviz @@ -204,6 +205,10 @@ def xspace_to_tool_data( raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) if success: data = dcn_collective_stats_proto_to_gviz.to_json(raw_data) + elif tool == 'inference_profile': + raw_data, success = xspace_wrapper_func(xspace_paths, tool) + if success: + data = inference_stats_proto_to_gviz.to_json(raw_data) else: logger.warning('%s is not a known xplane tool', tool) return data, content_type diff --git a/plugin/tensorboard_plugin_profile/protobuf/BUILD b/plugin/tensorboard_plugin_profile/protobuf/BUILD index 9d048700..124bac7e 100644 --- a/plugin/tensorboard_plugin_profile/protobuf/BUILD +++ b/plugin/tensorboard_plugin_profile/protobuf/BUILD @@ -13,6 +13,7 @@ proto_library( "dcn_slack_analysis.proto", "diagnostics.proto", "hlo_stats.proto", + "inference_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", @@ -30,6 +31,7 @@ py_proto_library( "dcn_slack_analysis.proto", "diagnostics.proto", "hlo_stats.proto", + "inference_stats.proto", "input_pipeline.proto", "kernel_stats.proto", "overview_page.proto", diff --git a/plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto b/plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto new file mode 100644 index 00000000..47f378db --- /dev/null +++ b/plugin/tensorboard_plugin_profile/protobuf/inference_stats.proto @@ -0,0 +1,285 @@ +// This proto is used for inference-specific analysis. +syntax = "proto2"; + +package tensorflow.profiler; + +message TensorEventDetail { + // The index of the tensor pattern in TensorPatternDatabase. + optional int32 tensor_pattern_index = 1; + + // The owner of this TensorEventDetail. + enum TensorEventOwner { + // Unknown. This should not happen in production code. + UNKNOWN = 0; + + // Owned by the request. + REQUEST = 1; + + // Owned by the batch. + BATCH = 2; + } + + // If batching is enabled, the TensorEventDetails in BatchDetail will have + // owner = BATCH, and they are counted when calculating statistics like the + // number of occurrence for each tensor pattern. The TensorEventDetails in + // RequestDetail will have owner = BATCH, which means the tensor events + // actually happen in the batch, and they are not counted when calculating + // various statistics. + // If batching is not enabled, the TensorEventDetail will only appear in + // RequestDetail and the owner will only be REQUEST. + optional TensorEventOwner owner = 2; + + // Total time in picosecs spent on linearize and delinearize tensors. + optional uint64 linearize_delinearize_time_ps = 3; +} + +// Detail of a user facing request. +// Next ID: 22 +message RequestDetail { + // Request id. + optional int64 request_id = 10 [default = -1]; + + // An index to the model_id inside InferenceStats below. Storing index + // instead of string to save space. It will be -1 if the model id is not + // given. + optional int32 model_id_index = 8 [default = -1]; + + // Start-time of the request in picosecs. + optional uint64 start_time_ps = 1 [default = 0]; + + // End-time of the request in picosecs. + optional uint64 end_time_ps = 2 [default = 0]; + + // Total time in picosecs in this request spent on device. + optional uint64 device_time_ps = 7 [default = 0]; + + // Total time in picosecs in this request spent on writes to device. + optional uint64 write_to_device_time_ps = 5 [default = 0]; + + // Total time in picosecs in this request spent on reads from device. + optional uint64 read_from_device_time_ps = 6 [default = 0]; + + // If this inference request is running in batching mode, record the latency + // between a request is scheduled and is processed in a batch. Otherwise, it + // will always be 0. + optional uint64 batching_request_delay_ps = 12 [default = 0]; + + // Batch ids related to this request. + repeated int64 related_batch_ids = 11; + + // If this inference request is running in batching mode, record the size of + // the request. Otherwise, it will always be 0. + optional int32 batching_request_size = 13; + + // Detailed breakdown for host side activities of a request. + // Total time in picosecs spent on host preprocessing. + optional uint64 host_preprocessing_ps = 14; + + // Total time in picosecs spent on host batch formation. + optional uint64 host_batch_formation_ps = 15; + + // Total time in picosecs spent on host runtime. + optional uint64 host_runtime_ps = 16; + + // Total time in picosecs spent on host postprocessing. + optional uint64 host_postprocessing_ps = 17; + + // Tensor event details. + // One request can have multiple TensorEventDetails because it might be + // split into multiple batches for execution. + repeated TensorEventDetail tensor_event_details = 18; + + // Host index for this request. + optional int32 host_id = 19; + + // Percentile of this request in all requests in the profile duration. + optional double percentile = 20; + + // The time no event associated with. It could be that the machine was idle or + // executing some events which were not traced. + optional double idle_time_ps = 21; + + // Were device_start_time_ps, device_end_time_ps, session_id + reserved 3, 4, 9; +} + +// Detail of a batch. +// Next ID: 12 +message BatchDetail { + // Batch id. + optional int64 batch_id = 1 [default = -1]; + + // Start time of the batch in picosecs. + optional uint64 start_time_ps = 2 [default = 0]; + + // End time of the batch in picosecs. + optional uint64 end_time_ps = 3 [default = 0]; + + // The latency between "start time of the first request in this batch", and + // the time this batch is processed. + optional uint64 batch_delay_ps = 5 [default = 0]; + + // Request ids related to this batch. + repeated int64 related_request_ids = 4; + + // Size of padding. + optional int32 padding_amount = 6; + + // Size of a batch after padding. + optional int32 batch_size_after_padding = 7; + + // Model ID of this batch. This is the same model_id as any of the request in + // this batch. All the requests from the same batch must share the same + // model_id. + optional int32 model_id_index = 8; + + // Tensor event details. + optional TensorEventDetail tensor_event_detail = 9; + + // Host index for this batch. + optional int32 host_id = 10; + + // Percentile of this batch in all batches in the profile duration. + optional double percentile = 11; +} + +// Per-host data for inference analysis. +message PerHostInferenceStats { + // A list of requests selected for inference analysis on this host. + // This list is in ascending order of the request duration. + repeated RequestDetail request_details = 3; + + // A list of batches selected for inference analysis on this host. + // This list is in ascending order of the batch duration. + repeated BatchDetail batch_details = 5; + + reserved 1, 2, 4, 6; + + // were session_run_times, sessions_per_second, requests_per_second, + // batches_per_second. +} + +// Per-model aggregated result of tensor transfer. +message TensorTransferAggregatedResult { + message TensorPatternResult { + // The index of the tensor pattern in TensorPatternDatabase. + optional int32 tensor_pattern_index = 1; + + // The number of occurrence of this tensor pattern in this model. + optional uint64 count = 2; + + message PercentileTime { + optional double percentile = 1; + optional uint64 time_ps = 2; + } + + // The percentiles of the linearize and delinearize time of this tensor + // pattern in this model. + repeated PercentileTime linearize_delinearize_percentile_time = 3; + } + + repeated TensorPatternResult tensor_pattern_results = 1; +} + +// Aggregated result per batch size. +message PerBatchSizeAggregatedResult { + optional int32 batch_size = 1; + optional RequestDetail aggregated_request_result = 2; + optional BatchDetail aggregated_batch_result = 3; + optional double request_throughput = 4; + optional double batch_throughput = 5; +} + +// Per-model data for inference analysis. +message PerModelInferenceStats { + // A list of requests selected for inference analysis on this model. + // This list is in ascending order of the request duration. + repeated RequestDetail request_details = 1; + + // Aggregated result from all the . + optional RequestDetail aggregated_request_detail = 8; + + // Inference requests per second for this model. + optional double request_throughput = 2; + + // Average latency in microseconds of the requests in this model. + optional double request_average_latency_us = 3; + + // A list of batches selected for inference analysis on this model. + // This list is in ascending order of the batch duration. + repeated BatchDetail batch_details = 4; + + // Aggregated result from all the . + optional BatchDetail aggregated_batch_detail = 9; + + // Batches per second for this model. + optional double batch_throughput = 5; + + // Average latency in microseconds of the batches in this model. + optional double batch_average_latency_us = 6; + + // The aggregated result of tensor transfer in this model. + optional TensorTransferAggregatedResult tensor_transfer_aggregated_result = 7; + + // Aggregated result per batch size. + repeated PerBatchSizeAggregatedResult per_batch_size_aggregated_result = 10; +} + +// Batching parameters collected from TFstreamz. +message BatchingParameters { + // Number of batch threads. + optional int64 num_batch_threads = 1; + + // How long a request can wait before being processed by a batch. + optional int64 batch_timeout_micros = 2; + + // Maximum size of a batch. + optional int64 max_batch_size = 3; + + // Maximum number of enqueued batches. + optional int64 max_enqueued_batches = 4; + + // Sizes that are allowed to form a batch. A list of integers separated by "," + optional string allowed_batch_sizes = 5; +} + +// Model ID database. Unknown model id will be "" and won't be stored here. So +// if model id is not found in the TF-session metadata, ModelIdDatabase will be +// empty. +message ModelIdDatabase { + // Array of model ids. + repeated string ids = 1; + + // Map from id to index. + map id_to_index = 2; + + // Map from id to batching parameters. + map id_to_batching_params = 3; +} + +// Tensor pattern database for all the tensor patterns that occurred during the +// profiling window. +message TensorPatternDatabase { + // A tensor pattern is the string concatenation of all the linearize and + // delinearize events in an inference request. Each event records the tensor + // shape, data type and the layout on device. + repeated string tensor_pattern = 1; +} + +// Proto consumed by inference analysis. +message InferenceStats { + // Map from host-id to the InferenceStats for that host. + map inference_stats_per_host = 3; + + // Map from model-id to the InferenceStats for that model. + map inference_stats_per_model = + 5; + + // A database of model ids. + optional ModelIdDatabase model_id_db = 4; + + // A database of tensor patterns. + optional TensorPatternDatabase tensor_pattern_db = 6; + + reserved 1, 2; // were processing_stats, session_run_times +}