Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use redis hashmap for storing session #19

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions sessiondb/redis/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime
seconds, hasExpiration, found := db.redis.TTL(sid)
if !found {
// not found, create an entry with ttl and return an empty lifetime, session manager will do its job.
if err := db.redis.Set(sid, sid, int64(expires.Seconds())); err != nil {
// set dummy field called `isSet`` since we can't create empty hashmap with expiry.
if err := db.redis.Set(sid, "isSet", true, int64(expires.Seconds())); err != nil {
golog.Debug(err)
}

Expand All @@ -50,40 +51,35 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime

if !hasExpiration {
return sessions.LifeTime{}

}

return sessions.LifeTime{Time: time.Now().Add(time.Duration(seconds) * time.Second)}
}

const delim = "_"

func makeKey(sid, key string) string {
return sid + delim + key
}

// Set sets a key value of a specific session.
// Ignore the "immutable".
// Set sets a field and value in session hashmap.
// TODO: Immutable is not implemented. Need to check if field is already set in hashmap before setting it.
func (db *Database) Set(sid string, lifetime sessions.LifeTime, key string, value interface{}, immutable bool) {
valueBytes, err := sessions.DefaultTranscoder.Marshal(value)
if err != nil {
golog.Error(err)
return
}

if err = db.redis.Set(makeKey(sid, key), valueBytes, int64(lifetime.DurationUntilExpiration().Seconds())); err != nil {
// Set hashmap field
if err = db.redis.Set(sid, key, valueBytes, int64(lifetime.DurationUntilExpiration().Seconds())); err != nil {
golog.Debug(err)
}
}

// Get retrieves a session value based on the key.
// Get retrieves a session field value from session hashmap.
func (db *Database) Get(sid string, key string) (value interface{}) {
db.get(makeKey(sid, key), &value)
db.get(sid, key, &value)
return
}

func (db *Database) get(key string, outPtr interface{}) {
data, err := db.redis.Get(key)
// get retrieves a session field value from session hashmap.
func (db *Database) get(key, field string, outPtr interface{}) {
data, err := db.redis.Get(key, field)
if err != nil {
// not found.
return
Expand All @@ -95,7 +91,7 @@ func (db *Database) get(key string, outPtr interface{}) {
}

func (db *Database) keys(sid string) []string {
keys, err := db.redis.GetKeys(sid + delim)
keys, err := db.redis.GetKeys(sid)
if err != nil {
golog.Debugf("unable to get all redis keys of session '%s': %v", sid, err)
return nil
Expand All @@ -107,10 +103,10 @@ func (db *Database) keys(sid string) []string {
// Visit loops through all session keys and values.
func (db *Database) Visit(sid string, cb func(key string, value interface{})) {
keys := db.keys(sid)
for _, key := range keys {
for _, field := range keys {
var value interface{} // new value each time, we don't know what user will do in "cb".
db.get(key, &value)
cb(key, value)
db.get(sid, field, &value)
cb(field, value)
}
}

Expand All @@ -121,7 +117,7 @@ func (db *Database) Len(sid string) (n int) {

// Delete removes a session key value based on its key.
func (db *Database) Delete(sid string, key string) (deleted bool) {
err := db.redis.Delete(makeKey(sid, key))
err := db.redis.Delete(sid, key)
if err != nil {
golog.Error(err)
}
Expand All @@ -131,20 +127,25 @@ func (db *Database) Delete(sid string, key string) (deleted bool) {
// Clear removes all session key values but it keeps the session entry.
func (db *Database) Clear(sid string) {
keys := db.keys(sid)
var delKeys []string
// Delete all keys except `isSet``
for _, key := range keys {
if err := db.redis.Delete(key); err != nil {
golog.Debugf("unable to delete session '%s' value of key: '%s': %v", sid, key, err)
if key != "isSet" {
delKeys = append(delKeys, key)
}
}

if err := db.redis.DeleteMulti(sid, delKeys...); err != nil {
golog.Debugf("unable to delete session '%s' value of keys: '%v': %v", sid, keys, err)
}
}

// Release destroys the session, it clears and removes the session entry,
// session manager will create a new session ID on the next request after this call.
func (db *Database) Release(sid string) {
// clear all $sid-$key.
db.Clear(sid)
// and remove the $sid.
db.redis.Delete(sid)
if err := db.redis.DeleteAll(sid); err != nil {
golog.Debugf("unable to delete session '%s'", sid)
}
}

// Close terminates the redis connection.
Expand Down
95 changes: 52 additions & 43 deletions sessiondb/redis/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package service

import (
"errors"
"fmt"
"time"

"github.com/gomodule/redigo/redis"
Expand Down Expand Up @@ -45,40 +44,44 @@ func (r *Service) CloseConnection() error {

// Set sets a key-value to the redis store.
// The expiration is setted by the MaxAgeSeconds.
func (r *Service) Set(key string, value interface{}, secondsLifetime int64) (err error) {
func (r *Service) Set(key, field string, value interface{}, secondsLifetime int64) (err error) {
c := r.pool.Get()
defer c.Close()
if c.Err() != nil {
return c.Err()
}

// if has expiration, then use the "EX" to delete the key automatically.
_, err = c.Do("HSET", r.Config.Prefix+key, field, value)
if err != nil {
return err
}

// If lifetime is given then expire the map
if secondsLifetime > 0 {
_, err = c.Do("SETEX", r.Config.Prefix+key, secondsLifetime, value)
} else {
_, err = c.Do("SET", r.Config.Prefix+key, value)
_, err = c.Do("EXPIRE", r.Config.Prefix+key, secondsLifetime)
return err
}

return
}

// Get returns value, err by its key
//returns nil and a filled error if something bad happened.
func (r *Service) Get(key string) (interface{}, error) {
func (r *Service) Get(key, field string) (interface{}, error) {
c := r.pool.Get()
defer c.Close()
if err := c.Err(); err != nil {
return nil, err
}

redisVal, err := c.Do("GET", r.Config.Prefix+key)

redisVal, err := c.Do("HGET", r.Config.Prefix+key, field)
if err != nil {
return nil, err
}
if redisVal == nil {
return nil, ErrKeyNotFound
}

return redisVal, nil
}

Expand Down Expand Up @@ -108,75 +111,56 @@ func (r *Service) GetAll() (interface{}, error) {
}

redisVal, err := c.Do("SCAN", 0) // 0 -> cursor

if err != nil {
return nil, err
}

if redisVal == nil {
return nil, err
}

return redisVal, nil
}

// GetKeys returns all redis keys using the "SCAN" with MATCH command.
// Read more at: https://redis.io/commands/scan#the-match-option.
func (r *Service) GetKeys(prefix string) ([]string, error) {
// GetKeys returns all fields in session hash map
func (r *Service) GetKeys(key string) ([]string, error) {
c := r.pool.Get()
defer c.Close()
if err := c.Err(); err != nil {
return nil, err
}

if err := c.Send("SCAN", 0, "MATCH", r.Config.Prefix+prefix+"*", "COUNT", 9999999999); err != nil {
redisVal, err := c.Do("HKEYS")
if err != nil {
return nil, err
}

if err := c.Flush(); err != nil {
return nil, err
if redisVal == nil {
return nil, ErrKeyNotFound
}

reply, err := c.Receive()
if err != nil || reply == nil {
return nil, err
valIfce := redisVal.([]interface{})
keys := make([]string, len(valIfce))
for i, v := range valIfce {
keys[i] = v.(string)
}

// it returns []interface, with two entries, the first one is "0" and the second one is a slice of the keys as []interface{uint8....}.

if keysInterface, ok := reply.([]interface{}); ok {
if len(keysInterface) == 2 {
// take the second, it must contain the slice of keys.
if keysSliceAsBytes, ok := keysInterface[1].([]interface{}); ok {
keys := make([]string, len(keysSliceAsBytes), len(keysSliceAsBytes))
for i, k := range keysSliceAsBytes {
keys[i] = fmt.Sprintf("%s", k)
}

return keys, nil
}

}
}

return nil, nil
return keys, nil
}

// GetBytes returns value, err by its key
// you can use utils.Deserialize((.GetBytes("yourkey"),&theobject{})
//returns nil and a filled error if something wrong happens
func (r *Service) GetBytes(key string) ([]byte, error) {
func (r *Service) GetBytes(key, field string) ([]byte, error) {
c := r.pool.Get()
defer c.Close()
if err := c.Err(); err != nil {
return nil, err
}

redisVal, err := c.Do("GET", r.Config.Prefix+key)

redisVal, err := c.Do("HGET", r.Config.Prefix+key, field)
if err != nil {
return nil, err
}

if redisVal == nil {
return nil, ErrKeyNotFound
}
Expand All @@ -185,7 +169,32 @@ func (r *Service) GetBytes(key string) ([]byte, error) {
}

// Delete removes redis entry by specific key
func (r *Service) Delete(key string) error {
func (r *Service) Delete(key, field string) error {
c := r.pool.Get()
defer c.Close()

_, err := c.Do("HDEL", r.Config.Prefix+key, field)
return err
}

// DeleteMulti removes multiple fields from hashmap
func (r *Service) DeleteMulti(key string, fields ...string) error {
c := r.pool.Get()
defer c.Close()

// Make list of args for HDEL
args := make([]interface{}, len(fields)+1)
args[0] = r.Config.Prefix + key
for i := range fields {
args[i+1] = fields[i]
}

_, err := c.Do("HDEL", args...)
return err
}

// DeleteAll deletes session hash map
func (r *Service) DeleteAll(key string) error {
c := r.pool.Get()
defer c.Close()

Expand Down