Skip to content

Commit 3da39d8

Browse files
subhamsoni-googlecopybara-github
authored andcommitted
Add worker ID to responses and configure gRPC channels for retries and load balancing.
PiperOrigin-RevId: 827391460
1 parent c0ca36b commit 3da39d8

File tree

9 files changed

+127
-6
lines changed

9 files changed

+127
-6
lines changed

plugin/xprof/profile_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,6 @@ def _get_valid_hosts(
738738
for xplane_path in path.glob(file_pattern):
739739
host_name, _ = _parse_filename(xplane_path.name)
740740
if host_name:
741-
print('host_name: %s', host_name)
742741
all_xplane_files[host_name] = xplane_path
743742
except OSError as e:
744743
logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e)

plugin/xprof/protobuf/worker_service.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,6 @@ message WorkerProfileDataRequest {
4040
message WorkerProfileDataResponse {
4141
// The absolute path to the tool specific output.
4242
string output = 1;
43+
// A unique identifier for the worker that handled the request.
44+
string worker_id = 2;
4345
}

plugin/xprof/worker/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
"@com_github_grpc_grpc//:grpc++",
1515
"@com_google_absl//absl/log",
1616
"@com_google_absl//absl/status:statusor",
17+
"@com_google_absl//absl/strings",
1718
"@org_xprof//plugin/xprof/protobuf:worker_service_cc_grpc_proto",
1819
"@org_xprof//xprof/convert:profile_processor_factory",
1920
"@org_xprof//xprof/convert:tool_options",
@@ -27,6 +28,7 @@ cc_library(
2728
deps = [
2829
":worker_service",
2930
"@com_github_grpc_grpc//:grpc++",
31+
"@com_github_grpc_grpc//:grpc_security_base",
3032
"@com_google_absl//absl/log",
3133
"@com_google_absl//absl/strings",
3234
],
@@ -48,8 +50,10 @@ cc_library(
4850
hdrs = ["stub_factory.h"],
4951
deps = [
5052
"@com_github_grpc_grpc//:grpc++",
53+
"@com_github_grpc_grpc//:grpc_security_base",
5154
"@com_google_absl//absl/base:core_headers",
5255
"@com_google_absl//absl/base:no_destructor",
56+
"@com_google_absl//absl/log",
5357
"@com_google_absl//absl/strings",
5458
"@com_google_absl//absl/synchronization",
5559
"@org_xprof//plugin/xprof/protobuf:worker_service_cc_grpc_proto",

plugin/xprof/worker/grpc_server.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121

2222
#include "absl/log/log.h"
2323
#include "absl/strings/str_cat.h"
24+
#include "grpc/grpc.h"
2425
#include "grpcpp/security/server_credentials.h"
2526
#include "grpcpp/server.h"
2627
#include "grpcpp/server_builder.h"
@@ -39,6 +40,12 @@ void InitializeGrpcServer(int port) {
3940
std::string server_address = absl::StrCat(kServerAddressPrefix, port);
4041
::grpc::ServerBuilder builder;
4142
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
43+
builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIME_MS, 20000);
44+
builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 10000);
45+
builder.AddChannelArgument(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);
46+
builder.AddChannelArgument(GRPC_ARG_HTTP2_MAX_PING_STRIKES, 0);
47+
builder.AddChannelArgument(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1);
48+
builder.AddChannelArgument(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1);
4249
worker_service =
4350
std::make_unique<::xprof::profiler::ProfileWorkerServiceImpl>();
4451
builder.RegisterService(worker_service.get());

plugin/xprof/worker/grpc_utils.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ limitations under the License.
1515

1616
#include "plugin/xprof/worker/grpc_utils.h"
1717

18+
#include <cstddef>
1819
#include <string>
1920

21+
#ifdef _WIN32
22+
#include <windows.h>
23+
#else
24+
#include <unistd.h>
25+
#endif
26+
2027
#include "absl/status/status.h"
2128
#include "grpcpp/support/status.h"
2229

@@ -33,5 +40,35 @@ grpc::Status ToGrpcStatus(const absl::Status& absl_status) {
3340
std::string(absl_status.message()));
3441
}
3542

43+
std::string GetHostname() {
44+
// Use a named constant rather than a magic number.
45+
// 256 is generally sufficient for standard hostnames (POSIX HOST_NAME_MAX is
46+
// usually 255, DNS limit is 253).
47+
constexpr size_t kBufferSize = 256;
48+
char buffer[kBufferSize] = {0};
49+
50+
#ifdef _WIN32
51+
DWORD size = kBufferSize;
52+
// Use GetComputerNameExA with ComputerNamePhysicalDnsHostname to get the
53+
// real, long hostname without the 15-char NetBIOS limit.
54+
if (GetComputerNameExA(ComputerNamePhysicalDnsHostname, buffer, &size)) {
55+
return std::string(buffer, size);
56+
}
57+
// Fallback to generic simple name if DNS name fails.
58+
size = kBufferSize;
59+
if (GetComputerNameA(buffer, &size)) {
60+
return std::string(buffer, size);
61+
}
62+
#else
63+
if (gethostname(buffer, sizeof(buffer)) == 0) {
64+
// Ensure null-termination in case of edge-case truncation.
65+
buffer[kBufferSize - 1] = '\0';
66+
return std::string(buffer);
67+
}
68+
#endif
69+
70+
return "unknown_host";
71+
}
72+
3673
} // namespace profiler
3774
} // namespace xprof

