From b5e07279944bcf5e1d7a23af9d9820d4f771e852 Mon Sep 17 00:00:00 2001 From: Devansh Singh Date: Thu, 8 Feb 2024 12:56:49 +0530 Subject: [PATCH] Add argument error handling to server commands Signed-off-by: Devansh Singh --- server/commands.go | 7 ++++++- server/generic.go | 36 ++++++++++++++++++++++++++++++++++++ server/set.go | 24 ++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/server/commands.go b/server/commands.go index e5828fb..8886915 100644 --- a/server/commands.go +++ b/server/commands.go @@ -1,6 +1,11 @@ package server -import "log" +import ( + "errors" + "log" +) + +var ErrNotEnoughArgs = errors.New("not enough arguments for command") func (c *Command) write(msg string) { _, err := c.Conn.Write([]byte(msg + "\n")) diff --git a/server/generic.go b/server/generic.go index eb3b5ee..5829654 100644 --- a/server/generic.go +++ b/server/generic.go @@ -7,6 +7,10 @@ import ( ) func (s *Server) get(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } result, err := s.DB.Get(cmd.Args[0]) if err != nil { cmd.error(err) @@ -16,6 +20,10 @@ func (s *Server) get(cmd Command) { } func (s *Server) set(cmd Command) { + if len(cmd.Args) < 2 { + cmd.error(ErrNotEnoughArgs) + return + } err := s.DB.Set(cmd.Args[0], []byte(cmd.Args[1])) if err != nil { cmd.error(err) @@ -23,6 +31,10 @@ func (s *Server) set(cmd Command) { } func (s *Server) setEx(cmd Command) { + if len(cmd.Args) < 3 { + cmd.error(ErrNotEnoughArgs) + return + } ttl, err := strconv.Atoi(cmd.Args[2]) if err != nil { cmd.error(err) @@ -34,6 +46,10 @@ func (s *Server) setEx(cmd Command) { } func (s *Server) del(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } err := s.DB.Del(cmd.Args[0]) if err != nil { cmd.error(err) @@ -48,6 +64,10 @@ func (s *Server) mGet(cmd Command) { } func (s *Server) expire(cmd Command) { + if len(cmd.Args) < 2 { + cmd.error(ErrNotEnoughArgs) + return + } expiration, err := strconv.Atoi(cmd.Args[1]) if err != nil { cmd.error(err) @@ -66,6 +86,10 @@ func (s *Server) keys(cmd Command) { } func (s *Server) exists(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } ok := s.DB.Exists(cmd.Args[0]) if !ok { cmd.write("FALSE") @@ -75,6 +99,10 @@ func (s *Server) exists(cmd Command) { } func (s *Server) persist(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } err := s.DB.Persist(cmd.Args[0]) if err != nil { cmd.error(err) @@ -82,6 +110,10 @@ func (s *Server) persist(cmd Command) { } func (s *Server) expireTime(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } exp, err := s.DB.ExpireTime(cmd.Args[0]) if err != nil { cmd.error(err) @@ -91,6 +123,10 @@ func (s *Server) expireTime(cmd Command) { } func (s *Server) ttl(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } ttl, err := s.DB.TTL(cmd.Args[0]) if err != nil { cmd.error(err) diff --git a/server/set.go b/server/set.go index 13c6c87..4a46c88 100644 --- a/server/set.go +++ b/server/set.go @@ -6,10 +6,18 @@ import ( ) func (s *Server) sAdd(cmd Command) { + if len(cmd.Args) < 2 { + cmd.error(ErrNotEnoughArgs) + return + } s.DB.SAdd(cmd.Args[0], cmd.Args[1]) } func (s *Server) sMembers(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } elements, err := s.DB.SMembers(cmd.Args[0]) if err != nil { cmd.error(err) @@ -21,6 +29,10 @@ func (s *Server) sMembers(cmd Command) { } func (s *Server) sCard(cmd Command) { + if len(cmd.Args) < 1 { + cmd.error(ErrNotEnoughArgs) + return + } size, err := s.DB.SCard(cmd.Args[0]) if err != nil { cmd.error(err) @@ -30,6 +42,10 @@ func (s *Server) sCard(cmd Command) { } func (s *Server) sIsMember(cmd Command) { + if len(cmd.Args) < 2 { + cmd.error(ErrNotEnoughArgs) + return + } ok, err := s.DB.SIsMember(cmd.Args[0], cmd.Args[1]) if err != nil { cmd.error(err) @@ -43,6 +59,10 @@ func (s *Server) sIsMember(cmd Command) { } func (s *Server) sDiff(cmd Command) { + if len(cmd.Args) < 2 { + cmd.error(ErrNotEnoughArgs) + return + } elements, err := s.DB.SDiff(cmd.Args[0], cmd.Args[1]) if err != nil { cmd.error(err) @@ -54,6 +74,10 @@ func (s *Server) sDiff(cmd Command) { } func (s *Server) sDiffStore(cmd Command) { + if len(cmd.Args) < 3 { + cmd.error(ErrNotEnoughArgs) + return + } err := s.DB.SDiffStore(cmd.Args[0], cmd.Args[1], cmd.Args[2]) if err != nil { cmd.error(err)