diff --git a/sessiondb/redis/database.go b/sessiondb/redis/database.go index 7999647..9fe76fd 100644 --- a/sessiondb/redis/database.go +++ b/sessiondb/redis/database.go @@ -41,7 +41,8 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime seconds, hasExpiration, found := db.redis.TTL(sid) if !found { // not found, create an entry with ttl and return an empty lifetime, session manager will do its job. - if err := db.redis.Set(sid, sid, int64(expires.Seconds())); err != nil { + // set dummy field called `isSet`` since we can't create empty hashmap with expiry. + if err := db.redis.Set(sid, "isSet", true, int64(expires.Seconds())); err != nil { golog.Debug(err) } @@ -50,20 +51,13 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime if !hasExpiration { return sessions.LifeTime{} - } return sessions.LifeTime{Time: time.Now().Add(time.Duration(seconds) * time.Second)} } -const delim = "_" - -func makeKey(sid, key string) string { - return sid + delim + key -} - -// Set sets a key value of a specific session. -// Ignore the "immutable". +// Set sets a field and value in session hashmap. +// TODO: Immutable is not implemented. Need to check if field is already set in hashmap before setting it. func (db *Database) Set(sid string, lifetime sessions.LifeTime, key string, value interface{}, immutable bool) { valueBytes, err := sessions.DefaultTranscoder.Marshal(value) if err != nil { @@ -71,19 +65,21 @@ func (db *Database) Set(sid string, lifetime sessions.LifeTime, key string, valu return } - if err = db.redis.Set(makeKey(sid, key), valueBytes, int64(lifetime.DurationUntilExpiration().Seconds())); err != nil { + // Set hashmap field + if err = db.redis.Set(sid, key, valueBytes, int64(lifetime.DurationUntilExpiration().Seconds())); err != nil { golog.Debug(err) } } -// Get retrieves a session value based on the key. +// Get retrieves a session field value from session hashmap. func (db *Database) Get(sid string, key string) (value interface{}) { - db.get(makeKey(sid, key), &value) + db.get(sid, key, &value) return } -func (db *Database) get(key string, outPtr interface{}) { - data, err := db.redis.Get(key) +// get retrieves a session field value from session hashmap. +func (db *Database) get(key, field string, outPtr interface{}) { + data, err := db.redis.Get(key, field) if err != nil { // not found. return @@ -95,7 +91,7 @@ func (db *Database) get(key string, outPtr interface{}) { } func (db *Database) keys(sid string) []string { - keys, err := db.redis.GetKeys(sid + delim) + keys, err := db.redis.GetKeys(sid) if err != nil { golog.Debugf("unable to get all redis keys of session '%s': %v", sid, err) return nil @@ -107,10 +103,10 @@ func (db *Database) keys(sid string) []string { // Visit loops through all session keys and values. func (db *Database) Visit(sid string, cb func(key string, value interface{})) { keys := db.keys(sid) - for _, key := range keys { + for _, field := range keys { var value interface{} // new value each time, we don't know what user will do in "cb". - db.get(key, &value) - cb(key, value) + db.get(sid, field, &value) + cb(field, value) } } @@ -121,7 +117,7 @@ func (db *Database) Len(sid string) (n int) { // Delete removes a session key value based on its key. func (db *Database) Delete(sid string, key string) (deleted bool) { - err := db.redis.Delete(makeKey(sid, key)) + err := db.redis.Delete(sid, key) if err != nil { golog.Error(err) } @@ -131,20 +127,25 @@ func (db *Database) Delete(sid string, key string) (deleted bool) { // Clear removes all session key values but it keeps the session entry. func (db *Database) Clear(sid string) { keys := db.keys(sid) + var delKeys []string + // Delete all keys except `isSet`` for _, key := range keys { - if err := db.redis.Delete(key); err != nil { - golog.Debugf("unable to delete session '%s' value of key: '%s': %v", sid, key, err) + if key != "isSet" { + delKeys = append(delKeys, key) } } + + if err := db.redis.DeleteMulti(sid, delKeys...); err != nil { + golog.Debugf("unable to delete session '%s' value of keys: '%v': %v", sid, keys, err) + } } // Release destroys the session, it clears and removes the session entry, // session manager will create a new session ID on the next request after this call. func (db *Database) Release(sid string) { - // clear all $sid-$key. - db.Clear(sid) - // and remove the $sid. - db.redis.Delete(sid) + if err := db.redis.DeleteAll(sid); err != nil { + golog.Debugf("unable to delete session '%s'", sid) + } } // Close terminates the redis connection. diff --git a/sessiondb/redis/service/service.go b/sessiondb/redis/service/service.go index b33bc61..d8b4553 100644 --- a/sessiondb/redis/service/service.go +++ b/sessiondb/redis/service/service.go @@ -2,7 +2,6 @@ package service import ( "errors" - "fmt" "time" "github.com/gomodule/redigo/redis" @@ -45,18 +44,22 @@ func (r *Service) CloseConnection() error { // Set sets a key-value to the redis store. // The expiration is setted by the MaxAgeSeconds. -func (r *Service) Set(key string, value interface{}, secondsLifetime int64) (err error) { +func (r *Service) Set(key, field string, value interface{}, secondsLifetime int64) (err error) { c := r.pool.Get() defer c.Close() if c.Err() != nil { return c.Err() } - // if has expiration, then use the "EX" to delete the key automatically. + _, err = c.Do("HSET", r.Config.Prefix+key, field, value) + if err != nil { + return err + } + + // If lifetime is given then expire the map if secondsLifetime > 0 { - _, err = c.Do("SETEX", r.Config.Prefix+key, secondsLifetime, value) - } else { - _, err = c.Do("SET", r.Config.Prefix+key, value) + _, err = c.Do("EXPIRE", r.Config.Prefix+key, secondsLifetime) + return err } return @@ -64,21 +67,21 @@ func (r *Service) Set(key string, value interface{}, secondsLifetime int64) (err // Get returns value, err by its key //returns nil and a filled error if something bad happened. -func (r *Service) Get(key string) (interface{}, error) { +func (r *Service) Get(key, field string) (interface{}, error) { c := r.pool.Get() defer c.Close() if err := c.Err(); err != nil { return nil, err } - redisVal, err := c.Do("GET", r.Config.Prefix+key) - + redisVal, err := c.Do("HGET", r.Config.Prefix+key, field) if err != nil { return nil, err } if redisVal == nil { return nil, ErrKeyNotFound } + return redisVal, nil } @@ -108,11 +111,9 @@ func (r *Service) GetAll() (interface{}, error) { } redisVal, err := c.Do("SCAN", 0) // 0 -> cursor - if err != nil { return nil, err } - if redisVal == nil { return nil, err } @@ -120,63 +121,46 @@ func (r *Service) GetAll() (interface{}, error) { return redisVal, nil } -// GetKeys returns all redis keys using the "SCAN" with MATCH command. -// Read more at: https://redis.io/commands/scan#the-match-option. -func (r *Service) GetKeys(prefix string) ([]string, error) { +// GetKeys returns all fields in session hash map +func (r *Service) GetKeys(key string) ([]string, error) { c := r.pool.Get() defer c.Close() if err := c.Err(); err != nil { return nil, err } - if err := c.Send("SCAN", 0, "MATCH", r.Config.Prefix+prefix+"*", "COUNT", 9999999999); err != nil { + redisVal, err := c.Do("HKEYS") + if err != nil { return nil, err } - - if err := c.Flush(); err != nil { - return nil, err + if redisVal == nil { + return nil, ErrKeyNotFound } - reply, err := c.Receive() - if err != nil || reply == nil { - return nil, err + valIfce := redisVal.([]interface{}) + keys := make([]string, len(valIfce)) + for i, v := range valIfce { + keys[i] = v.(string) } - // it returns []interface, with two entries, the first one is "0" and the second one is a slice of the keys as []interface{uint8....}. - - if keysInterface, ok := reply.([]interface{}); ok { - if len(keysInterface) == 2 { - // take the second, it must contain the slice of keys. - if keysSliceAsBytes, ok := keysInterface[1].([]interface{}); ok { - keys := make([]string, len(keysSliceAsBytes), len(keysSliceAsBytes)) - for i, k := range keysSliceAsBytes { - keys[i] = fmt.Sprintf("%s", k) - } - - return keys, nil - } - - } - } - - return nil, nil + return keys, nil } // GetBytes returns value, err by its key // you can use utils.Deserialize((.GetBytes("yourkey"),&theobject{}) //returns nil and a filled error if something wrong happens -func (r *Service) GetBytes(key string) ([]byte, error) { +func (r *Service) GetBytes(key, field string) ([]byte, error) { c := r.pool.Get() defer c.Close() if err := c.Err(); err != nil { return nil, err } - redisVal, err := c.Do("GET", r.Config.Prefix+key) - + redisVal, err := c.Do("HGET", r.Config.Prefix+key, field) if err != nil { return nil, err } + if redisVal == nil { return nil, ErrKeyNotFound } @@ -185,7 +169,32 @@ func (r *Service) GetBytes(key string) ([]byte, error) { } // Delete removes redis entry by specific key -func (r *Service) Delete(key string) error { +func (r *Service) Delete(key, field string) error { + c := r.pool.Get() + defer c.Close() + + _, err := c.Do("HDEL", r.Config.Prefix+key, field) + return err +} + +// DeleteMulti removes multiple fields from hashmap +func (r *Service) DeleteMulti(key string, fields ...string) error { + c := r.pool.Get() + defer c.Close() + + // Make list of args for HDEL + args := make([]interface{}, len(fields)+1) + args[0] = r.Config.Prefix + key + for i := range fields { + args[i+1] = fields[i] + } + + _, err := c.Do("HDEL", args...) + return err +} + +// DeleteAll deletes session hash map +func (r *Service) DeleteAll(key string) error { c := r.pool.Get() defer c.Close()