Skip to content

Commit

Permalink
chore: add timeout to replication sockets (#3434)
Browse files Browse the repository at this point in the history
* chore: add timeout fo replication sockets

Master will stop the replication flow if writes could not progress for more than K millis.

---------

Signed-off-by: Roman Gershman <[email protected]>
Signed-off-by: Roman Gershman <[email protected]>
Co-authored-by: Shahar Mike <[email protected]>
  • Loading branch information
romange and chakaz authored Aug 7, 2024
1 parent 7c84b8e commit 1cbfcd4
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 41 deletions.
67 changes: 51 additions & 16 deletions src/server/dflycmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
using namespace std;

ABSL_DECLARE_FLAG(bool, info_replication_valkey_compatible);
ABSL_DECLARE_FLAG(uint32_t, replication_timeout);

namespace dfly {

Expand Down Expand Up @@ -119,6 +120,7 @@ void DflyCmd::ReplicaInfo::Cancel() {
}

flow->full_sync_fb.JoinIfNeeded();
flow->conn = nullptr;
});

// Wait for error handler to quit.
Expand Down Expand Up @@ -501,6 +503,7 @@ void DflyCmd::ReplicaOffset(CmdArgList args, ConnectionContext* cntx) {
OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
DCHECK(!flow->full_sync_fb.IsJoinable());
DCHECK(shard);
DCHECK(flow->conn);

// The summary contains the LUA scripts, so make sure at least (and exactly one)
// of the flows also contain them.
Expand All @@ -527,13 +530,10 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
return OpStatus::CANCELLED;
}

// Shard can be null for io thread.
if (shard != nullptr) {
if (flow->start_partial_sync_at.has_value())
saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
else
saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
}
if (flow->start_partial_sync_at.has_value())
saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
else
saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);

flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx);
return OpStatus::OK;
Expand All @@ -555,12 +555,12 @@ void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) {

OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
// Create streamer for shard flows.
DCHECK(shard);
DCHECK(flow->conn);

if (shard != nullptr) {
flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx));
bool send_lsn = flow->version >= DflyVersion::VER4;
flow->streamer->Start(flow->conn->socket(), send_lsn);
}
flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx));
bool send_lsn = flow->version >= DflyVersion::VER4;
flow->streamer->Start(flow->conn->socket(), send_lsn);

// Register cleanup.
flow->cleanup = [flow]() {
Expand All @@ -577,6 +577,8 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
error_code ec;

if (ec = flow->saver->SaveBody(cntx, nullptr); ec) {
if (!flow->conn->socket()->IsOpen())
ec = make_error_code(errc::operation_canceled); // we cancelled the operation.
cntx->ReportError(ec);
return;
}
Expand All @@ -588,8 +590,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
}
}

