diff --git a/.gitignore b/.gitignore index e1985f0..91aa830 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /_example/_example /_example/cache/ +/.cache/ *.txt diff --git a/diskcache.go b/diskcache.go index 74f73ec..892f026 100644 --- a/diskcache.go +++ b/diskcache.go @@ -13,6 +13,7 @@ package diskcache import ( "bufio" "bytes" + "context" "crypto/sha256" "errors" "fmt" @@ -176,7 +177,7 @@ func (c *Cache) EvictKey(key string) error { // or if the cached response is stale the request will be executed and cached. func (c *Cache) Fetch(key string, p Policy, req *http.Request, force bool) (bool, time.Time, *http.Response, error) { // check stale - stale, mod, err := c.Stale(key, p.TTL) + stale, mod, err := c.Stale(req.Context(), key, p.TTL) if err != nil { return false, time.Time{}, nil, err } @@ -213,7 +214,7 @@ func (c *Cache) Mod(key string) (time.Time, error) { } // Stale returns whether or not the key is stale, based on ttl. -func (c *Cache) Stale(key string, ttl time.Duration) (bool, time.Time, error) { +func (c *Cache) Stale(ctx context.Context, key string, ttl time.Duration) (bool, time.Time, error) { mod, err := c.Mod(key) switch { case err != nil && errors.Is(err, fs.ErrNotExist): @@ -221,6 +222,9 @@ func (c *Cache) Stale(key string, ttl time.Duration) (bool, time.Time, error) { case err != nil: return false, time.Time{}, err } + if d, ok := TTL(ctx); ok { + ttl = d + } return ttl != 0 && time.Now().After(mod.Add(ttl)), mod, nil } @@ -230,7 +234,7 @@ func (c *Cache) Cached(req *http.Request) (bool, error) { if err != nil { return false, err } - stale, _, err := c.Stale(key, p.TTL) + stale, _, err := c.Stale(req.Context(), key, p.TTL) if err != nil { return false, err } @@ -347,3 +351,22 @@ func UserCacheDir(paths ...string) (string, error) { } return filepath.Join(append([]string{dir}, paths...)...), nil } + +// contextKey is a context key. +type contextKey string + +// context keys. +const ( + ttlKey contextKey = "ttl" +) + +// WithContextTTL adds the ttl to the context. +func WithContextTTL(parent context.Context, ttl time.Duration) context.Context { + return context.WithValue(parent, ttlKey, ttl) +} + +// TTL returns the ttl from the context. +func TTL(ctx context.Context) (time.Duration, bool) { + ttl, ok := ctx.Value(ttlKey).(time.Duration) + return ttl, ok +} diff --git a/diskcache_test.go b/diskcache_test.go new file mode 100644 index 0000000..661340a --- /dev/null +++ b/diskcache_test.go @@ -0,0 +1,97 @@ +package diskcache + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "sync/atomic" + "testing" + "time" +) + +func TestWithContextTTL(t *testing.T) { + // set up simple test server for demonstration + var count uint64 + s := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + fmt.Fprintf(res, "%d\n", atomic.AddUint64(&count, 1)) + })) + defer s.Close() + baseDir := setupDir(t, "test-with-context-ttl") + // create disk cache + c, err := New( + WithBasePathFs(baseDir), + WithErrorTruncator(), + WithTTL(365*24*time.Hour), + ) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + cl := &http.Client{ + Transport: c, + } + ctx := context.Background() + for i := 0; i < 3; i++ { + v, err := doReq(ctx, cl, s.URL) + switch { + case err != nil: + t.Fatalf("expected no error, got: %v", err) + case v != 1: + t.Errorf("expected %d, got: %d", 1, v) + } + } + if count != 1 { + t.Fatalf("expected count == %d, got: %d", 1, count) + } + for i := 1; i < 5; i++ { + v, err := doReq(WithContextTTL(ctx, 1*time.Millisecond), cl, s.URL) + switch { + case err != nil: + t.Fatalf("expected no error, got: %v", err) + case v != i+1: + t.Errorf("expected %d, got: %d", i+1, v) + } + <-time.After(2 * time.Millisecond) + } +} + +func doReq(ctx context.Context, cl *http.Client, urlstr string) (int, error) { + req, err := http.NewRequestWithContext(ctx, "GET", urlstr, nil) + if err != nil { + return -1, err + } + res, err := cl.Do(req) + if err != nil { + return -1, err + } + defer res.Body.Close() + buf, err := io.ReadAll(res.Body) + if err != nil { + return -1, err + } + return strconv.Atoi(string(bytes.TrimSpace(buf))) +} + +func setupDir(t *testing.T, name string) string { + t.Helper() + wd, err := os.Getwd() + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + dir := filepath.Join(wd, ".cache", name) + switch err := os.RemoveAll(dir); { + case errors.Is(err, os.ErrNotExist): + case err != nil: + t.Fatalf("expected no error, got: %v", err) + } + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + return dir +}