plugin/xprof/worker/grpc_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616
#ifndef THIRD_PARTY_XPROF_PLUGIN_TENSORBOARD_PLUGIN_PROFILE_WORKER_GRPC_UTILS_H_
1717
#define THIRD_PARTY_XPROF_PLUGIN_TENSORBOARD_PLUGIN_PROFILE_WORKER_GRPC_UTILS_H_
1818

19+
#include <string>
20+
1921
#include "absl/status/status.h"
2022
#include "grpcpp/support/status.h"
2123

@@ -26,6 +28,9 @@ namespace profiler {
2628
absl::Status ToAbslStatus(const grpc::Status& grpc_status);
2729
grpc::Status ToGrpcStatus(const absl::Status& absl_status);
2830

31+
// Returns hostname of machine.
32+
std::string GetHostname();
33+
2934
} // namespace profiler
3035
} // namespace xprof
3136

plugin/xprof/worker/stub_factory.cc

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ limitations under the License.
2525
#include "absl/base/const_init.h"
2626
#include "absl/base/no_destructor.h"
2727
#include "absl/base/thread_annotations.h"
28+
#include "absl/log/log.h"
2829
#include "absl/strings/str_split.h"
2930
#include "absl/synchronization/mutex.h"
31+
#include "grpc/grpc.h"
3032
#include "grpcpp/channel.h"
3133
#include "grpcpp/create_channel.h"
3234
#include "grpcpp/security/credentials.h"
35+
#include "grpcpp/support/channel_arguments.h"
3336
#include "plugin/xprof/protobuf/worker_service.grpc.pb.h"
3437