auto DflyCmd::CreateSyncSession(ConnectionContext* cntx)
-> std::pair<uint32_t, std::shared_ptr<ReplicaInfo>> {
auto DflyCmd::CreateSyncSession(ConnectionContext* cntx) -> std::pair<uint32_t, unsigned> {
unique_lock lk(mu_);
unsigned sync_id = next_sync_id_++;

Expand All @@ -612,7 +613,7 @@ auto DflyCmd::CreateSyncSession(ConnectionContext* cntx)
auto [it, inserted] = replica_infos_.emplace(sync_id, std::move(replica_ptr));
CHECK(inserted);

return *it;
return {it->first, flow_count};
}

auto DflyCmd::GetReplicaInfoFromConnection(ConnectionContext* cntx)
Expand Down Expand Up @@ -651,6 +652,40 @@ void DflyCmd::StopReplication(uint32_t sync_id) {
replica_infos_.erase(sync_id);
}

void DflyCmd::BreakStalledFlowsInShard() {
unique_lock global_lock(mu_, try_to_lock);

// give up on blocking because we run this function periodically in a background fiber,
// so it will eventually grab the lock.
if (!global_lock.owns_lock())
return;

ShardId sid = EngineShard::tlocal()->shard_id();
vector<uint32_t> deleted;

for (auto [sync_id, replica_ptr] : replica_infos_) {
shared_lock replica_lock = replica_ptr->GetSharedLock();

if (!replica_ptr->flows[sid].saver)
continue;

// If saver is present - we are currently using it for full sync.
int64_t last_write_ns = replica_ptr->flows[sid].saver->GetLastWriteTime();
int64_t timeout_ns = int64_t(absl::GetFlag(FLAGS_replication_timeout)) * 1'000'000LL;
int64_t now = absl::GetCurrentTimeNanos();
if (last_write_ns > 0 && last_write_ns + timeout_ns < now) {
VLOG(1) << "Breaking full sync for sync_id " << sync_id << " last_write_ts: " << last_write_ns
<< ", now: " << now;
deleted.push_back(sync_id);
replica_lock.unlock();
replica_ptr->Cancel();
}
}

for (auto sync_id : deleted)
replica_infos_.erase(sync_id);
}

shared_ptr<DflyCmd::ReplicaInfo> DflyCmd::GetReplicaInfo(uint32_t sync_id) {
lock_guard lk(mu_);

Expand Down Expand Up @@ -807,7 +842,7 @@ void DflyCmd::Shutdown() {
void FlowInfo::TryShutdownSocket() {
// Close socket for clean disconnect.
if (conn->socket()->IsOpen()) {
(void)conn->socket()->Shutdown(SHUT_RDWR);
std::ignore = conn->socket()->Shutdown(SHUT_RDWR);
}
}

Expand Down
7 changes: 5 additions & 2 deletions src/server/dflycmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class DflyCmd {
// Stop all background processes so we can exit in orderly manner.
void Shutdown();

// Create new sync session.
std::pair<uint32_t, std::shared_ptr<ReplicaInfo>> CreateSyncSession(ConnectionContext* cntx);
// Create new sync session. Returns (session_id, number of flows)
std::pair<uint32_t, unsigned> CreateSyncSession(ConnectionContext* cntx);

// Master side acces method to replication info of that connection.
std::shared_ptr<ReplicaInfo> GetReplicaInfoFromConnection(ConnectionContext* cntx);
Expand All @@ -160,6 +160,9 @@ class DflyCmd {
// Sets metadata.
void SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version);

// Tries to break those flows that stuck on socket write for too long time.
void BreakStalledFlowsInShard();

private:
// JOURNAL [START/STOP]
// Start or stop journaling.
Expand Down
16 changes: 9 additions & 7 deletions src/server/engine_shard_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -682,13 +682,14 @@ void EngineShard::RetireExpiredAndEvict() {
}

void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
std::function<void()> global_handler) {
std::function<void()> shard_handler) {
VLOG(1) << "RunPeriodic with period " << period_ms.count() << "ms";

bool runs_global_periodic = (shard_id() == 0); // Only shard 0 runs global periodic.
unsigned global_count = 0;
int64_t last_stats_time = time(nullptr);
int64_t last_heartbeat_ms = INT64_MAX;
int64_t last_handler_ms = 0;

while (true) {
if (fiber_periodic_done_.WaitFor(period_ms)) {
Expand All @@ -702,6 +703,10 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
}
Heartbeat();
last_heartbeat_ms = fb2::ProactorBase::GetMonotonicTimeNs() / 1000000;
if (shard_handler && last_handler_ms + 100 < last_heartbeat_ms) {
last_handler_ms = last_heartbeat_ms;
shard_handler();
}

if (runs_global_periodic) {
++global_count;
Expand All @@ -727,10 +732,6 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
rss_mem_peak.store(total_rss, memory_order_relaxed);
}
}

if (global_handler) {
global_handler();
}
}
}
}
Expand Down Expand Up @@ -903,7 +904,7 @@ size_t GetTieredFileLimit(size_t threads) {
return max_shard_file_size;
}

void EngineShardSet::Init(uint32_t sz, std::function<void()> global_handler) {
void EngineShardSet::Init(uint32_t sz, std::function<void()> shard_handler) {
CHECK_EQ(0u, size());
shard_queue_.resize(sz);

Expand All @@ -922,7 +923,8 @@ void EngineShardSet::Init(uint32_t sz, std::function<void()> global_handler) {
shard->InitTieredStorage(pb, max_shard_file_size);

// Must be last, as it accesses objects initialized above.
shard->StartPeriodicFiber(pb, global_handler);
// We can not move shard_handler because this code is called multiple times.
shard->StartPeriodicFiber(pb, shard_handler);
}
});
}
Expand Down
6 changes: 3 additions & 3 deletions src/server/engine_shard_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ class EngineShard {
// blocks the calling fiber.
void Shutdown(); // called before destructing EngineShard.

void StartPeriodicFiber(util::ProactorBase* pb, std::function<void()> global_handler);
void StartPeriodicFiber(util::ProactorBase* pb, std::function<void()> shard_handler);

void Heartbeat();
void RetireExpiredAndEvict();

void RunPeriodic(std::chrono::milliseconds period_ms, std::function<void()> global_handler);
void RunPeriodic(std::chrono::milliseconds period_ms, std::function<void()> shard_handler);

void CacheStats();

Expand Down Expand Up @@ -288,7 +288,7 @@ class EngineShardSet {
return pp_;
}

void Init(uint32_t size, std::function<void()> global_handler);
void Init(uint32_t size, std::function<void()> shard_handler);

// Shutdown sequence:
// - EngineShardSet.PreShutDown()
Expand Down
10 changes: 5 additions & 5 deletions src/server/journal/streamer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

using namespace facade;

ABSL_FLAG(uint32_t, replication_stream_timeout, 500,
"Time in milliseconds to wait for the replication output buffer go below "
"the throttle limit.");
ABSL_FLAG(uint32_t, replication_timeout, 10000,
"Time in milliseconds to wait for the replication writes being stuck.");

ABSL_FLAG(uint32_t, replication_stream_output_limit, 64_KB,
"Time to wait for the replication output buffer go below the throttle limit");

Expand Down Expand Up @@ -155,8 +155,8 @@ void JournalStreamer::ThrottleIfNeeded() {
if (IsStopped() || !IsStalled())
return;

auto next = chrono::steady_clock::now() +
chrono::milliseconds(absl::GetFlag(FLAGS_replication_stream_timeout));
auto next =
chrono::steady_clock::now() + chrono::milliseconds(absl::GetFlag(FLAGS_replication_timeout));
size_t inflight_start = in_flight_bytes_;
size_t sent_start = total_sent_;

Expand Down
6 changes: 5 additions & 1 deletion src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -887,14 +887,18 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
ServerState::tlocal()->UpdateChannelStore(cs);
});

