Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make use of contexts in more places #1261

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions api/api_interface_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ type ListRefresher interface {
}

type Querier interface {
Query(question string, qType dns.Type) (*model.Response, error)
Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error)
}

type CacheControl interface {
FlushCaches()
FlushCaches(ctx context.Context)
}

func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
Expand Down Expand Up @@ -137,13 +137,13 @@ func (i *OpenAPIInterfaceImpl) ListRefresh(_ context.Context,
return ListRefresh200Response{}, nil
}

func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObject) (QueryResponseObject, error) {
func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestObject) (QueryResponseObject, error) {
qType := dns.Type(dns.StringToType[request.Body.Type])
if qType == dns.Type(dns.TypeNone) {
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
}

resp, err := i.querier.Query(dns.Fqdn(request.Body.Query), qType)
resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType)
if err != nil {
return nil, err
}
Expand All @@ -156,10 +156,10 @@ func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObje
}), nil
}

func (i *OpenAPIInterfaceImpl) CacheFlush(_ context.Context,
func (i *OpenAPIInterfaceImpl) CacheFlush(ctx context.Context,
_ CacheFlushRequestObject,
) (CacheFlushResponseObject, error) {
i.cacheControl.FlushCaches()
i.cacheControl.FlushCaches(ctx)

return CacheFlush200Response{}, nil
}
39 changes: 22 additions & 17 deletions api/api_interface_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"time"

// . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
Expand Down Expand Up @@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
return args.Get(0).(BlockingStatus)
}

func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) {
args := m.Called(question, qType)
func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
args := m.Called(ctx, question, qType)

return args.Get(0).(*model.Response), args.Error(1)
}

func (m *CacheControlMock) FlushCaches() {
_ = m.Called()
func (m *CacheControlMock) FlushCaches(ctx context.Context) {
_ = m.Called(ctx)
}

var _ = Describe("API implementation tests", func() {
Expand All @@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() {
listRefreshMock *ListRefreshMock
cacheControlMock *CacheControlMock
sut *OpenAPIInterfaceImpl

ctx context.Context
cancelFn context.CancelFunc
)

BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)

blockingControlMock = &BlockingControlMock{}
querierMock = &QuerierMock{}
listRefreshMock = &ListRefreshMock{}
Expand All @@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() {
)
Expect(err).Should(Succeed())

querierMock.On("Query", "google.com.", A).Return(&model.Response{
querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
Res: queryResponse,
Reason: "reason",
}, nil)

resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com", Type: "A",
},
Expand All @@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() {
})

It("should return 400 on wrong parameter", func() {
resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com",
Type: "WRONGTYPE",
Expand All @@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
listRefreshMock.On("RefreshLists").Return(nil)

resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp200 ListRefresh200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 500 on failure", func() {
listRefreshMock.On("RefreshLists").Return(errors.New("failed"))

resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp500 ListRefresh500TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp500))
Expand All @@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() {
duration := "3s"
grroups := "gr1,gr2"

resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &duration,
Groups: &grroups,
Expand All @@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() {

It("should return 400 on failure", func() {
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{})
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp400 DisableBlocking400TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp400))
Expand All @@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() {

It("should return 400 on wrong duration parameter", func() {
wrongDuration := "4sds"
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &wrongDuration,
},
Expand All @@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
blockingControlMock.On("EnableBlocking").Return()

resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{})
resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp200 EnableBlocking200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() {
AutoEnableInSec: 47,
})

resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{})
resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{})
Expect(err).Should(Succeed())
var resp200 BlockingStatus200JSONResponse
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() {
Describe("Cache API", func() {
When("Cache flush is called", func() {
It("should return 200 on success", func() {
cacheControlMock.On("FlushCaches").Return()
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{})
cacheControlMock.On("FlushCaches", ctx).Return()
resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{})
Expect(err).Should(Succeed())
var resp200 CacheFlush200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand Down
6 changes: 3 additions & 3 deletions cache/expirationcache/expiration_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Options struct {
// OnExpirationCallback will be called just before an element gets expired and will
// be removed from cache. This function can return new value and TTL to leave the
// element in the cache or nil to remove it
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
type OnExpirationCallback[T any] func(ctx context.Context, key string) (val *T, ttl time.Duration)

// OnCacheHitCallback will be called on cache get if entry was found
type OnCacheHitCallback func(key string)
Expand All @@ -58,7 +58,7 @@ func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
l, _ := lru.New(defaultSize)
c := &ExpiringLRUCache[T]{
cleanUpInterval: defaultCleanUpInterval,
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
preExpirationFn: func(ctx context.Context, key string) (val *T, ttl time.Duration) {
return nil, 0
},
onCacheHit: func(key string) {},
Expand Down Expand Up @@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
var keysToDelete []string

for _, key := range expiredKeys {
newVal, newTTL := e.preExpirationFn(key)
newVal, newTTL := e.preExpirationFn(context.Background(), key)
if newVal != nil {
e.Put(key, newVal, newTTL)
} else {
Expand Down
6 changes: 3 additions & 3 deletions cache/expirationcache/expiration_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() {
Describe("preExpiration function", func() {
When("function is defined", func() {
It("should update the value and TTL if function returns values", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "v2"

return &v2, time.Second
Expand All @@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() {
})

It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "val2"

return &v2, time.Second
Expand All @@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() {
})

It("should delete the key if function returns nil", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
return nil, 0
}
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
Expand Down
10 changes: 6 additions & 4 deletions cache/expirationcache/prefetching_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ type cacheValue[T any] struct {
type OnEntryReloadedCallback func(key string)

// ReloadEntryFn reloads a prefetched entry by key
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
type ReloadEntryFn[T any] func(ctx context.Context, key string) (*T, time.Duration)

type PrefetchingOptions[T any] struct {
Options
ReloadFn func(cacheKey string) (*T, time.Duration)
ReloadFn ReloadEntryFn[T]
PrefetchThreshold int
PrefetchExpires time.Duration
PrefetchMaxItemsCount int
Expand Down Expand Up @@ -70,9 +70,11 @@ func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
}

func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
func (e *PrefetchingExpiringLRUCache[T]) onExpired(
ctx context.Context, cacheKey string,
) (val *cacheValue[T], ttl time.Duration) {
if e.shouldPrefetch(cacheKey) {
loadedVal, ttl := e.reloadFn(cacheKey)
loadedVal, ttl := e.reloadFn(ctx, cacheKey)
if loadedVal != nil {
if e.onPrefetchEntryReloaded != nil {
e.onPrefetchEntryReloaded(cacheKey)
Expand Down
8 changes: 4 additions & 4 deletions cache/expirationcache/prefetching_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down Expand Up @@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand All @@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() {
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down Expand Up @@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down
Loading