Skip to content

Commit

Permalink
chore(pkg/middlewares): use settings for cache and filter constructor…
Browse files Browse the repository at this point in the history
…s (future proofing for compatibility)
  • Loading branch information
qdm12 committed Nov 15, 2023
1 parent e07c215 commit 38380c0
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 14 deletions.
6 changes: 5 additions & 1 deletion examples/doh-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ func main() {
if err != nil {
log.Fatal(err)
}
cacheMiddleware := cachemiddleware.New(cache)

cacheMiddleware, err := cachemiddleware.New(cachemiddleware.Settings{Cache: cache})
if err != nil {
log.Fatal(err)
}

server, err := doh.NewServer(doh.ServerSettings{
Middlewares: []doh.Middleware{cacheMiddleware},
Expand Down
6 changes: 5 additions & 1 deletion examples/dot-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ func main() {
if err != nil {
log.Fatal(err)
}
cacheMiddleware := cachemiddleware.New(cache)

cacheMiddleware, err := cachemiddleware.New(cachemiddleware.Settings{Cache: cache})
if err != nil {
log.Fatal(err)
}

server, err := dot.NewServer(dot.ServerSettings{
Middlewares: []dot.Middleware{cacheMiddleware},
Expand Down
13 changes: 11 additions & 2 deletions internal/setup/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,19 @@ func DNS(userSettings settings.Settings, //nolint:ireturn
server Service, err error) {
var middlewares []Middleware

middlewares = append(middlewares, cachemiddleware.New(cache))
cacheMiddleware, err := cachemiddleware.New(cachemiddleware.Settings{Cache: cache})
if err != nil {
return nil, fmt.Errorf("creating cache middleware: %w", err)
}
middlewares = append(middlewares, cacheMiddleware)

filterMiddleware, err := filtermiddleware.New(filtermiddleware.Settings{Filter: filter})
if err != nil {
return nil, fmt.Errorf("creating filter middleware: %w", err)
}
// Note the filter middleware must be wrapping the cache middleware
// to catch filtered responses found from the cache.
middlewares = append(middlewares, filtermiddleware.New(filter))
middlewares = append(middlewares, filterMiddleware)

commonPrometheus := prometheus.Settings{
Prefix: *userSettings.Metrics.Prometheus.Subsystem,
Expand Down
6 changes: 4 additions & 2 deletions pkg/doh/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ func Test_Server_Mocks(t *testing.T) {
cache.EXPECT().Add(
mockhelp.NewMatcherRequest(expectedRequestAAAA),
mockhelp.NewMatcherResponse(expectedResponseAAAA))
cacheMiddleware := cachemiddleware.New(cache)
cacheMiddleware, err := cachemiddleware.New(cachemiddleware.Settings{Cache: cache})
require.NoError(t, err)

filter := NewMockfilter(ctrl)
filter.EXPECT().
Expand All @@ -234,7 +235,8 @@ func Test_Server_Mocks(t *testing.T) {
filter.EXPECT().
FilterResponse(mockhelp.NewMatcherResponse(expectedResponseAAAA)).
Return(false)
filterMiddleware := filtermiddleware.New(filter)
filterMiddleware, err := filtermiddleware.New(filtermiddleware.Settings{Filter: filter})
require.NoError(t, err)

logger := NewMockLogger(ctrl)
logger.EXPECT().Info(mockhelp.NewMatcherRegex("DNS server listening on .*:[1-9][0-9]{0,4}"))
Expand Down
6 changes: 4 additions & 2 deletions pkg/dot/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ func Test_Server_Mocks(t *testing.T) {
cache.EXPECT().Add(
mockhelp.NewMatcherRequest(expectedRequestAAAA),
mockhelp.NewMatcherResponse(expectedResponseAAAA))
cacheMiddleware := cachemiddleware.New(cache)
cacheMiddleware, err := cachemiddleware.New(cachemiddleware.Settings{Cache: cache})
require.NoError(t, err)

filter := NewMockfilter(ctrl)
filter.EXPECT().
Expand All @@ -230,7 +231,8 @@ func Test_Server_Mocks(t *testing.T) {
filter.EXPECT().
FilterResponse(mockhelp.NewMatcherResponse(expectedResponseAAAA)).
Return(false)
filterMiddleware := filtermiddleware.New(filter)
filterMiddleware, err := filtermiddleware.New(filtermiddleware.Settings{Filter: filter})
require.NoError(t, err)

logger := NewMockLogger(ctrl)
logger.EXPECT().Info(mockhelp.NewMatcherRegex("DNS server listening on .*:[1-9][0-9]{0,4}"))
Expand Down
13 changes: 10 additions & 3 deletions pkg/middlewares/cache/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cache

import (
"fmt"

"github.com/miekg/dns"
"github.com/qdm12/dns/v2/internal/stateful"
)
Expand All @@ -9,10 +11,15 @@ type Middleware struct {
cache Cache
}

func New(cache Cache) *Middleware {
return &Middleware{
cache: cache,
func New(settings Settings) (middleware *Middleware, err error) {
err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

return &Middleware{
cache: settings.Cache,
}, nil
}

func (m *Middleware) Wrap(next dns.Handler) dns.Handler { //nolint:ireturn
Expand Down
22 changes: 22 additions & 0 deletions pkg/middlewares/cache/settings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package cache

import (
"errors"
"fmt"
)

type Settings struct {
Cache Cache
}

var (
ErrCacheMustBeSet = errors.New("cache must be set")
)

func (s *Settings) Validate() (err error) {
if s.Cache == nil {
return fmt.Errorf("%w", ErrCacheMustBeSet)
}

return nil
}
13 changes: 10 additions & 3 deletions pkg/middlewares/filter/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package filter

import (
"fmt"

"github.com/miekg/dns"
"github.com/qdm12/dns/v2/internal/stateful"
)
Expand All @@ -9,10 +11,15 @@ type Middleware struct {
filter Filter
}

func New(filter Filter) *Middleware {
return &Middleware{
filter: filter,
func New(settings Settings) (middleware *Middleware, err error) {
err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

return &Middleware{
filter: settings.Filter,
}, nil
}

func (m *Middleware) Wrap(next dns.Handler) dns.Handler { //nolint:ireturn
Expand Down
22 changes: 22 additions & 0 deletions pkg/middlewares/filter/settings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package filter

import (
"errors"
"fmt"
)

type Settings struct {
Filter Filter
}

var (
ErrFilterMustBeSet = errors.New("filter must be set")
)

func (s *Settings) Validate() (err error) {
if s.Filter == nil {
return fmt.Errorf("%w", ErrFilterMustBeSet)
}

return nil
}

0 comments on commit 38380c0

Please sign in to comment.