Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: implement sharded pub/sub #4518

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ jobs:
cd ${GITHUB_WORKSPACE}/build
echo Run ctest -V -L DFLY
GLOG_alsologtostderr=1 GLOG_vmodule=rdb_load=1,rdb_save=1,snapshot=1 \
FLAGS_fiber_safety_margin=4096 FLAGS_list_experimental_v2=true timeout 20m ctest -V -L DFLY
FLAGS_fiber_safety_margin=4096 FLAGS_list_experimental_v2=true timeout 20m ctest -V -L DFLY -E allocation_tracker_test

# Run allocation tracker test separately. It generates a TON of logs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allocation tracker test generates a gazillion number of logs. I silenced it that way by excluding the test from running with alsologtostderr flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FLAGS_fiber_safety_margin=4096 FLAGS_force_epoll=true GLOG_vmodule=rdb_load=1,rdb_save=1,snapshot=1 timeout 5m ./allocation_tracker_test

echo "Running tests with --force_epoll"

Expand Down
10 changes: 8 additions & 2 deletions src/facade/command_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ class CommandId {

// PSUBSCRIBE/PUNSUBSCRIBE variant
bool IsPSub() const {
return is_p_sub_;
return is_p_pub_sub_;
}

// SSUBSCRIBE/SUNSUBSCRIBE variant
bool IsShardedPSub() const {
return is_sharded_pub_sub_;
}

protected:
Expand All @@ -118,7 +123,8 @@ class CommandId {
bool restricted_ = false;

bool is_pub_sub_ = false;
bool is_p_sub_ = false;
bool is_sharded_pub_sub_ = false;
bool is_p_pub_sub_ = false;
};

} // namespace facade
4 changes: 3 additions & 1 deletion src/facade/facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first
if (name_ == "PUBLISH" || name_ == "SUBSCRIBE" || name_ == "UNSUBSCRIBE") {
is_pub_sub_ = true;
} else if (name_ == "PSUBSCRIBE" || name_ == "PUNSUBSCRIBE") {
is_p_sub_ = true;
is_p_pub_sub_ = true;
} else if (name_ == "SPUBLISH" || name_ == "SSUBSCRIBE" || name_ == "SUNSUBSCRIBE") {
is_sharded_pub_sub_ = true;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/server/acl/validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace dfly::acl {

std::pair<bool, AclLog::Reason> auth_res;

if (id.IsPubSub()) {
if (id.IsPubSub() || id.IsShardedPSub()) {
auth_res = IsPubSubCommandAuthorized(false, cntx.acl_commands, cntx.pub_sub, tail_args, id);
} else if (id.IsPSub()) {
auth_res = IsPubSubCommandAuthorized(true, cntx.acl_commands, cntx.pub_sub, tail_args, id);
Expand Down
62 changes: 54 additions & 8 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -919,18 +919,33 @@ void Service::Shutdown() {
facade::Connection::Shutdown();
}

OpResult<KeyIndex> DetermineClusterKeys(const CommandId* cid, CmdArgList args) {
if (!cid->IsShardedPSub()) {
return DetermineKeys(cid, args);
}

// Sharded pub sub
// Command form: SPUBLISH shardchannel message
if (cid->name() == "SPUBLISH") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to bypass DetermineKeys only for this specific case.

return {KeyIndex(0, 1)};
}

return {KeyIndex(0, args.size())};
}

optional<ErrorReply> Service::CheckKeysOwnership(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx) {
if (dfly_cntx.is_replicating) {
// Always allow commands on the replication port, as it might be for future-owned keys.
return nullopt;
}

if (cid->first_key_pos() == 0) {
if (cid->first_key_pos() == 0 && !cid->IsShardedPSub()) {
return nullopt; // No key command.
}

OpResult<KeyIndex> key_index_res = DetermineKeys(cid, args);
OpResult<KeyIndex> key_index_res = DetermineClusterKeys(cid, args);

if (!key_index_res) {
return ErrorReply{key_index_res.status()};
}
Expand Down Expand Up @@ -2258,8 +2273,9 @@ void Service::Exec(CmdArgList args, const CommandContext& cmd_cntx) {
VLOG(2) << "Exec completed";
}

void Service::Publish(CmdArgList args, const CommandContext& cmd_cntx) {
if (IsClusterEnabled()) {
namespace {
void PublishImpl(bool reject_cluster, CmdArgList args, const CommandContext& cmd_cntx) {
if (reject_cluster && IsClusterEnabled()) {
return cmd_cntx.rb->SendError("PUBLISH is not supported in cluster mode yet");
}
string_view channel = ArgS(args, 0);
Expand All @@ -2269,17 +2285,17 @@ void Service::Publish(CmdArgList args, const CommandContext& cmd_cntx) {
cmd_cntx.rb->SendLong(cs->SendMessages(channel, messages));
}

void Service::Subscribe(CmdArgList args, const CommandContext& cmd_cntx) {
if (IsClusterEnabled()) {
void SubscribeImpl(bool reject_cluster, CmdArgList args, const CommandContext& cmd_cntx) {
if (reject_cluster && IsClusterEnabled()) {
return cmd_cntx.rb->SendError("SUBSCRIBE is not supported in cluster mode yet");
}
cmd_cntx.conn_cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args),
static_cast<RedisReplyBuilder*>(cmd_cntx.rb));
}

void Service::Unsubscribe(CmdArgList args, const CommandContext& cmd_cntx) {
void UnSubscribeImpl(bool reject_cluster, CmdArgList args, const CommandContext& cmd_cntx) {
auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);
if (IsClusterEnabled()) {
if (reject_cluster && IsClusterEnabled()) {
return cmd_cntx.rb->SendError("UNSUBSCRIBE is not supported in cluster mode yet");
}

Expand All @@ -2290,6 +2306,32 @@ void Service::Unsubscribe(CmdArgList args, const CommandContext& cmd_cntx) {
}
}

} // namespace

void Service::Publish(CmdArgList args, const CommandContext& cmd_cntx) {
PublishImpl(true, args, cmd_cntx);
}

void Service::SPublish(CmdArgList args, const CommandContext& cmd_cntx) {
PublishImpl(false, args, cmd_cntx);
}

void Service::Subscribe(CmdArgList args, const CommandContext& cmd_cntx) {
SubscribeImpl(true, args, cmd_cntx);
}

void Service::SSubscribe(CmdArgList args, const CommandContext& cmd_cntx) {
SubscribeImpl(false, args, cmd_cntx);
}

void Service::Unsubscribe(CmdArgList args, const CommandContext& cmd_cntx) {
UnSubscribeImpl(true, args, cmd_cntx);
}

void Service::SUnsubscribe(CmdArgList args, const CommandContext& cmd_cntx) {
UnSubscribeImpl(false, args, cmd_cntx);
}

void Service::PSubscribe(CmdArgList args, const CommandContext& cmd_cntx) {
auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);

Expand Down Expand Up @@ -2648,9 +2690,13 @@ void Service::Register(CommandRegistry* registry) {
.SetValidator(&EvalValidator)
<< CI{"EXEC", CO::LOADING | CO::NOSCRIPT, 1, 0, 0, acl::kExec}.MFUNC(Exec)
<< CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, acl::kPublish}.MFUNC(Publish)
<< CI{"SPUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, acl::kPublish}.MFUNC(SPublish)
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, acl::kSubscribe}.MFUNC(Subscribe)
<< CI{"SSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, acl::kSubscribe}.MFUNC(SSubscribe)
<< CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -1, 0, 0, acl::kUnsubscribe}.MFUNC(
Unsubscribe)
<< CI{"SUNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -1, 0, 0, acl::kUnsubscribe}.MFUNC(
SUnsubscribe)
<< CI{"PSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, acl::kPSubscribe}.MFUNC(PSubscribe)
<< CI{"PUNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -1, 0, 0, acl::kPUnsubsribe}.MFUNC(
PUnsubscribe)
Expand Down
3 changes: 3 additions & 0 deletions src/server/main_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,11 @@ class Service : public facade::ServiceInterface {
void EvalShaRo(CmdArgList args, const CommandContext& cmd_cntx);
void Exec(CmdArgList args, const CommandContext& cmd_cntx);
void Publish(CmdArgList args, const CommandContext& cmd_cntx);
void SPublish(CmdArgList args, const CommandContext& cmd_cntx);
void Subscribe(CmdArgList args, const CommandContext& cmd_cntx);
void SSubscribe(CmdArgList args, const CommandContext& cmd_cntx);
void Unsubscribe(CmdArgList args, const CommandContext& cmd_cntx);
void SUnsubscribe(CmdArgList args, const CommandContext& cmd_cntx);
void PSubscribe(CmdArgList args, const CommandContext& cmd_cntx);
void PUnsubscribe(CmdArgList args, const CommandContext& cmd_cntx);
void Function(CmdArgList args, const CommandContext& cmd_cntx);
Expand Down
40 changes: 40 additions & 0 deletions tests/dragonfly/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,3 +2861,43 @@ async def do_migration(index):
logging.debug("stop seeding")
seeder.stop()
await seed


@dfly_args({"proactor_threads": 2, "cluster_mode": "yes"})
async def test_cluster_sharded_pub_sub(df_factory: DflyInstanceFactory):
nodes = [df_factory.create(port=next(next_port)) for i in range(2)]
df_factory.start_all(nodes)

c_nodes = [node.client() for node in nodes]

nodes_info = [(await create_node_info(instance)) for instance in nodes]
nodes_info[0].slots = [(0, 16383)]
nodes_info[1].slots = []

await push_config(json.dumps(generate_config(nodes_info)), [node.client for node in nodes_info])
# channel name kostas crc is at slot 2883 which is part of the first node.
with pytest.raises(redis.exceptions.ResponseError) as moved_error:
await c_nodes[1].execute_command("SSUBSCRIBE kostas")

assert str(moved_error.value) == f"MOVED 2833 127.0.0.1:{nodes[0].port}"

node_a = ClusterNode("localhost", nodes[0].port)
node_b = ClusterNode("localhost", nodes[1].port)

consumer_client = RedisCluster(startup_nodes=[node_a, node_b])
consumer = consumer_client.pubsub()
consumer.ssubscribe("kostas")

await c_nodes[0].execute_command("SPUBLISH kostas hello")

# Consume subscription message result from above
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "subscribe", "pattern": None, "channel": b"kostas", "data": 1}

message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "message", "pattern": None, "channel": b"kostas", "data": b"hello"}

consumer.sunsubscribe("kostas")
await c_nodes[0].execute_command("SPUBLISH kostas new_message")
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "unsubscribe", "pattern": None, "channel": b"kostas", "data": 0}
Loading