-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.go
194 lines (180 loc) · 4.79 KB
/
dataloader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
package dataloader
import (
"context"
"sync"
"time"
)
// Config for a generic dataloader.
type Config[T any] struct {
// Fetch sets the function for fetching data.
Fetch func(ctx context.Context, keys []string) ([]T, error)
// Wait sets the duration to wait before fetching data.
Wait time.Duration
// MaxBatch sets the max batch size when fetching data.
MaxBatch int
// Copy sets the function for how to copy values when priming
// the cache, see Dataloader.Prime.
// Copy must be specified if you intend to use Dataloader.Prime.
Copy func(src T) T
}
// Dataloader is a generic dataloader.
type Dataloader[T any] struct {
ctx context.Context
config Config[T]
mu sync.Mutex // protects mutable state below
cache map[string]T
batch *dataloaderBatch[T]
}
// New creates a new dataloader.
func New[T any](
ctx context.Context,
config Config[T],
) *Dataloader[T] {
return &Dataloader[T]{
ctx: ctx,
config: config,
}
}
type dataloaderBatch[T any] struct {
ctx context.Context
keys []string
data []T
err error
closing bool
done chan struct{}
}
// Load a result by key, batching and caching will be applied automatically.
func (l *Dataloader[T]) Load(key string) (T, error) {
return l.LoadThunk(key)()
}
// LoadThunk returns a function that when called will block waiting for a result.
// This method should be used if you want one goroutine to make requests to
// different data loaders without blocking until the thunk is called.
func (l *Dataloader[T]) LoadThunk(key string) func() (T, error) {
l.mu.Lock()
if it, ok := l.cache[key]; ok {
l.mu.Unlock()
return func() (T, error) {
return it, nil
}
}
if l.batch == nil {
l.batch = &dataloaderBatch[T]{ctx: l.ctx, done: make(chan struct{})}
}
batch := l.batch
pos := batch.keyIndex(l, key)
l.mu.Unlock()
return func() (T, error) {
<-batch.done
var data T
if pos < len(batch.data) {
data = batch.data[pos]
}
if batch.err == nil {
l.mu.Lock()
l.unsafeSet(key, data)
l.mu.Unlock()
}
return data, batch.err
}
}
// LoadAll fetches many keys at once.
// It will be broken into appropriately sized sub-batches based on how the dataloader is configured.
func (l *Dataloader[T]) LoadAll(keys []string) ([]T, error) {
results := make([]func() (T, error), len(keys))
for i, key := range keys {
results[i] = l.LoadThunk(key)
}
values := make([]T, len(keys))
var err error
for i, thunk := range results {
values[i], err = thunk()
if err != nil {
return nil, err
}
}
return values, nil
}
// LoadAllThunk returns a function that when called will block waiting for results.
// This method should be used if you want one goroutine to make requests to many
// different data loaders without blocking until the thunk is called.
func (l *Dataloader[T]) LoadAllThunk(keys []string) func() ([]T, error) {
results := make([]func() (T, error), len(keys))
for i, key := range keys {
results[i] = l.LoadThunk(key)
}
return func() ([]T, error) {
values := make([]T, len(keys))
var err error
for i, thunk := range results {
values[i], err = thunk()
if err != nil {
return nil, err
}
}
return values, nil
}
}
// Prime the cache with the provided key and value.
// If the key already exists, no change is made and false is returned.
// Calling Prime without specifying Copy in Config will panic.
func (l *Dataloader[T]) Prime(key string, value T) bool {
if l.config.Copy == nil {
panic("Copy must be specified in dataloader.Config before calling Prime.")
}
l.mu.Lock()
var found bool
if _, found = l.cache[key]; !found {
// make a copy when writing to the cache, it's easy to pass a pointer in from a loop var
// and end up with the whole cache pointing to the same value.
cpy := l.config.Copy(value)
l.unsafeSet(key, cpy)
}
l.mu.Unlock()
return !found
}
func (l *Dataloader[T]) unsafeSet(key string, value T) {
if l.cache == nil {
l.cache = map[string]T{}
}
l.cache[key] = value
}
// keyIndex will return the location of the key in the batch, if its not found
// it will add the key to the batch.
func (b *dataloaderBatch[T]) keyIndex(l *Dataloader[T], key string) int {
for i, existingKey := range b.keys {
if key == existingKey {
return i
}
}
pos := len(b.keys)
b.keys = append(b.keys, key)
if pos == 0 {
go b.startTimer(l)
}
if l.config.MaxBatch != 0 && pos >= l.config.MaxBatch-1 {
if !b.closing {
b.closing = true
l.batch = nil
go b.end(l)
}
}
return pos
}
func (b *dataloaderBatch[T]) startTimer(l *Dataloader[T]) {
// TODO: Respect context.
time.Sleep(l.config.Wait)
l.mu.Lock()
// we must have hit a batch limit and are already finalizing this batch
if b.closing {
l.mu.Unlock()
return
}
l.batch = nil
l.mu.Unlock()
b.end(l)
}
func (b *dataloaderBatch[T]) end(l *Dataloader[T]) {
b.data, b.err = l.config.Fetch(b.ctx, b.keys)
close(b.done)
}