diff --git a/third_party/tsl/tsl/platform/default/BUILD b/third_party/tsl/tsl/platform/default/BUILD index 8e46929122a023..7fb65e4a501ebf 100644 --- a/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/tsl/tsl/platform/default/BUILD @@ -465,6 +465,7 @@ cc_library( "//tsl/platform:strcat", "//tsl/platform:stringpiece", "//tsl/platform:types", + "//tsl/profiler/backends/cpu:threadpool_listener_state", ], alwayslink = True, ) diff --git a/third_party/tsl/tsl/platform/default/tracing_impl.h b/third_party/tsl/tsl/platform/default/tracing_impl.h index 8e06e4f60e8ae5..317151b169e2a4 100644 --- a/third_party/tsl/tsl/platform/default/tracing_impl.h +++ b/third_party/tsl/tsl/platform/default/tracing_impl.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_TRACING_IMPL_H_ #define TENSORFLOW_TSL_PLATFORM_DEFAULT_TRACING_IMPL_H_ +#include "tsl/profiler/backends/cpu/threadpool_listener_state.h" + // Stub implementations of tracing functionality. // Definitions that do nothing for platforms that don't have underlying thread @@ -33,7 +35,9 @@ limitations under the License. namespace tsl { namespace tracing { -inline bool EventCollector::IsEnabled() { return false; } +inline bool EventCollector::IsEnabled() { + return tsl::profiler::threadpool_listener::IsEnabled(); +} } // namespace tracing } // namespace tsl diff --git a/third_party/tsl/tsl/profiler/backends/cpu/BUILD b/third_party/tsl/tsl/profiler/backends/cpu/BUILD index 4deea648549221..2f6f43a81fa710 100644 --- a/third_party/tsl/tsl/profiler/backends/cpu/BUILD +++ b/third_party/tsl/tsl/profiler/backends/cpu/BUILD @@ -128,3 +128,39 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "threadpool_listener", + srcs = ["threadpool_listener.cc"], + hdrs = ["threadpool_listener.h"], + visibility = internal_visibility([ + "//tensorflow/python:__pkg__", + "//tsl/platform/cloud:__pkg__", + "//tsl/profiler:__pkg__", + "//tsl/profiler:internal", + "//tsl/profiler:xla_internal", + ]), + deps = [ + ":threadpool_listener_state", + ":traceme_recorder", + "//tsl/platform:logging", + "//tsl/platform:tracing", + "//tsl/platform:types", + "//tsl/profiler/lib:context_types_hdrs", + "//tsl/profiler/lib:profiler_interface", + "//tsl/profiler/lib:traceme_encode", + "//tsl/profiler/utils:time_utils", + "//tsl/profiler/utils:xplane_schema", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "threadpool_listener_state", + srcs = ["threadpool_listener_state.cc"], + hdrs = ["threadpool_listener_state.h"], + visibility = internal_visibility([ + "//tsl/platform:__subpackages__", + ]), +) diff --git a/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener.cc b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener.cc new file mode 100644 index 00000000000000..1774c10b51db71 --- /dev/null +++ b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener.cc @@ -0,0 +1,99 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tsl/profiler/backends/cpu/threadpool_listener.h" + +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/tracing.h" +#include "tsl/platform/types.h" +#include "tsl/profiler/backends/cpu/threadpool_listener_state.h" +#include "tsl/profiler/backends/cpu/traceme_recorder.h" +#include "tsl/profiler/lib/context_types.h" +#include "tsl/profiler/lib/traceme_encode.h" +#include "tsl/profiler/utils/time_utils.h" +#include "tsl/profiler/utils/xplane_schema.h" + +namespace tsl { +namespace profiler { +namespace { + +void RegisterThreadpoolEventCollector(ThreadpoolEventCollector* collector) { + tracing::SetEventCollector(tracing::EventCategory::kScheduleClosure, + collector); + tracing::SetEventCollector(tracing::EventCategory::kRunClosure, collector); +} + +void UnregisterThreadpoolEventCollector() { + tracing::SetEventCollector(tracing::EventCategory::kScheduleClosure, nullptr); + tracing::SetEventCollector(tracing::EventCategory::kRunClosure, nullptr); +} + +} // namespace + +void ThreadpoolEventCollector::RecordEvent(uint64 arg) const { + int64_t now = GetCurrentTimeNanos(); + TraceMeRecorder::Record( + {TraceMeEncode(kThreadpoolListenerRecord, + {{"_pt", ContextType::kThreadpoolEvent}, {"_p", arg}}), + now, now}); +} +void ThreadpoolEventCollector::StartRegion(uint64 arg) const { + int64_t now = GetCurrentTimeNanos(); + TraceMeRecorder::Record( + {TraceMeEncode(kThreadpoolListenerStartRegion, + {{"_ct", ContextType::kThreadpoolEvent}, {"_c", arg}}), + now, now}); +} +void ThreadpoolEventCollector::StopRegion() const { + int64_t now = GetCurrentTimeNanos(); + TraceMeRecorder::Record( + {TraceMeEncode(kThreadpoolListenerStopRegion, {}), now, now}); +} + +absl::Status ThreadpoolProfilerInterface::Start() { + if (tracing::EventCollector::IsEnabled()) { + LOG(WARNING) << "[ThreadpoolEventCollector] EventCollector is enabled, Not " + "collecting events from ThreadPool."; + status_ = absl::FailedPreconditionError( + "ThreadpoolEventCollector is enabled, Not collecting events from " + "ThreadPool."); + return absl::OkStatus(); + } + event_collector_ = std::make_unique(); + RegisterThreadpoolEventCollector(event_collector_.get()); + threadpool_listener::Activate(); + return absl::OkStatus(); +} + +absl::Status ThreadpoolProfilerInterface::Stop() { + threadpool_listener::Deactivate(); + UnregisterThreadpoolEventCollector(); + return absl::OkStatus(); +} + +absl::Status ThreadpoolProfilerInterface::CollectData( + tensorflow::profiler::XSpace* space) { + if (!status_.ok()) { + *space->add_errors() = status_.ToString(); + } + return absl::OkStatus(); +} + +} // namespace profiler +} // namespace tsl diff --git a/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener.h b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener.h new file mode 100644 index 00000000000000..900be949c60e6a --- /dev/null +++ b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_THREADPOOL_LISTENER_H_ +#define TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_THREADPOOL_LISTENER_H_ + +#include "absl/status/status.h" +#include "tsl/platform/tracing.h" +#include "tsl/platform/types.h" +#include "tsl/profiler/backends/cpu/threadpool_listener_state.h" +#include "tsl/profiler/lib/profiler_interface.h" +namespace tsl { +namespace profiler { + +class ThreadpoolEventCollector : public tsl::tracing::EventCollector { + public: + explicit ThreadpoolEventCollector() = default; + + void RecordEvent(uint64 arg) const override; + void StartRegion(uint64 arg) const override; + void StopRegion() const override; + + // Annotates the current thread with a name. + void SetCurrentThreadName(const char* name) {} + // Returns whether event collection is enabled. + static bool IsEnabled() { return threadpool_listener::IsEnabled(); } +}; + +class ThreadpoolProfilerInterface : public ProfilerInterface { + public: + explicit ThreadpoolProfilerInterface() = default; + + absl::Status Start() override; + absl::Status Stop() override; + + absl::Status CollectData(tensorflow::profiler::XSpace* space) override; + + private: + absl::Status status_; + std::unique_ptr event_collector_; +}; + +} // namespace profiler +} // namespace tsl + +#endif // TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_THREADPOOL_LISTENER_H_ diff --git a/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener_state.cc b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener_state.cc new file mode 100644 index 00000000000000..dbbf47a6d977f1 --- /dev/null +++ b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener_state.cc @@ -0,0 +1,35 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/backends/cpu/threadpool_listener_state.h" + +#include + +namespace tsl { +namespace profiler { +namespace threadpool_listener { +namespace { +std::atomic enabled; +} + +bool IsEnabled() { return enabled.load(std::memory_order_acquire); } + +void Activate() { enabled.store(true, std::memory_order_release); } + +void Deactivate() { enabled.store(false, std::memory_order_release); } + +} // namespace threadpool_listener +} // namespace profiler +} // namespace tsl diff --git a/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener_state.h b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener_state.h new file mode 100644 index 00000000000000..36d19676f48450 --- /dev/null +++ b/third_party/tsl/tsl/profiler/backends/cpu/threadpool_listener_state.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_THREADPOOL_LISTENER_STATE_H_ +#define TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_THREADPOOL_LISTENER_STATE_H_ + +namespace tsl { +namespace profiler { +namespace threadpool_listener { + +// Check if the threadpool listener is enabled. +bool IsEnabled(); + +// Set global state of threadpool listener to enabled. +void Activate(); + +// Set global state of threadpool listener to disabled. +void Deactivate(); + +} // namespace threadpool_listener +} // namespace profiler +} // namespace tsl + +#endif // TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_THREADPOOL_LISTENER_STATE_H_ diff --git a/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/tsl/tsl/profiler/lib/BUILD index 88462a4963215b..87b9c9a877554e 100644 --- a/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/tsl/tsl/profiler/lib/BUILD @@ -332,6 +332,7 @@ cc_library( visibility = internal_visibility([ "@xla//xla/backends/profiler/plugin:__pkg__", "//learning/brain/tfrc/executor/stream_executor:__pkg__", + "@xla//xla/backends/profiler/cpu:__pkg__", ]), deps = [ ":profiler_interface", diff --git a/third_party/tsl/tsl/profiler/lib/context_types.cc b/third_party/tsl/tsl/profiler/lib/context_types.cc index 371631c10ba882..a581f47572f95d 100644 --- a/third_party/tsl/tsl/profiler/lib/context_types.cc +++ b/third_party/tsl/tsl/profiler/lib/context_types.cc @@ -48,6 +48,8 @@ const char* GetContextTypeString(ContextType context_type) { return "pathways_exec"; case ContextType::kPjrtLibraryCall: return "pjrt_library_call"; + case ContextType::kThreadpoolEvent: + return "threadpool_event"; } } diff --git a/third_party/tsl/tsl/profiler/lib/context_types.h b/third_party/tsl/tsl/profiler/lib/context_types.h index 621f35462fdae2..35bf1b8b2755b7 100644 --- a/third_party/tsl/tsl/profiler/lib/context_types.h +++ b/third_party/tsl/tsl/profiler/lib/context_types.h @@ -37,6 +37,7 @@ enum class ContextType : int { kTpuLaunch, kPathwaysExecutor, kPjrtLibraryCall, + kThreadpoolEvent, kLastContextType = ContextType::kTpuLaunch, }; diff --git a/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/tsl/tsl/profiler/utils/BUILD index 3e9041b0ae988e..fc656790987056 100644 --- a/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/tsl/tsl/profiler/utils/BUILD @@ -426,6 +426,7 @@ cc_library( "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc b/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc index 15aedbc9d495b4..7e8e7f3a431c97 100644 --- a/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc +++ b/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc @@ -81,6 +81,7 @@ void MutateXPlane(XPlane* plane, std::vector> CreateMutatorFactories() { std::vector> mutator_factories; + mutator_factories.push_back(ThreadpoolLineMutatorFactory::CreateFactory()); mutator_factories.push_back(XplaneRootEventMutatorFactory::CreateFactory( HostEventType::kProcessBatch, 2)); mutator_factories.push_back(XplaneRootEventMutatorFactory::CreateFactory( diff --git a/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h b/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h index 2cbc96bea37927..2433cd825cc842 100644 --- a/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h +++ b/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/hash/hash.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" @@ -461,8 +462,101 @@ class TpuModuleLineMutatorFactory : public XplaneEventMutatorFactory { }; }; -// Preprocess the given XSpace to support legacy traces. It converts old context -// events and stats into new ones according to go/xprof-traceme2-semantics. +// Line mutator for threadpool line. +// Threadpool Line Mutator create a kThreadpoolListenerRegion from StartRegion +// to StopRegion events, and propagates the context information from the +// StartRegion to the newly added event. +class ThreadpoolLineMutatorFactory : public XplaneEventMutatorFactory { + public: + static std::unique_ptr CreateFactory() { + return absl::WrapUnique(new ThreadpoolLineMutatorFactory()); + } + + std::vector> CreateMutators( + XPlaneBuilder* xplane) const override { + std::vector> mutators; + mutators.emplace_back(std::make_unique(xplane)); + return mutators; + } + + private: + ThreadpoolLineMutatorFactory() = default; + + class ThreadpoolLineMutator : public XplaneEventMutator { + public: + explicit ThreadpoolLineMutator(XPlaneBuilder* xplane) + : XplaneEventMutator(nullptr), xplane_(xplane) { + start_region_metadata_ = + xplane_->GetEventMetadata(kThreadpoolListenerStartRegion); + stop_region_metadata_ = + xplane_->GetEventMetadata(kThreadpoolListenerStopRegion); + thread_pool_metadata_ = + xplane_->GetOrCreateEventMetadata(kThreadpoolListenerRegion); + consumer_ = xplane_->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kConsumerId)); + consumer_type_ = xplane_->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kConsumerType)); + } + + void Mutate(XEventBuilder* event_builder) override { + CHECK(false); // Crash OK + } + + void MutateEventsInLine(XLineBuilder* line) override { + if (start_region_metadata_ == nullptr || + stop_region_metadata_ == nullptr) { + // Skip mutations for xplanes that do not have region markers. These + // include device_planes, or situations where the threadpool_listeners + // did not start or were not present. + return; + } + int64_t start_region_timestamp_ps = 0; + int64_t region_id; + struct EventMetadata { + int64_t start_region_timestamp_ps; + int64_t region_id; + int64_t end_region_timestamp_ps; + }; + + std::vector event_metadata; + line->ForEachEvent([&](const XEventBuilder& event) { + if (event.MetadataId() == start_region_metadata_->id()) { + auto consumer_id = event.GetStat(*consumer_); + if (!consumer_id) return; + start_region_timestamp_ps = event.TimestampPs(); + region_id = event.IntOrUintValue(*consumer_id); + } else if (event.MetadataId() == stop_region_metadata_->id() && + start_region_timestamp_ps != 0) { + EventMetadata metadata; + metadata.start_region_timestamp_ps = start_region_timestamp_ps; + metadata.region_id = region_id; + metadata.end_region_timestamp_ps = event.TimestampPs(); + event_metadata.emplace_back(metadata); + } + }); + for (const auto& event_metadata : event_metadata) { + XEventBuilder region = line->AddEvent(*thread_pool_metadata_); + region.SetTimestampPs(event_metadata.start_region_timestamp_ps); + region.SetEndTimestampPs(event_metadata.end_region_timestamp_ps); + region.SetOrAddStatValue(*consumer_, event_metadata.region_id); + region.SetOrAddStatValue( + *consumer_type_, + static_cast(ContextType::kThreadpoolEvent)); + } + } + + private: + XStatMetadata* consumer_; + XStatMetadata* consumer_type_; + XPlaneBuilder* xplane_; + XEventMetadata* start_region_metadata_; + XEventMetadata* stop_region_metadata_; + XEventMetadata* thread_pool_metadata_; + }; +}; +// Preprocess the given XSpace to support legacy traces. It converts old +// context events and stats into new ones according to +// go/xprof-traceme2-semantics. void PreprocessXSpace(XSpace* space); void PreprocessXPlane(XPlane* plane); diff --git a/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc b/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc index 08343acf6af1c4..5912c00c397853 100644 --- a/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc +++ b/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tsl/profiler/utils/preprocess_xplane.h" +#include #include #include "absl/container/flat_hash_map.h" @@ -250,6 +251,45 @@ TEST(PreprocessXPlane, HostRunIdPreprocessorTest) { }); } +TEST(PreprocessXPlane, ThreadPoolPreprocessorTest) { + XSpace space; + XPlane* plane = space.add_planes(); + XPlaneBuilder plane_builder(plane); + auto main_line = plane_builder.GetOrCreateLine(0); + CreateXEvent(&plane_builder, &main_line, kThreadpoolListenerRecord, 100, 100, + {{StatType::kProducerType, + static_cast(ContextType::kThreadpoolEvent)}, + {StatType::kProducerId, int64_t{123}}}); + auto thread_pool_line = plane_builder.GetOrCreateLine(1); + CreateXEvent(&plane_builder, &thread_pool_line, + kThreadpoolListenerStartRegion, 200, 0, + {{StatType::kConsumerType, + static_cast(ContextType::kThreadpoolEvent)}, + {StatType::kConsumerId, int64_t{123}}}); + CreateXEvent(&plane_builder, &thread_pool_line, kThreadpoolListenerStopRegion, + 300, 0, + {{StatType::kConsumerType, + static_cast(ContextType::kThreadpoolEvent)}, + {StatType::kConsumerId, int64_t{123}}}); + + bool new_event_added = false; + PreprocessXSpace(&space); + XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(plane); + plane_visitor.ForEachLine([&](const XLineVisitor& line) { + line.ForEachEvent([&](const XEventVisitor& event) { + if (event.Name() == kThreadpoolListenerRegion) { + new_event_added = true; + EXPECT_EQ(event.DurationPs(), 100); + EXPECT_EQ(event.TimestampPs(), 200); + auto stat = event.GetStat(StatType::kConsumerId); + EXPECT_TRUE(stat.has_value()); + EXPECT_EQ(stat->IntOrUintValue(), 123); + } + }); + }); + EXPECT_TRUE(new_event_added); +} + } // namespace } // namespace profiler } // namespace tsl diff --git a/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 990ffa1ed4e9df..9dd1f77d618a24 100644 --- a/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -558,5 +558,14 @@ const char kXProfMetadataFlow[] = "flow"; const char kXProfMetadataTransfers[] = "transfers"; const char kXProfMetadataBufferSize[] = "buffer_size"; +// String constants for threadpool_listener +const absl::string_view kThreadpoolListenerRecord = + "ThreadpoolListener::Record"; +const absl::string_view kThreadpoolListenerStartRegion = + "ThreadpoolListener::StartRegion"; +const absl::string_view kThreadpoolListenerStopRegion = + "ThreadpoolListener::StopRegion"; +const absl::string_view kThreadpoolListenerRegion = + "ThreadpoolListener::Region"; } // namespace profiler } // namespace tsl diff --git a/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/tsl/tsl/profiler/utils/xplane_schema.h index c3d0dbddd70d30..73e8b2b6ad63b2 100644 --- a/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -507,6 +507,12 @@ TF_CONST_INIT extern const char kXProfMetadataFlow[]; TF_CONST_INIT extern const char kXProfMetadataTransfers[]; TF_CONST_INIT extern const char kXProfMetadataBufferSize[]; +// String constants for threadpool_listener events +TF_CONST_INIT extern const absl::string_view kThreadpoolListenerRecord; +TF_CONST_INIT extern const absl::string_view kThreadpoolListenerStartRegion; +TF_CONST_INIT extern const absl::string_view kThreadpoolListenerStopRegion; +TF_CONST_INIT extern const absl::string_view kThreadpoolListenerRegion; + } // namespace profiler } // namespace tsl diff --git a/xla/backends/profiler/cpu/BUILD b/xla/backends/profiler/cpu/BUILD index f0154fed4eb802..df755efe7db3f9 100644 --- a/xla/backends/profiler/cpu/BUILD +++ b/xla/backends/profiler/cpu/BUILD @@ -37,7 +37,9 @@ cc_library( "@com_google_absl//absl/status", "@tsl//tsl/platform:errors", "@tsl//tsl/profiler/backends/cpu:host_tracer_utils", + "@tsl//tsl/profiler/backends/cpu:threadpool_listener", "@tsl//tsl/profiler/backends/cpu:traceme_recorder", + "@tsl//tsl/profiler/lib:profiler_collection", "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@tsl//tsl/profiler/utils:time_utils", @@ -129,12 +131,15 @@ xla_cc_test( "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:env", "@tsl//tsl/platform:test", "@tsl//tsl/platform:types", "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@tsl//tsl/profiler/utils:tf_xplane_visitor", + "@tsl//tsl/profiler/utils:timespan", "@tsl//tsl/profiler/utils:xplane_schema", "@tsl//tsl/profiler/utils:xplane_visitor", ], diff --git a/xla/backends/profiler/cpu/host_tracer.cc b/xla/backends/profiler/cpu/host_tracer.cc index ad79fb1c398cca..7aa97ffe9bad08 100644 --- a/xla/backends/profiler/cpu/host_tracer.cc +++ b/xla/backends/profiler/cpu/host_tracer.cc @@ -23,7 +23,9 @@ limitations under the License. #include "absl/status/status.h" #include "tsl/platform/errors.h" #include "tsl/profiler/backends/cpu/host_tracer_utils.h" +#include "tsl/profiler/backends/cpu/threadpool_listener.h" #include "tsl/profiler/backends/cpu/traceme_recorder.h" +#include "tsl/profiler/lib/profiler_collection.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/time_utils.h" @@ -115,7 +117,12 @@ absl::Status HostTracer::CollectData( // TENSORFLOW_STATUS_OK std::unique_ptr CreateHostTracer( const HostTracerOptions& options) { if (options.trace_level == 0) return nullptr; - return std::make_unique(options.trace_level); + std::vector> profilers; + profilers.push_back(std::make_unique(options.trace_level)); + profilers.push_back( + std::make_unique()); + return std::make_unique( + std::move(profilers)); } } // namespace profiler diff --git a/xla/backends/profiler/cpu/host_tracer_test.cc b/xla/backends/profiler/cpu/host_tracer_test.cc index 881f46e50837ff..ef02a826977a25 100644 --- a/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/xla/backends/profiler/cpu/host_tracer_test.cc @@ -14,19 +14,25 @@ limitations under the License. ==============================================================================*/ #include "xla/backends/profiler/cpu/host_tracer.h" +#include #include #include #include #include +#include #include "absl/types/optional.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/tf_xplane_visitor.h" +#include "tsl/profiler/utils/timespan.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_visitor.h" @@ -37,8 +43,11 @@ namespace { using ::tsl::Env; using ::tsl::Thread; using ::tsl::ThreadOptions; +using ::tsl::profiler::StatType; +using ::tsl::profiler::Timespan; using ::tsl::profiler::TraceMe; using ::tsl::profiler::XEventVisitor; +using ::tsl::profiler::XLineVisitor; using ::tsl::profiler::XPlaneVisitor; using ::tsl::profiler::XStatVisitor; @@ -153,6 +162,67 @@ TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) { EXPECT_EQ(e6.DisplayName(), "Iterator::ParallelMap"); } +TEST(HostTracerTest, CollectEventsFromThreadPool) { + tsl::thread::ThreadPool thread_pool(/*env=*/Env::Default(), + /*name=*/"HostTracerTest", + /*num_threads=*/1); + tsl::BlockingCounter counter(1); + auto tracer = CreateHostTracer({}); + TF_EXPECT_OK(tracer->Start()); + thread_pool.Schedule([&counter] { + TraceMe traceme("hello"); + counter.DecrementCount(); + }); + counter.Wait(); + TF_EXPECT_OK(tracer->Stop()); + tensorflow::profiler::XSpace space; + TF_EXPECT_OK(tracer->CollectData(&space)); + + EXPECT_THAT(space.planes(), testing::SizeIs(1)); + XPlaneVisitor xplane = tsl::profiler::CreateTfXPlaneVisitor(&space.planes(0)); + + bool has_record_event = false; + bool has_start_region_event = false; + bool has_end_region_event = false; + int64_t record_region_id = 0; + int64_t start_region_id = 0; + + Timespan region_timespan; + Timespan traceme_timespan; + + xplane.ForEachLine([&](const XLineVisitor& line) { + line.ForEachEvent([&](const XEventVisitor& event) { + if (event.Name() == tsl::profiler::kThreadpoolListenerRecord) { + has_record_event = true; + auto stat = event.GetStat(StatType::kProducerId); + EXPECT_TRUE(stat.has_value()); + record_region_id = stat->IntOrUintValue(); + } else if (event.Name() == + tsl::profiler::kThreadpoolListenerStartRegion) { + has_start_region_event = true; + auto stat = event.GetStat(StatType::kConsumerId); + EXPECT_TRUE(stat.has_value()); + start_region_id = stat->IntOrUintValue(); + region_timespan = event.GetTimespan(); + } else if (event.Name() == tsl::profiler::kThreadpoolListenerStopRegion) { + has_end_region_event = true; + region_timespan = Timespan::FromEndPoints(region_timespan.begin_ps(), + event.GetTimespan().end_ps()); + } else if (event.Name() == "hello") { + traceme_timespan = event.GetTimespan(); + } + }); + }); + + EXPECT_TRUE(has_record_event); + EXPECT_TRUE(has_start_region_event); + EXPECT_TRUE(has_end_region_event); + + EXPECT_EQ(record_region_id, start_region_id); + + EXPECT_TRUE(region_timespan.Includes(traceme_timespan)); +} + } // namespace } // namespace profiler } // namespace xla