shard_set->Init(shard_num, nullptr);
const auto tcp_disabled = GetFlag(FLAGS_port) == 0u;
// We assume that listeners.front() is the main_listener
// see dfly_main RunEngine
if (!tcp_disabled && !listeners.empty()) {
acl_family_.Init(listeners.front(), &user_registry_);
}

// Initialize shard_set with a global callback running once in a while in the shard threads.
shard_set->Init(shard_num, [this] { server_family_.GetDflyCmd()->BreakStalledFlowsInShard(); });

// Requires that shard_set will be initialized before because server_family_.Init might
// load the snapshot.
server_family_.Init(acceptor, std::move(listeners));
}

Expand Down
21 changes: 17 additions & 4 deletions src/server/rdb_save.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1135,10 +1135,15 @@ class RdbSaver::Impl {
return &meta_serializer_;
}

int64_t last_write_ts() const {
return last_write_time_ns_;
}

private:
unique_ptr<SliceSnapshot>& GetSnapshot(EngineShard* shard);

io::Sink* sink_;
int64_t last_write_time_ns_ = -1; // last write call.
vector<unique_ptr<SliceSnapshot>> shard_snapshots_;
// used for serializing non-body components in the calling fiber.
RdbSerializer meta_serializer_;
Expand Down Expand Up @@ -1263,10 +1268,12 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
continue;

DVLOG(2) << "Pulled " << record->id;
auto before = absl::GetCurrentTimeNanos();
last_write_time_ns_ = absl::GetCurrentTimeNanos();
io_error = sink_->Write(io::Buffer(record->value));
stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - before) / 1'000;

stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - last_write_time_ns_) / 1'000;
stats.rdb_save_count++;
last_write_time_ns_ = -1;
if (io_error) {
VLOG(1) << "Error writing to sink " << io_error.message();
break;
Expand Down Expand Up @@ -1369,7 +1376,10 @@ RdbSaver::SnapshotStats RdbSaver::Impl::GetCurrentSnapshotProgress() const {
}

error_code RdbSaver::Impl::FlushSerializer() {
return serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry);
last_write_time_ns_ = absl::GetCurrentTimeNanos();
auto ec = serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry);
last_write_time_ns_ = -1;
return ec;
}

RdbSaver::GlobalData RdbSaver::GetGlobalData(const Service* service) {
Expand Down Expand Up @@ -1482,7 +1492,6 @@ error_code RdbSaver::SaveBody(Context* cntx, RdbTypeFreqMap* freq_map) {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation());
if (io_error) {
LOG(ERROR) << "io error " << io_error;
return io_error;
}
if (cntx->GetError()) {
Expand Down Expand Up @@ -1572,6 +1581,10 @@ RdbSaver::SnapshotStats RdbSaver::GetCurrentSnapshotProgress() const {
return impl_->GetCurrentSnapshotProgress();
}

int64_t RdbSaver::GetLastWriteTime() const {
return impl_->last_write_ts();
}

void SerializerBase::AllocateCompressorOnce() {
if (compressor_impl_) {
return;
Expand Down
4 changes: 4 additions & 0 deletions src/server/rdb_save.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class RdbSaver {
// Fetch global data to be serialized in summary part of a snapshot / full sync.
static GlobalData GetGlobalData(const Service* service);

// Returns time in nanos of start of the last pending write interaction.
// Returns -1 if no write operations are currently pending.
int64_t GetLastWriteTime() const;

private:
class Impl;

Expand Down
4 changes: 2 additions & 2 deletions src/server/server_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2665,7 +2665,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
std::string_view arg = ArgS(args, i + 1);
if (cmd == "CAPA") {
if (arg == "dragonfly" && args.size() == 2 && i == 0) {
auto [sid, replica_info] = dfly_cmd_->CreateSyncSession(cntx);
auto [sid, flow_count] = dfly_cmd_->CreateSyncSession(cntx);
cntx->conn()->SetName(absl::StrCat("repl_ctrl_", sid));

string sync_id = absl::StrCat("SYNC", sid);
Expand All @@ -2681,7 +2681,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
rb->StartArray(4);
rb->SendSimpleString(master_replid_);
rb->SendSimpleString(sync_id);
rb->SendLong(replica_info->flows.size());
rb->SendLong(flow_count);
rb->SendLong(unsigned(DflyVersion::CURRENT_VER));
return;
}
Expand Down
Loading

0 comments on commit 1cbfcd4

Please sign in to comment.