Skip to content

Commit

Permalink
TCPStore: add ping to verify network connectivity on connect (pytorch…
Browse files Browse the repository at this point in the history
…#129985)

This does a round trip request on socket connect -- this allows for detecting connection resets etc and retrying before the non-retryable application requests are sent.

This adds support for PING to both the libuv and legacy backend.

Example error:
```
[trainer85612|12]:W0701 13:41:43.421574  4776 TCPStore.cpp:182] [c10d] recvValue failed on SocketImpl(fd=24, ...): Connection reset by peer
[trainer85612|12]:Exception raised from recvBytes at /mnt/code/pytorch/torch/csrc/distributed/c10d/Utils.hpp:669 (most recent call first):
...
[trainer85612|12]:#9 c10d::TCPStore::incrementValueBy(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, long) from /packages/.../conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so:84809637
[trainer85612|12]:pytorch#10 c10d::TCPStore::waitForWorkers() from /packages/.../conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so:84812868
[trainer85612|12]:pytorch#11 c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) from /packages/.../conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so:84814775
```

Test plan:

```
python test/distributed/test_store.py -v
```

```
tristanr@devvm4382 ~/pytorch (d4l3k/tcpstore_ping)> python ~/pt_tests/tcpstore_large_test.py
starting pool
started 90000
started 30000
started 70000
started 20000
started 80000
started 60000
started 0
[W702 16:16:25.301681870 TCPStore.cpp:343] [c10d] Starting store with 100000 workers but somaxconn is 4096.This might cause instability during bootstrap, consider increasing it.
init 20000
set 20000
init 80000
set 80000
init 70000
set 70000
init 60000
set 60000
init 30000
set 30000
init 90000
set 90000
started 40000
init 40000
set 40000
started 50000
init 50000
set 50000
started 10000
init 10000
set 10000
init 0
set 0
run finished 617.2992351055145
```

Pull Request resolved: pytorch#129985
Approved by: https://github.com/rsdcastro, https://github.com/kurman
  • Loading branch information
d4l3k authored and pytorchmergebot committed Jul 3, 2024
1 parent 91a8376 commit 9ee8c18
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 0 deletions.
41 changes: 41 additions & 0 deletions test/distributed/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import os
import socket
import struct
import sys
import tempfile
import threading
Expand Down Expand Up @@ -1005,6 +1006,46 @@ def _run_test(self):
dist.destroy_process_group()


class TestClientProtocol(TestCase):
def test_client_connect(self) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("localhost", 0))
port = sock.getsockname()[1]

def listen() -> None:
sock.listen()
conn, _ = sock.accept()

# VALIDATE
# 0x3C85F7CE
self.assertEqual(conn.recv(5), b"\x00\xce\xf7\x85\x3c")

# PING
data = conn.recv(5)
self.assertEqual(data[0], 13)
nonce = struct.unpack("i", data[1:])[0]
self.assertEqual(nonce, os.getpid())

# send PING nonce response
conn.sendall(data[1:])

conn.close()

thread = threading.Thread(target=listen)
thread.start()

store = dist.TCPStore(
host_name="localhost",
port=port,
world_size=2,
is_master=False,
timeout=timedelta(seconds=2),
wait_for_workers=False,
)

thread.join()


if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/distributed/c10d/TCPStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
// client's first query for validation
validate();

// ping to verify network connectivity
ping();

// success
break;
} catch (const c10::DistNetworkError& ex) {
Expand Down Expand Up @@ -453,6 +456,19 @@ void TCPStore::validate() {
buffer.flush();
}

void TCPStore::ping() {
const std::lock_guard<std::mutex> lock(activeOpLock_);
detail::SendBuffer buffer(*client_, detail::QueryType::PING);

uint32_t nonce = getpid();
buffer.appendValue<std::uint32_t>(nonce);
buffer.flush();

uint32_t returnedNonce = client_->receiveValue<std::uint32_t>();
TORCH_INTERNAL_ASSERT(
nonce == returnedNonce, "Ping failed, invalid nonce returned");
}

void TCPStore::_splitSet(
const std::string& key,
const std::vector<uint8_t>& data) {
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/TCPStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class TORCH_API TCPStore : public Store {
private:
int64_t incrementValueBy(const std::string& key, int64_t delta);

void ping();
void validate();

std::vector<uint8_t> doGet(const std::string& key);
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/distributed/c10d/TCPStoreBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TCPStoreMasterDaemon : public BackgroundThread {
// The master runs on a single thread so only
// one handler can be executed at a time
void validateHandler(int socket);
void pingHandler(int socket);
void setHandler(int socket);
void compareSetHandler(int socket);
void addHandler(int socket);
Expand Down Expand Up @@ -267,6 +268,10 @@ void TCPStoreMasterDaemon::query(int socket) {
TORCH_CHECK(
false, "Miscellaneous client without VALIDATE query is detected");
}

} else if (qt == QueryType::PING) {
pingHandler(socket);

} else if (qt == QueryType::SET) {
setHandler(socket);

Expand Down Expand Up @@ -334,6 +339,12 @@ void TCPStoreMasterDaemon::validateHandler(int socket) {
}
}

void TCPStoreMasterDaemon::pingHandler(int socket) {
uint32_t nonce = 0;
tcputil::recvBytes<uint32_t>(socket, &nonce, 1);
tcputil::sendValue<uint32_t>(socket, nonce);
}

void TCPStoreMasterDaemon::setHandler(int socket) {
std::string key = tcputil::recvString(socket);
std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/TCPStoreBackend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum class QueryType : uint8_t {
MULTI_GET,
MULTI_SET,
CANCEL_WAIT,
PING,
};

enum class CheckResponseType : uint8_t { READY, NOT_READY };
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,10 @@ class UvClient : public UvTcpSocket {
return;
} else {
switch ((QueryType)command) {
case QueryType::PING:
if (!parse_ping_command())
return;
break;
case QueryType::SET:
if (!parse_set_command())
return;
Expand Down Expand Up @@ -749,6 +753,18 @@ class UvClient : public UvTcpSocket {
return true;
}

bool parse_ping_command() {
uint32_t nonce;
if (!stream.read_value(nonce)) {
return false;
}

StreamWriter sw(iptr());
sw.write_value(nonce);
sw.send();
return true;
}

bool parse_set_command() {
std::string key;
if (!stream.read_key(key))
Expand Down

0 comments on commit 9ee8c18

Please sign in to comment.