Skip to content

Commit

Permalink
feat: db cache overhaul (#1017)
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettladley authored Jun 12, 2024
1 parent 53cc070 commit 3698fab
Show file tree
Hide file tree
Showing 40 changed files with 654 additions and 298 deletions.
4 changes: 4 additions & 0 deletions backend/background/jobs/welcome_sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"time"

"github.com/GenerateNU/sac/backend/background"
"github.com/GenerateNU/sac/backend/database/cache"

"github.com/GenerateNU/sac/backend/constants"
"github.com/GenerateNU/sac/backend/entities/models"
"gorm.io/gorm"
Expand Down Expand Up @@ -84,6 +86,8 @@ func (j *Jobs) WelcomeSender(ctx context.Context) background.JobFunc {
}

func (j *Jobs) dequeueWelcomeTask(tx *gorm.DB) (*models.WelcomeTask, error) {
tx = cache.SetUseCache(tx, true)

var task models.WelcomeTask
if err := tx.Raw("SELECT email, name, attempts FROM welcome_tasks FOR UPDATE SKIP LOCKED LIMIT 1").Scan(&task).Error; err != nil {
tx.Rollback()
Expand Down
2 changes: 1 addition & 1 deletion backend/database/cache/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Cache Acknowledgement

## Forked code from @evangwt's [grc package](https://github.com/evangwt/grc) to fit into our internal project structure
## Forked code from @go-gorm's [caches package](https://github.com/go-gorm/caches/tree/master) to fit into our internal project structure
288 changes: 143 additions & 145 deletions backend/database/cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,201 +1,199 @@
package cache

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"log/slog"
"sync"
"time"

"github.com/GenerateNU/sac/backend/config"
go_json "github.com/goccy/go-json"

"gorm.io/gorm/callbacks"

"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)

var (
useCacheKey struct{}
cacheTTLKey struct{}
)

// GormCache is a cache plugin for gorm
type GormCache struct {
name string
client CacheClient
config CacheConfig
type Config struct {
Easer bool
Cacher Cacher
TTL time.Duration
}

// CacheClient is an interface for cache operations
type CacheClient interface {
Get(ctx context.Context, key string) (interface{}, error)
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
type Caches struct {
callbacks map[queryType]func(db *gorm.DB)
Conf *Config

queue *sync.Map
}

// CacheConfig is a struct for cache options
type CacheConfig struct {
TTL time.Duration // cache expiration time
Prefix string // cache key prefix
func (c *Caches) Name() string {
return "gorm:caches"
}

// NewGormCache returns a new GormCache instance
func NewGormCache(name string, client CacheClient, config CacheConfig) *GormCache {
return &GormCache{
name: name,
client: client,
config: config,
func (c *Caches) Initialize(db *gorm.DB) error {
if c.Conf == nil {
c.Conf = &Config{
Easer: false,
Cacher: nil,
}
}
}

// Name returns the plugin name
func (g *GormCache) Name() string {
return g.name
}
if c.Conf.Easer {
c.queue = &sync.Map{}
}

// Initialize initializes the plugin
func (g *GormCache) Initialize(db *gorm.DB) error {
return db.Callback().Query().Replace("gorm:query", g.queryCallback)
}
callbacks := make(map[queryType]func(db *gorm.DB), 4)
callbacks[uponQuery] = db.Callback().Query().Get("gorm:query")
callbacks[uponCreate] = db.Callback().Create().Get("gorm:query")
callbacks[uponUpdate] = db.Callback().Update().Get("gorm:query")
callbacks[uponDelete] = db.Callback().Delete().Get("gorm:query")
c.callbacks = callbacks

// queryCallback is a callback function for query operations
func (g *GormCache) queryCallback(db *gorm.DB) {
if db.Error != nil {
return
if err := db.Callback().Query().Replace("gorm:query", c.query); err != nil {
return err
}

enableCache := g.enableCache(db)
if err := db.Callback().Create().Replace("gorm:query", c.getMutatorCb(uponCreate)); err != nil {
return err
}

// build query sql
callbacks.BuildQuerySQL(db)
if db.DryRun || db.Error != nil {
return
if err := db.Callback().Update().Replace("gorm:query", c.getMutatorCb(uponUpdate)); err != nil {
return err
}

var (
key string
err error
hit bool
)
if enableCache {
key = g.cacheKey(db)
if err := db.Callback().Delete().Replace("gorm:query", c.getMutatorCb(uponDelete)); err != nil {
return err
}

// get value from cache
hit, err = g.loadCache(db, key)
if err != nil {
slog.Error("load cache failed", "error", err, "hit", hit)
return
}
return nil
}

// hit cache
if hit {
return
}
func (c *Caches) query(db *gorm.DB) {
useCache, ok := db.Statement.Context.Value(useCacheKey).(bool)
if !ok {
useCache = false
}

if !hit {
g.queryDB(db)

if enableCache {
if err = g.setCache(db, key); err != nil {
slog.Error("set cache failed", "error", err)
}
}
cacheTTL, ok := db.Statement.Context.Value(cacheTTLKey).(time.Duration)
if !ok {
cacheTTL = c.Conf.TTL
}
}

func (g *GormCache) enableCache(db *gorm.DB) bool {
ctx := db.Statement.Context
if !useCache || (!c.Conf.Easer && c.Conf.Cacher == nil) {
c.callbacks[uponQuery](db)
return
}

// check if use cache
useCache, ok := ctx.Value(useCacheKey).(bool)
if !ok || !useCache {
return false // do not use cache, skip this callback
if !c.Conf.Easer && c.Conf.Cacher == nil {
c.callbacks[uponQuery](db)
return
}
return true
}

func (g *GormCache) cacheKey(db *gorm.DB) string {
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
hash := sha256.Sum256([]byte(sql))
key := g.config.Prefix + hex.EncodeToString(hash[:])
return key
}
identifier := buildIdentifier(db)

func (g *GormCache) loadCache(db *gorm.DB, key string) (bool, error) {
value, err := g.client.Get(db.Statement.Context, key)
if err != nil && !errors.Is(err, redis.Nil) {
return false, err
if c.checkCache(db, identifier) {
return
}

if value == nil {
return false, nil
c.ease(db, identifier)
if db.Error != nil {
return
}

// cache hit, scan value to destination
if err = go_json.Unmarshal(value.([]byte), &db.Statement.Dest); err != nil {
return false, err
c.storeInCache(db, identifier, cacheTTL)
if db.Error != nil {
return
}
db.RowsAffected = int64(db.Statement.ReflectValue.Len())
return true, nil
}

func (g *GormCache) setCache(db *gorm.DB, key string) error {
ctx := db.Statement.Context
// getMutatorCb returns a decorator which calls the Cacher's Invalidate method
func (c *Caches) getMutatorCb(typ queryType) func(db *gorm.DB) {
return func(db *gorm.DB) {
if c.Conf.Cacher != nil {
if err := c.Conf.Cacher.Invalidate(db.Statement.Context); err != nil {
_ = db.AddError(err)
}
}
if cb := c.callbacks[typ]; cb != nil { // By default, gorm has no callbacks associated with mutating behaviors
cb(db)
}
}
}

// get cache ttl from context or config
ttl, ok := ctx.Value(cacheTTLKey).(time.Duration)
if !ok {
ttl = g.config.TTL // use default ttl
func (c *Caches) ease(db *gorm.DB, identifier string) {
if !c.Conf.Easer {
c.callbacks[uponQuery](db)
return
}

// set value to cache with ttl
return g.client.Set(ctx, key, db.Statement.Dest, ttl)
}
res := ease(&queryTask{
id: identifier,
db: db,
queryCb: c.callbacks[uponQuery],
}, c.queue).(*queryTask)

func (g *GormCache) queryDB(db *gorm.DB) {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
if err := db.AddError(err); err != nil {
slog.Error("error encountered while adding error", "error", err)
}
if db.Error != nil {
return
}
defer func() {
if err := db.AddError(rows.Close()); err != nil {
slog.Error("error encountered while closing rows", "error", err)
}
}()
gorm.Scan(rows, db, 0)
}

type RedisClient struct {
client *redis.Client
}
if res.db.Statement.Dest == db.Statement.Dest {
return
}

// NewRedisClient returns a new RedisClient instance
func NewRedisClient(settings config.RedisSettings) *RedisClient {
return &RedisClient{
client: settings.Into(),
detachedQuery := &Query[any]{
Dest: db.Statement.Dest,
RowsAffected: db.Statement.RowsAffected,
}

easedQuery := &Query[any]{
Dest: res.db.Statement.Dest,
RowsAffected: res.db.Statement.RowsAffected,
}
if err := easedQuery.copyTo(detachedQuery); err != nil {
_ = db.AddError(err)
}

detachedQuery.replaceOn(db)
}

// Get gets value from redis by key using json encoding/decoding
func (r *RedisClient) Get(ctx context.Context, key string) (interface{}, error) {
data, err := r.client.Get(ctx, key).Bytes()
if err != nil {
return nil, err
func (c *Caches) checkCache(db *gorm.DB, identifier string) bool {
if c.Conf.Cacher != nil {
res, err := c.Conf.Cacher.Get(db.Statement.Context, identifier, &Query[any]{
Dest: db.Statement.Dest,
RowsAffected: db.Statement.RowsAffected,
})
if err != nil {
_ = db.AddError(err)
}

if res != nil {
res.replaceOn(db)
return true
}
}
return data, nil
return false
}

// Set sets value to redis by key with ttl using json encoding/decoding
func (r *RedisClient) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := go_json.Marshal(value) // encode value to json bytes using json encoding/decoding
if err != nil {
return err
func (c *Caches) storeInCache(db *gorm.DB, identifier string, ttl time.Duration) {
if c.Conf.Cacher != nil {
err := c.Conf.Cacher.Store(db.Statement.Context, identifier, &Query[any]{
Dest: db.Statement.Dest,
RowsAffected: db.Statement.RowsAffected,
},
ttl,
)
if err != nil {
_ = db.AddError(err)
}
}
return r.client.Set(ctx, key, data, ttl).Err()
}

type key byte

const (
useCacheKey key = 0
cacheTTLKey key = 1
)

type queryType int

const (
uponQuery queryType = iota
uponCreate
uponUpdate
uponDelete
)
18 changes: 18 additions & 0 deletions backend/database/cache/cahcer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package cache

import (
"context"
"time"
)

type Cacher interface {
// Get impl should check if a specific key exists in the cache and return its value
// look at Query.Marshal
Get(ctx context.Context, key string, q *Query[any]) (*Query[any], error)
// Store impl should store a cached representation of the val param
// look at Query.Unmarshal
Store(ctx context.Context, key string, val *Query[any], ttl time.Duration) error
// Invalidate impl should invalidate all cached values
// It will be called when INSERT / UPDATE / DELETE queries are sent to the DB
Invalidate(ctx context.Context) error
}
Loading

0 comments on commit 3698fab

Please sign in to comment.