diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index 83cbdb616c31..139a01b35996 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -534,7 +534,9 @@ error_code RdbSerializer::SaveStreamObject(const PrimeValue& pv) { stream* s = (stream*)pv.RObjPtr(); rax* rax = s->rax_tree; - RETURN_ON_ERR(SaveLen(raxSize(rax))); + const size_t rax_size = raxSize(rax); + + RETURN_ON_ERR(SaveLen(rax_size)); /* Serialize all the listpacks inside the radix tree as they are, * when loading back, we'll use the first entry of each listpack @@ -542,22 +544,22 @@ error_code RdbSerializer::SaveStreamObject(const PrimeValue& pv) { raxIterator ri; raxStart(&ri, rax); raxSeek(&ri, "^", NULL, 0); - while (raxNext(&ri)) { + + auto stop_listpacks_rax = absl::MakeCleanup([&] { raxStop(&ri); }); + + for (size_t i = 0; raxNext(&ri); i++) { uint8_t* lp = (uint8_t*)ri.data; size_t lp_bytes = lpBytes(lp); - error_code ec = SaveString((uint8_t*)ri.key, ri.key_len); - if (ec) { - raxStop(&ri); - return ec; - } - ec = SaveString(lp, lp_bytes); - if (ec) { - raxStop(&ri); - return ec; - } + RETURN_ON_ERR(SaveString((uint8_t*)ri.key, ri.key_len)); + RETURN_ON_ERR(SaveString(lp, lp_bytes)); + + const FlushState flush_state = + (i + 1 < rax_size) ? FlushState::kFlushMidEntry : FlushState::kFlushEndEntry; + FlushIfNeeded(flush_state); } - raxStop(&ri); + + std::move(stop_listpacks_rax).Invoke(); /* Save the number of elements inside the stream. We cannot obtain * this easily later, since our macro nodes should be checked for @@ -597,7 +599,7 @@ error_code RdbSerializer::SaveStreamObject(const PrimeValue& pv) { raxStart(&ri, s->cgroups); raxSeek(&ri, "^", NULL, 0); - auto cleanup = absl::MakeCleanup([&] { raxStop(&ri); }); + auto stop_cgroups_rax = absl::MakeCleanup([&] { raxStop(&ri); }); while (raxNext(&ri)) { streamCG* cg = (streamCG*)ri.data; diff --git a/tests/dragonfly/snapshot_test.py b/tests/dragonfly/snapshot_test.py index 81fd7f7fc00a..cb3f658fc369 100644 --- a/tests/dragonfly/snapshot_test.py +++ b/tests/dragonfly/snapshot_test.py @@ -574,6 +574,7 @@ async def test_tiered_entries_throttle(async_client: aioredis.Redis): ("SET"), ("ZSET"), ("LIST"), + ("STREAM"), ], ) @pytest.mark.slow