From ae250766895a2cb0508701e82d6666d4b5c8a5fe Mon Sep 17 00:00:00 2001 From: hdt3213 Date: Sun, 20 Jun 2021 11:48:18 +0800 Subject: [PATCH] support watch command --- README.md | 4 +- README_CN.md | 6 +-- cluster/multi.go | 84 ++++++++++++++++++++++++++++++++++----- cluster/multi_test.go | 71 +++++++++++++++++++++++++-------- cluster/router.go | 2 + db.go | 28 +++++++++++-- exec.go | 5 +++ exec_helper.go | 1 + interface/redis/client.go | 1 + multi.go | 47 ++++++++++++++++++++-- multi_test.go | 44 ++++++++++++++++++++ redis/connection/conn.go | 12 ++++++ util_test.go | 7 ++-- 13 files changed, 272 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 50df8e65..57454e64 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,8 @@ MSET (10 keys): 65487.89 requests per second ## Todo List -+ [ ] `Multi` Command -+ [ ] `Watch` Command and CAS support ++ [x] `Multi` Command ++ [x] `Watch` Command and CAS support + [ ] Stream support + [ ] RDB file loader + [ ] Master-Slave mode diff --git a/README_CN.md b/README_CN.md index 7c83203a..e095a6d6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -14,7 +14,7 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试 - 自动过期功能(TTL) - 发布订阅 - 地理位置 -- AOF 持久化及AOF重写 +- AOF 持久化及 AOF 重写 - Multi 命令开启的事务具有`原子性`和`隔离性`. 若在执行过程中遇到错误, godis 会回滚已执行的命令 - 内置集群模式. 集群对客户端是透明的, 您可以像使用单机版 redis 一样使用 godis 集群 - `MSET`, `DEL` 命令在集群模式下原子性执行 @@ -105,8 +105,8 @@ MSET (10 keys): 65487.89 requests per second ## 开发计划 -+ [ ] `Multi` 命令 -+ [ ] `Watch` 命令和 CAS 支持 ++ [x] `Multi` 命令 ++ [x] `Watch` 命令和 CAS 支持 + [ ] Stream 队列 + [ ] 加载 RDB 文件 + [ ] 主从模式 diff --git a/cluster/multi.go b/cluster/multi.go index c9e8ce88..1905187d 100644 --- a/cluster/multi.go +++ b/cluster/multi.go @@ -3,10 +3,13 @@ package cluster import ( "github.com/hdt3213/godis" "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" + "strconv" ) const relayMulti = "_multi" +const innerWatch = "_watch" var relayMultiBytes = []byte(relayMulti) @@ -25,9 +28,15 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R keys = append(keys, wKeys...) keys = append(keys, rKeys...) } + watching := conn.GetWatching() + watchingKeys := make([]string, 0, len(watching)) + for key := range watching { + watchingKeys = append(watchingKeys, key) + } + keys = append(keys, watchingKeys...) if len(keys) == 0 { // empty transaction or only `PING`s - return godis.ExecMulti(cluster.db, cmdLines) + return godis.ExecMulti(cluster.db, conn, watching, cmdLines) } groupMap := cluster.groupBy(keys) if len(groupMap) > 1 { @@ -41,12 +50,12 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R // out parser not support reply.MultiRawReply, so we have to encode it if peer == cluster.self { - return godis.ExecMulti(cluster.db, cmdLines) + return godis.ExecMulti(cluster.db, conn, watching, cmdLines) } - return execMultiOnOtherNode(cluster, conn, peer, cmdLines) + return execMultiOnOtherNode(cluster, conn, peer, watching, cmdLines) } -func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string, cmdLines []CmdLine) redis.Reply { +func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string, watching map[string]uint32, cmdLines []CmdLine) redis.Reply { defer func() { conn.ClearQueuedCmds() conn.SetMultiState(false) @@ -54,11 +63,28 @@ func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string, relayCmdLine := [][]byte{ // relay it to executing node relayMultiBytes, } + // watching commands + var watchingCmdLine = utils.ToCmdLine(innerWatch) + for key, ver := range watching { + verStr := strconv.FormatUint(uint64(ver), 10) + watchingCmdLine = append(watchingCmdLine, []byte(key), []byte(verStr)) + } + relayCmdLine = append(relayCmdLine, encodeCmdLine([]CmdLine{watchingCmdLine})...) relayCmdLine = append(relayCmdLine, encodeCmdLine(cmdLines)...) - rawRelayResult := cluster.relay(peer, conn, relayCmdLine) + var rawRelayResult redis.Reply + if peer == cluster.self { + // this branch just for testing + rawRelayResult = execRelayedMulti(cluster, nil, relayCmdLine) + } else { + rawRelayResult = cluster.relay(peer, conn, relayCmdLine) + } if reply.IsErrorReply(rawRelayResult) { return rawRelayResult } + _, ok := rawRelayResult.(*reply.EmptyMultiBulkReply) + if ok { + return rawRelayResult + } relayResult, ok := rawRelayResult.(*reply.MultiBulkReply) if !ok { return reply.MakeErrReply("execute failed") @@ -71,25 +97,65 @@ func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string, } // execRelayedMulti execute relayed multi commands transaction -// cmdLine format: _multi base64ed-cmdLine +// cmdLine format: _multi watch-cmdLine base64ed-cmdLine // result format: base64ed-reply list func execRelayedMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) < 2 { + return reply.MakeArgNumErrReply("_exec") + } decoded, err := parseEncodedMultiRawReply(cmdLine[1:]) if err != nil { return reply.MakeErrReply(err.Error()) } - var cmdLines []CmdLine + var txCmdLines []CmdLine for _, rep := range decoded.Replies { mbr, ok := rep.(*reply.MultiBulkReply) if !ok { return reply.MakeErrReply("exec failed") } - cmdLines = append(cmdLines, mbr.Args) + txCmdLines = append(txCmdLines, mbr.Args) + } + watching := make(map[string]uint32) + watchCmdLine := txCmdLines[0] // format: _watch key1 ver1 key2 ver2... + for i := 2; i < len(watchCmdLine); i += 2 { + key := string(watchCmdLine[i-1]) + verStr := string(watchCmdLine[i]) + ver, err := strconv.ParseUint(verStr, 10, 64) + if err != nil { + return reply.MakeErrReply("watching command line failed") + } + watching[key] = uint32(ver) + } + rawResult := godis.ExecMulti(cluster.db, conn, watching, txCmdLines[1:]) + _, ok := rawResult.(*reply.EmptyMultiBulkReply) + if ok { + return rawResult } - rawResult := godis.ExecMulti(cluster.db, cmdLines) resultMBR, ok := rawResult.(*reply.MultiRawReply) if !ok { return reply.MakeErrReply("exec failed") } return encodeMultiRawReply(resultMBR) } + +func execWatch(cluster *Cluster, conn redis.Connection, args [][]byte) redis.Reply { + if len(args) < 2 { + return reply.MakeArgNumErrReply("watch") + } + args = args[1:] + watching := conn.GetWatching() + for _, bkey := range args { + key := string(bkey) + peer := cluster.peerPicker.PickNode(key) + result := cluster.relay(peer, conn, utils.ToCmdLine("GetVer", key)) + if reply.IsErrorReply(result) { + return result + } + intResult, ok := result.(*reply.IntReply) + if !ok { + return reply.MakeErrReply("get version failed") + } + watching[key] = uint32(intResult.Code) + } + return reply.MakeOkReply() +} diff --git a/cluster/multi_test.go b/cluster/multi_test.go index 92155b32..e6d92d83 100644 --- a/cluster/multi_test.go +++ b/cluster/multi_test.go @@ -50,24 +50,63 @@ func TestMultiExecOnOthers(t *testing.T) { testCluster.Exec(conn, utils.ToCmdLine("lrange", key, "0", "-1")) cmdLines := conn.GetQueuedCmdLine() - relayCmdLine := [][]byte{ // relay it to executing node - relayMultiBytes, - } - relayCmdLine = append(relayCmdLine, encodeCmdLine(cmdLines)...) - rawRelayResult := execRelayedMulti(testCluster, conn, relayCmdLine) - if reply.IsErrorReply(rawRelayResult) { - t.Error() - } - relayResult, ok := rawRelayResult.(*reply.MultiBulkReply) - if !ok { - t.Error() - } - rep, err := parseEncodedMultiRawReply(relayResult.Args) - if err != nil { - t.Error() - } + rawResp := execMultiOnOtherNode(testCluster, conn, testCluster.self, nil, cmdLines) + rep := rawResp.(*reply.MultiRawReply) if len(rep.Replies) != 2 { t.Errorf("expect 2 replies actual %d", len(rep.Replies)) } asserts.AssertMultiBulkReply(t, rep.Replies[1], []string{value}) } + +func TestWatch(t *testing.T) { + testCluster.db.Flush() + conn := new(connection.FakeConn) + key := utils.RandString(10) + value := utils.RandString(10) + testCluster.Exec(conn, utils.ToCmdLine("watch", key)) + testCluster.Exec(conn, utils.ToCmdLine("set", key, value)) + result := testCluster.Exec(conn, toArgs("MULTI")) + asserts.AssertNotError(t, result) + key2 := utils.RandString(10) + value2 := utils.RandString(10) + testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2)) + result = testCluster.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertNotError(t, result) + result = testCluster.Exec(conn, utils.ToCmdLine("get", key2)) + asserts.AssertNullBulk(t, result) + + testCluster.Exec(conn, utils.ToCmdLine("watch", key)) + result = testCluster.Exec(conn, toArgs("MULTI")) + asserts.AssertNotError(t, result) + testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2)) + result = testCluster.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertNotError(t, result) + result = testCluster.Exec(conn, utils.ToCmdLine("get", key2)) + asserts.AssertBulkReply(t, result, value2) +} + +func TestWatch2(t *testing.T) { + testCluster.db.Flush() + conn := new(connection.FakeConn) + key := utils.RandString(10) + value := utils.RandString(10) + testCluster.Exec(conn, utils.ToCmdLine("watch", key)) + testCluster.Exec(conn, utils.ToCmdLine("set", key, value)) + result := testCluster.Exec(conn, toArgs("MULTI")) + asserts.AssertNotError(t, result) + key2 := utils.RandString(10) + value2 := utils.RandString(10) + testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2)) + cmdLines := conn.GetQueuedCmdLine() + execMultiOnOtherNode(testCluster, conn, testCluster.self, conn.GetWatching(), cmdLines) + result = testCluster.Exec(conn, utils.ToCmdLine("get", key2)) + asserts.AssertNullBulk(t, result) + + testCluster.Exec(conn, utils.ToCmdLine("watch", key)) + result = testCluster.Exec(conn, toArgs("MULTI")) + asserts.AssertNotError(t, result) + testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2)) + execMultiOnOtherNode(testCluster, conn, testCluster.self, conn.GetWatching(), cmdLines) + result = testCluster.Exec(conn, utils.ToCmdLine("get", key2)) + asserts.AssertBulkReply(t, result, value2) +} diff --git a/cluster/router.go b/cluster/router.go index 918a812a..87383a98 100644 --- a/cluster/router.go +++ b/cluster/router.go @@ -110,6 +110,8 @@ func makeRouter() map[string]CmdFunc { routerMap["flushdb"] = FlushDB routerMap["flushall"] = FlushAll routerMap[relayMulti] = execRelayedMulti + routerMap["getver"] = defaultFunc + routerMap["watch"] = execWatch return routerMap } diff --git a/db.go b/db.go index 0f88be6e..1bc70d39 100644 --- a/db.go +++ b/db.go @@ -28,6 +28,8 @@ type DB struct { data dict.Dict // key -> expireTime (time.Time) ttlMap dict.Dict + // key -> version(uint32) + versionMap dict.Dict // dict.Dict will ensure concurrent-safety of its method // use this mutex for complicated command only, eg. rpush, incr ... @@ -72,10 +74,11 @@ type UndoFunc func(db *DB, args [][]byte) []CmdLine // MakeDB create DB instance and start it func MakeDB() *DB { db := &DB{ - data: dict.MakeConcurrent(dataDictSize), - ttlMap: dict.MakeConcurrent(ttlDictSize), - locker: lock.Make(lockerSize), - hub: pubsub.MakeHub(), + data: dict.MakeConcurrent(dataDictSize), + ttlMap: dict.MakeConcurrent(ttlDictSize), + versionMap: dict.MakeConcurrent(dataDictSize), + locker: lock.Make(lockerSize), + hub: pubsub.MakeHub(), } // aof @@ -249,6 +252,23 @@ func (db *DB) IsExpired(key string) bool { return expired } +/* --- add version --- */ + +func (db *DB) addVersion(keys ...string) { + for _, key := range keys { + versionCode := db.GetVersion(key) + db.versionMap.Put(key, versionCode+1) + } +} + +func (db *DB) GetVersion(key string) uint32 { + entity, ok := db.versionMap.Get(key) + if !ok { + return 0 + } + return entity.(uint32) +} + /* ---- Subscribe Functions ---- */ // AfterClientClose does some clean after client close connection diff --git a/exec.go b/exec.go index 0298e49e..e7243682 100644 --- a/exec.go +++ b/exec.go @@ -71,6 +71,11 @@ func execSpecialCmd(c redis.Connection, cmdLine [][]byte, cmdName string, db *DB return reply.MakeArgNumErrReply(cmdName), true } return execMulti(db, c), true + } else if cmdName == "watch" { + if !validateArity(-2, cmdLine) { + return reply.MakeArgNumErrReply(cmdName), true + } + return Watch(db, c, cmdLine[1:]), true } return nil, false } diff --git a/exec_helper.go b/exec_helper.go index e092ae24..88536813 100644 --- a/exec_helper.go +++ b/exec_helper.go @@ -18,6 +18,7 @@ func execNormalCommand(db *DB, cmdArgs [][]byte) redis.Reply { prepare := cmd.prepare write, read := prepare(cmdArgs[1:]) + db.addVersion(write...) db.RWLocks(write, read) defer db.RWUnLocks(write, read) fun := cmd.executor diff --git a/interface/redis/client.go b/interface/redis/client.go index c91f23f9..3ed36af0 100644 --- a/interface/redis/client.go +++ b/interface/redis/client.go @@ -18,4 +18,5 @@ type Connection interface { GetQueuedCmdLine() [][][]byte EnqueueCmd([][]byte) ClearQueuedCmds() + GetWatching() map[string]uint32 } diff --git a/multi.go b/multi.go index 1fb07165..42137457 100644 --- a/multi.go +++ b/multi.go @@ -12,6 +12,37 @@ var forbiddenInMulti = set.Make( "flushall", ) +// Watch set watching keys +func Watch(db *DB, conn redis.Connection, args [][]byte) redis.Reply { + watching := conn.GetWatching() + for _, bkey := range args { + key := string(bkey) + watching[key] = db.GetVersion(key) + } + return reply.MakeOkReply() +} + +func execGetVersion(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + ver := db.GetVersion(key) + return reply.MakeIntReply(int64(ver)) +} + +func init() { + RegisterCommand("GetVer", execGetVersion, readAllKeys, nil, 2) +} + +// invoker should lock watching keys +func isWatchingChanged(db *DB, watching map[string]uint32) bool { + for key, ver := range watching { + currentVersion := db.GetVersion(key) + if ver != currentVersion { + return true + } + } + return false +} + // StartMulti starts multi-command-transaction func StartMulti(db *DB, conn redis.Connection) redis.Reply { if conn.InMultiState() { @@ -48,11 +79,11 @@ func execMulti(db *DB, conn redis.Connection) redis.Reply { } defer conn.SetMultiState(false) cmdLines := conn.GetQueuedCmdLine() - return ExecMulti(db, cmdLines) + return ExecMulti(db, conn, conn.GetWatching(), cmdLines) } // ExecMulti executes multi commands transaction Atomically and Isolated -func ExecMulti(db *DB, cmdLines []CmdLine) redis.Reply { +func ExecMulti(db *DB, conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply { // prepare writeKeys := make([]string, 0) // may contains duplicate readKeys := make([]string, 0) @@ -64,9 +95,18 @@ func ExecMulti(db *DB, cmdLines []CmdLine) redis.Reply { writeKeys = append(writeKeys, write...) readKeys = append(readKeys, read...) } + // set watch + watchingKeys := make([]string, 0, len(watching)) + for key := range watching { + watchingKeys = append(watchingKeys, key) + } + readKeys = append(readKeys, watchingKeys...) db.RWLocks(writeKeys, readKeys) defer db.RWUnLocks(writeKeys, readKeys) + if isWatchingChanged(db, watching) { // watching keys changed, abort + return reply.MakeEmptyMultiBulkReply() + } // execute results := make([]redis.Reply, 0, len(cmdLines)) aborted := false @@ -82,7 +122,8 @@ func ExecMulti(db *DB, cmdLines []CmdLine) redis.Reply { } results = append(results, result) } - if !aborted { + if !aborted { //success + db.addVersion(writeKeys...) return reply.MakeMultiRawReply(results) } // undo if aborted diff --git a/multi_test.go b/multi_test.go index 14421d84..4b445542 100644 --- a/multi_test.go +++ b/multi_test.go @@ -23,6 +23,12 @@ func TestMulti(t *testing.T) { asserts.AssertBulkReply(t, result, value) result = testDB.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1")) asserts.AssertMultiBulkReply(t, result, []string{value}) + if len(conn.GetWatching()) > 0 { + t.Error("watching map should be reset") + } + if len(conn.GetQueuedCmdLine()) > 0 { + t.Error("queue should be reset") + } } func TestRollback(t *testing.T) { @@ -38,6 +44,12 @@ func TestRollback(t *testing.T) { asserts.AssertErrReply(t, result, "EXECABORT Transaction discarded because of previous errors.") result = testDB.Exec(conn, utils.ToCmdLine("type", key)) asserts.AssertStatusReply(t, result, "none") + if len(conn.GetWatching()) > 0 { + t.Error("watching map should be reset") + } + if len(conn.GetQueuedCmdLine()) > 0 { + t.Error("queue should be reset") + } } func TestDiscard(t *testing.T) { @@ -56,4 +68,36 @@ func TestDiscard(t *testing.T) { asserts.AssertNullBulk(t, result) result = testDB.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1")) asserts.AssertMultiBulkReplySize(t, result, 0) + if len(conn.GetWatching()) > 0 { + t.Error("watching map should be reset") + } + if len(conn.GetQueuedCmdLine()) > 0 { + t.Error("queue should be reset") + } +} + +func TestWatch(t *testing.T) { + testDB.Flush() + conn := new(connection.FakeConn) + for i := 0; i < 3; i++ { + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(conn, utils.ToCmdLine("watch", key)) + testDB.Exec(conn, utils.ToCmdLine("set", key, value)) + result := testDB.Exec(conn, utils.ToCmdLine("multi")) + asserts.AssertNotError(t, result) + key2 := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Exec(conn, utils.ToCmdLine("set", key2, value2)) + result = testDB.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertNotError(t, result) + result = testDB.Exec(conn, utils.ToCmdLine("get", key2)) + asserts.AssertNullBulk(t, result) + if len(conn.GetWatching()) > 0 { + t.Error("watching map should be reset") + } + if len(conn.GetQueuedCmdLine()) > 0 { + t.Error("queue should be reset") + } + } } diff --git a/redis/connection/conn.go b/redis/connection/conn.go index 3d101720..14bc3f7d 100644 --- a/redis/connection/conn.go +++ b/redis/connection/conn.go @@ -27,6 +27,7 @@ type Connection struct { // queued commands for `multi` multiState bool queue [][][]byte + watching map[string]uint32 } // RemoteAddr returns the remote network address @@ -120,6 +121,10 @@ func (c *Connection) InMultiState() bool { } func (c *Connection) SetMultiState(state bool) { + if !state { // reset data when cancel multi + c.watching = nil + c.queue = nil + } c.multiState = state } @@ -135,6 +140,13 @@ func (c *Connection) ClearQueuedCmds() { c.queue = nil } +func (c *Connection) GetWatching() map[string]uint32 { + if c.watching == nil { + c.watching = make(map[string]uint32) + } + return c.watching +} + // FakeConn implements redis.Connection for test type FakeConn struct { Connection diff --git a/util_test.go b/util_test.go index a93a051e..b02a37b1 100644 --- a/util_test.go +++ b/util_test.go @@ -7,8 +7,9 @@ import ( func makeTestDB() *DB { return &DB{ - data: dict.MakeConcurrent(1), - ttlMap: dict.MakeConcurrent(ttlDictSize), - locker: lock.Make(lockerSize), + data: dict.MakeConcurrent(dataDictSize), + versionMap: dict.MakeConcurrent(dataDictSize), + ttlMap: dict.MakeConcurrent(ttlDictSize), + locker: lock.Make(lockerSize), } }