3538
namespace xprof {
@@ -56,6 +59,59 @@ static absl::NoDestructor<
5659
static std::atomic<size_t> gCurrentStubIndex = 0;
5760
static std::atomic<bool> gStubsInitialized = false;
5861

62+
// Creates a gRPC channel for a given worker address. This channel is
63+
// configured with a service config that enables a robust retry policy for
64+
// transient errors and sets the client-side load balancing policy to
65+
// round-robin.
66+
std::shared_ptr<::grpc::Channel> CreateWorkerChannelForAddress(
67+
const std::string& address) {
68+
grpc::ChannelArguments args;
69+
// Set a service config for the channel that enables retries.
70+
// This config will be applied to all methods of the service.
71+
// Service Config: 10-minute timeout + conservative retries + LB
72+
const char* kServiceConfigJson = R"pb(
73+
{
74+
"methodConfig":
75+
[ {
76+
"name":
77+
[ {}],
78+
"timeout": "600s",
79+
"retryPolicy": {
80+
"maxAttempts": 4,
81+
"initialBackoff": "2s",
82+
"maxBackoff": "120s",
83+
"backoffMultiplier": 2.0,
84+
"retryableStatusCodes": [
85+
"UNAVAILABLE",
86+
"RESOURCE_EXHAUSTED",
87+
"INTERNAL",
88+
"ABORTED",
89+
"NOT_FOUND"
90+
]
91+
}
92+
}],
93+
"loadBalancingConfig":
94+
[ { "round_robin": {} }]
95+
})pb";
96+
args.SetServiceConfigJSON(kServiceConfigJson);
97+
args.SetLoadBalancingPolicyName("round_robin");
98+
args.SetInt(GRPC_ARG_DNS_MIN_TIME_BETWEEN_RESOLUTIONS_MS, 5000);
99+
args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, 20000);
100+
args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 10000);
101+
args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1);
102+
args.SetInt(GRPC_ARG_ENABLE_RETRIES, 1);
103+
args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1);
104+
args.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1);
105+
106+
// Create the channel with insecure credentials. This is acceptable because
107+
// the communication between the aggregator and workers happens within a
108+
// trusted, internal network environment.
109+
std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateCustomChannel(
110+
address, ::grpc::InsecureChannelCredentials(), args); // NOLINT
111+
LOG(INFO) << "Created gRPC channel for address: " << address;
112+
return channel;
113+
}
114+
59115
void InitializeStubs(const std::string& worker_service_addresses) {
60116
absl::MutexLock lock(&gStubsMutex);
61117
if (gStubsInitialized.load(std::memory_order_acquire)) {
@@ -66,8 +122,8 @@ void InitializeStubs(const std::string& worker_service_addresses) {
66122
absl::StrSplit(worker_service_addresses, kAddressDelimiter);
67123
for (const std::string& address : addresses) {
68124
if (address.empty()) continue;
69-
std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateChannel(
70-
address, ::grpc::InsecureChannelCredentials()); // NOLINT
125+
std::shared_ptr<::grpc::Channel> channel =
126+
CreateWorkerChannelForAddress(address);
71127
gStubs->push_back(XprofAnalysisWorkerService::NewStub(channel));
72128
}
73129
gStubsInitialized.store(true, std::memory_order_release);

plugin/xprof/worker/worker_service.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919

2020
#include "absl/log/log.h"
2121
#include "absl/status/statusor.h"
22+
#include "absl/strings/str_join.h"
2223
#include "grpcpp/server_context.h"
2324
#include "grpcpp/support/status.h"
2425
#include "xprof/convert/profile_processor_factory.h"
@@ -32,9 +33,12 @@ ::grpc::Status ProfileWorkerServiceImpl::GetProfileData(
3233
::grpc::ServerContext* context,
3334
const ::xprof::pywrap::WorkerProfileDataRequest* request,
3435
::xprof::pywrap::WorkerProfileDataResponse* response) {
35-
LOG(INFO) << "ProfileWorkerServiceImpl::GetProfileData called with request: "
36-
<< request->DebugString();
3736
const auto& origin_request = request->origin_request();
37+
LOG(INFO) << "GetProfileData tool:" << origin_request.tool_name()
38+
<< " session:" << origin_request.session_id() << " params:{"
39+
<< absl::StrJoin(origin_request.parameters(), ",",
40+
absl::PairFormatter("="))
41+
<< "}";
3842
tensorflow::profiler::ToolOptions tool_options;
3943
for (const auto& [key, value] : origin_request.parameters()) {
4044
tool_options[key] = value;
@@ -52,8 +56,12 @@ ::grpc::Status ProfileWorkerServiceImpl::GetProfileData(
5256
return ToGrpcStatus(map_output_file.status());
5357
}
5458
response->set_output(*map_output_file);
59+
response->set_worker_id(GetHostname());
60+
5561
LOG(INFO)
56-
<< "ProfileWorkerServiceImpl::GetProfileData finished successfully.";
62+
<< "ProfileWorkerServiceImpl::GetProfileData finished successfully by "
63+
"worker: "
64+
<< response->worker_id();
5765
return ::grpc::Status::OK;
5866
}
5967

xprof/convert/xplane_to_tools_data.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,9 @@ absl::StatusOr<std::string> CallWorkerService(const std::string& xspace_path,
465465
if (!grpc_status.ok()) {
466466
return ::xprof::profiler::ToAbslStatus(grpc_status);
467467
}
468+
LOG(INFO) << "gRPC response: tool=" << tool_name
469+
<< ", session=" << xspace_path
470+
<< ", worker_id=" << response.worker_id();
468471
return response.output();
469472
}
470473

0 commit comments

Comments
 (0)