diff --git a/cmd/swat4master/container/container.go b/cmd/swat4master/container/container.go index 0aeaeac..0c482f5 100644 --- a/cmd/swat4master/container/container.go +++ b/cmd/swat4master/container/container.go @@ -5,7 +5,6 @@ import ( "github.com/sergeii/swat4master/cmd/swat4master/config" "github.com/sergeii/swat4master/internal/core/usecases/addserver" - "github.com/sergeii/swat4master/internal/core/usecases/cleanservers" "github.com/sergeii/swat4master/internal/core/usecases/getserver" "github.com/sergeii/swat4master/internal/core/usecases/listservers" "github.com/sergeii/swat4master/internal/core/usecases/probeserver" @@ -44,7 +43,6 @@ func NewUseCaseConfigs(cfg config.Config) UseCaseConfigs { type Container struct { AddServer addserver.UseCase - CleanServers cleanservers.UseCase GetServer getserver.UseCase ListServers listservers.UseCase ProbeServer probeserver.UseCase @@ -57,7 +55,6 @@ type Container struct { func NewContainer( addServerUseCase addserver.UseCase, - cleanServersUseCase cleanservers.UseCase, getServerUseCase getserver.UseCase, listServersUseCase listservers.UseCase, probeServerUseCase probeserver.UseCase, @@ -69,7 +66,6 @@ func NewContainer( ) Container { return Container{ AddServer: addServerUseCase, - CleanServers: cleanServersUseCase, GetServer: getServerUseCase, ListServers: listServersUseCase, ProbeServer: probeServerUseCase, @@ -88,7 +84,6 @@ var Module = fx.Module("container", fx.Provide(reportserver.New), fx.Provide(renewserver.New), fx.Provide(removeserver.New), - fx.Provide(cleanservers.New), fx.Provide(refreshservers.New), fx.Provide(reviveservers.New), fx.Provide(probeserver.New), diff --git a/cmd/swat4master/modules/cleaner/cleaner.go b/cmd/swat4master/modules/cleaner/cleaner.go index cf61ec8..83f7a86 100644 --- a/cmd/swat4master/modules/cleaner/cleaner.go +++ b/cmd/swat4master/modules/cleaner/cleaner.go @@ -2,15 +2,15 @@ package cleaner import ( "context" - "time" "github.com/jonboulle/clockwork" "github.com/rs/zerolog" "go.uber.org/fx" "github.com/sergeii/swat4master/cmd/swat4master/config" - "github.com/sergeii/swat4master/internal/core/usecases/cleanservers" - "github.com/sergeii/swat4master/internal/metrics" + "github.com/sergeii/swat4master/internal/cleanup" + "github.com/sergeii/swat4master/internal/cleanup/cleaners/instancecleaner" + "github.com/sergeii/swat4master/internal/cleanup/cleaners/servercleaner" ) type Cleaner struct{} @@ -20,8 +20,7 @@ func Run( stopped chan struct{}, clock clockwork.Clock, logger *zerolog.Logger, - metrics *metrics.Collector, - uc cleanservers.UseCase, + manager *cleanup.Manager, cfg config.Config, ) { ticker := clock.NewTicker(cfg.CleanInterval) @@ -42,7 +41,7 @@ func Run( close(stopped) return case <-tickerCh: - clean(ctx, clock, logger, metrics, uc, cfg.CleanRetention) + manager.Clean(ctx) } } } @@ -51,8 +50,7 @@ func NewCleaner( lc fx.Lifecycle, cfg config.Config, clock clockwork.Clock, - metrics *metrics.Collector, - uc cleanservers.UseCase, + manager *cleanup.Manager, logger *zerolog.Logger, ) *Cleaner { stopped := make(chan struct{}) @@ -60,7 +58,7 @@ func NewCleaner( lc.Append(fx.Hook{ OnStart: func(context.Context) error { - go Run(stop, stopped, clock, logger, metrics, uc, cfg) // nolint: contextcheck + go Run(stop, stopped, clock, logger, manager, cfg) // nolint: contextcheck return nil }, OnStop: func(context.Context) error { @@ -73,24 +71,30 @@ func NewCleaner( return &Cleaner{} } -func clean( - ctx context.Context, - clock clockwork.Clock, - logger *zerolog.Logger, - metrics *metrics.Collector, - uc cleanservers.UseCase, - retention time.Duration, -) { - resp, err := uc.Execute(ctx, clock.Now().Add(-retention)) - if err != nil { - logger.Error(). - Err(err). - Msg("Failed to clean outdated servers") +type Opts struct { + fx.Out + + ServerCleanerOpts servercleaner.Opts + InstanceCleanerOpts instancecleaner.Opts +} + +func provideCleanerConfigs(cfg config.Config) Opts { + return Opts{ + ServerCleanerOpts: servercleaner.Opts{ + Retention: cfg.CleanRetention, + }, + InstanceCleanerOpts: instancecleaner.Opts{ + Retention: cfg.CleanRetention, + }, } - metrics.CleanerRemovals.Add(float64(resp.Count)) - metrics.CleanerErrors.Add(float64(resp.Errors)) } var Module = fx.Module("cleaner", + fx.Provide(cleanup.NewManager), + fx.Provide(provideCleanerConfigs), + fx.Invoke( + servercleaner.New, + instancecleaner.New, + ), fx.Provide(NewCleaner), ) diff --git a/internal/cleanup/cleaner.go b/internal/cleanup/cleaner.go new file mode 100644 index 0000000..bc92583 --- /dev/null +++ b/internal/cleanup/cleaner.go @@ -0,0 +1,31 @@ +package cleanup + +import ( + "context" + "sync" +) + +type Cleaner interface { + Clean(ctx context.Context) +} + +type Manager struct { + mutex sync.Mutex + cleaners []Cleaner +} + +func NewManager() *Manager { + return &Manager{} +} + +func (m *Manager) AddCleaner(c Cleaner) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.cleaners = append(m.cleaners, c) +} + +func (m *Manager) Clean(ctx context.Context) { + for _, c := range m.cleaners { + go c.Clean(ctx) + } +} diff --git a/internal/cleanup/cleaners/instancecleaner/instancecleaner.go b/internal/cleanup/cleaners/instancecleaner/instancecleaner.go new file mode 100644 index 0000000..3400a1e --- /dev/null +++ b/internal/cleanup/cleaners/instancecleaner/instancecleaner.go @@ -0,0 +1,63 @@ +package instancecleaner + +import ( + "context" + "time" + + "github.com/jonboulle/clockwork" + "github.com/rs/zerolog" + + "github.com/sergeii/swat4master/internal/cleanup" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/metrics" +) + +type Opts struct { + Retention time.Duration +} + +type InstanceCleaner struct { + opts Opts + instanceRepo repositories.InstanceRepository + clock clockwork.Clock + metrics *metrics.Collector + logger *zerolog.Logger +} + +func New( + manager *cleanup.Manager, + opts Opts, + instanceRepo repositories.InstanceRepository, + clock clockwork.Clock, + metrics *metrics.Collector, + logger *zerolog.Logger, +) InstanceCleaner { + cleaner := InstanceCleaner{ + opts: opts, + instanceRepo: instanceRepo, + clock: clock, + metrics: metrics, + logger: logger, + } + manager.AddCleaner(&cleaner) + return cleaner +} + +func (c InstanceCleaner) Clean(ctx context.Context) { + // Calculate the cutoff time for cleaning instances. + cleanUntil := c.clock.Now().Add(-c.opts.Retention) + fs := filterset.NewInstanceFilterSet().UpdatedBefore(cleanUntil) + + c.logger.Info().Stringer("until", cleanUntil).Msg("Starting to clean instances") + + count, err := c.instanceRepo.Clear(ctx, fs) + if err != nil { + c.metrics.CleanerErrors.WithLabelValues("instances").Inc() + c.logger.Error().Err(err).Stringer("until", cleanUntil).Msg("Failed to clean instances") + return + } + + c.metrics.CleanerRemovals.WithLabelValues("instances").Add(float64(count)) + c.logger.Info().Stringer("until", cleanUntil).Int("removed", count).Msg("Finished cleaning instances") +} diff --git a/internal/cleanup/cleaners/instancecleaner/instancecleaner_test.go b/internal/cleanup/cleaners/instancecleaner/instancecleaner_test.go new file mode 100644 index 0000000..adb19a7 --- /dev/null +++ b/internal/cleanup/cleaners/instancecleaner/instancecleaner_test.go @@ -0,0 +1,111 @@ +package instancecleaner_test + +import ( + "context" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/sergeii/swat4master/internal/cleanup" + "github.com/sergeii/swat4master/internal/cleanup/cleaners/instancecleaner" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/metrics" +) + +type MockInstanceRepository struct { + mock.Mock + repositories.InstanceRepository +} + +func (m *MockInstanceRepository) Clear(ctx context.Context, fs filterset.InstanceFilterSet) (int, error) { + args := m.Called(ctx, fs) + return args.Get(0).(int), args.Error(1) // nolint: forcetypeassert +} + +func TestInstanceCleaner_Clean_OK(t *testing.T) { + ctx := context.TODO() + + manager := cleanup.NewManager() + collector := metrics.New() + clock := clockwork.NewFakeClock() + logger := zerolog.Nop() + options := instancecleaner.Opts{ + Retention: time.Hour, + } + + instanceRepo := new(MockInstanceRepository) + instanceRepo.On("Clear", ctx, mock.Anything).Return(37, nil) + + cleaner := instancecleaner.New( + manager, + options, + instanceRepo, + clock, + collector, + &logger, + ) + cleaner.Clean(ctx) + + instanceRepo.AssertCalled( + t, + "Clear", + ctx, + mock.MatchedBy(func(fs filterset.InstanceFilterSet) bool { + updatedBefore, ok := fs.GetUpdatedBefore() + wantUpdatedBefore := ok && updatedBefore.Equal(clock.Now().Add(-time.Hour)) + return wantUpdatedBefore + }), + ) + + cleanerRemovalsWithInstancesValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues("instances")) + assert.Equal(t, float64(37), cleanerRemovalsWithInstancesValue) + cleanerErrorsWithInstancesValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues("instances")) + assert.Equal(t, float64(0), cleanerErrorsWithInstancesValue) +} + +func TestInstanceCleaner_Clean_RepoError(t *testing.T) { + ctx := context.TODO() + + manager := cleanup.NewManager() + collector := metrics.New() + clock := clockwork.NewFakeClock() + logger := zerolog.Nop() + options := instancecleaner.Opts{ + Retention: time.Hour, + } + + instanceRepo := new(MockInstanceRepository) + instanceRepo.On("Clear", ctx, mock.Anything).Return(0, assert.AnError) + + cleaner := instancecleaner.New( + manager, + options, + instanceRepo, + clock, + collector, + &logger, + ) + cleaner.Clean(ctx) + + instanceRepo.AssertCalled( + t, + "Clear", + ctx, + mock.MatchedBy(func(fs filterset.InstanceFilterSet) bool { + updatedBefore, ok := fs.GetUpdatedBefore() + wantUpdatedBefore := ok && updatedBefore.Equal(clock.Now().Add(-time.Hour)) + return wantUpdatedBefore + }), + ) + + cleanerRemovalsWithInstancesValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues("instances")) + assert.Equal(t, float64(0), cleanerRemovalsWithInstancesValue) + cleanerErrorsWithInstancesValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues("instances")) + assert.Equal(t, float64(1), cleanerErrorsWithInstancesValue) +} diff --git a/internal/cleanup/cleaners/servercleaner/servercleaner.go b/internal/cleanup/cleaners/servercleaner/servercleaner.go new file mode 100644 index 0000000..bdd4096 --- /dev/null +++ b/internal/cleanup/cleaners/servercleaner/servercleaner.go @@ -0,0 +1,97 @@ +package servercleaner + +import ( + "context" + "time" + + "github.com/jonboulle/clockwork" + "github.com/rs/zerolog" + + "github.com/sergeii/swat4master/internal/cleanup" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/entities/server" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/metrics" +) + +type Opts struct { + Retention time.Duration +} + +type ServerCleaner struct { + opts Opts + serverRepo repositories.ServerRepository + clock clockwork.Clock + metrics *metrics.Collector + logger *zerolog.Logger +} + +func New( + manager *cleanup.Manager, + opts Opts, + serverRepo repositories.ServerRepository, + clock clockwork.Clock, + metrics *metrics.Collector, + logger *zerolog.Logger, +) ServerCleaner { + cleaner := ServerCleaner{ + opts: opts, + serverRepo: serverRepo, + clock: clock, + metrics: metrics, + logger: logger, + } + manager.AddCleaner(&cleaner) + return cleaner +} + +func (c ServerCleaner) Clean(ctx context.Context) { + cleanUntil := c.clock.Now().Add(-c.opts.Retention) + fs := filterset.NewServerFilterSet().UpdatedBefore(cleanUntil) + + c.logger.Info().Stringer("until", cleanUntil).Msg("Starting to clean outdated servers") + + outdatedServers, err := c.serverRepo.Filter(ctx, fs) + if err != nil { + c.logger.Error().Err(err).Msg("Unable to obtain servers for cleanup") + return + } + + removed, errors := c.cleanServers(ctx, outdatedServers, cleanUntil) + + c.metrics.CleanerRemovals.WithLabelValues("servers").Add(float64(removed)) + c.metrics.CleanerErrors.WithLabelValues("servers").Add(float64(errors)) + c.logger.Info(). + Stringer("until", cleanUntil). + Int("removed", removed).Int("errors", errors). + Msg("Finished cleaning servers") +} + +func (c ServerCleaner) cleanServers( + ctx context.Context, + servers []server.Server, + cleanUntil time.Time, +) (int, int) { + removed, errors := 0, 0 + for _, svr := range servers { + err := c.serverRepo.Remove(ctx, svr, func(conflict *server.Server) bool { + if conflict.RefreshedAt.After(cleanUntil) { + c.logger.Info(). + Stringer("server", conflict).Stringer("refreshed", conflict.RefreshedAt). + Msg("Removed server is more recent") + return false + } + return true + }) + if err != nil { + c.logger.Error(). + Err(err). + Stringer("until", cleanUntil).Stringer("addr", svr.Addr). + Msg("Failed to remove outdated server") + errors++ + continue + } + removed++ + } + return removed, errors +} diff --git a/internal/cleanup/cleaners/servercleaner/servercleaner_test.go b/internal/cleanup/cleaners/servercleaner/servercleaner_test.go new file mode 100644 index 0000000..1de1360 --- /dev/null +++ b/internal/cleanup/cleaners/servercleaner/servercleaner_test.go @@ -0,0 +1,196 @@ +package servercleaner_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/sergeii/swat4master/internal/cleanup" + "github.com/sergeii/swat4master/internal/cleanup/cleaners/servercleaner" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/entities/server" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/metrics" + "github.com/sergeii/swat4master/internal/testutils/factories/serverfactory" +) + +type MockServerRepository struct { + mock.Mock + repositories.ServerRepository +} + +func (m *MockServerRepository) Filter(ctx context.Context, fs filterset.ServerFilterSet) ([]server.Server, error) { + args := m.Called(ctx, fs) + return args.Get(0).([]server.Server), args.Error(1) // nolint: forcetypeassert +} + +func (m *MockServerRepository) Remove( + ctx context.Context, + svr server.Server, + onConflict func(*server.Server) bool, +) error { + args := m.Called(ctx, svr, onConflict) + return args.Error(0) +} + +// func TestCleanServersUseCase_RemoveErrors(t *testing.T) { +// ctx := context.TODO() +// logger := zerolog.Nop() +// +// until := time.Now().Add(-24 * time.Hour) // Example time filter +// +// svr1 := serverfactory.BuildRandom() +// svr2 := serverfactory.BuildRandom() +// svr3 := serverfactory.BuildRandom() +// outdatedServers := []server.Server{svr1, svr2, svr3} +// +// serverRepo := new(MockServerRepository) +// serverRepo.On("Count", ctx).Return(3, nil).Once() +// serverRepo.On("Count", ctx).Return(2, nil).Once() +// serverRepo.On("Filter", ctx, mock.Anything).Return(outdatedServers, nil).Once() +// serverRepo.On("Remove", ctx, svr1, mock.Anything).Return(nil).Once() +// serverRepo.On("Remove", ctx, svr2, mock.Anything).Return(nil).Once() +// serverRepo.On("Remove", ctx, svr3, mock.Anything).Return(errors.New("error")).Once() +// +// uc := cleanservers.New(serverRepo, &logger) +// response, err := uc.Execute(ctx, until) +// +// assert.NoError(t, err) +// assert.Equal(t, 2, response.Count) +// assert.Equal(t, 1, response.Errors) +// +// serverRepo.AssertExpectations(t) +// serverRepo.AssertNumberOfCalls(t, "Remove", 3) +// } + +func TestServerCleaner_Clean_OK(t *testing.T) { + ctx := context.TODO() + + manager := cleanup.NewManager() + collector := metrics.New() + clock := clockwork.NewFakeClock() + logger := zerolog.Nop() + options := servercleaner.Opts{ + Retention: time.Hour * 24, + } + + outdatedServers := []server.Server{ + serverfactory.BuildRandom(), + serverfactory.BuildRandom(), + } + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return(outdatedServers, nil).Once() + serverRepo.On("Remove", ctx, mock.Anything, mock.Anything).Return(nil).Times(2) + + cleaner := servercleaner.New( + manager, + options, + serverRepo, + clock, + collector, + &logger, + ) + cleaner.Clean(ctx) + + serverRepo.AssertExpectations(t) + serverRepo.AssertCalled( + t, + "Filter", + ctx, + mock.MatchedBy(func(fs filterset.ServerFilterSet) bool { + updatedBefore, ok := fs.GetUpdatedBefore() + wantUpdatedBefore := ok && updatedBefore.Equal(clock.Now().Add(-time.Hour*24)) + return wantUpdatedBefore + }), + ) + for _, svr := range outdatedServers { + serverRepo.AssertCalled(t, "Remove", ctx, svr, mock.Anything) + } + + cleanerRemovalsWithServersValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues("servers")) + assert.Equal(t, float64(2), cleanerRemovalsWithServersValue) + cleanerErrorsWithServersValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues("servers")) + assert.Equal(t, float64(0), cleanerErrorsWithServersValue) +} + +func TestServerCleaner_Clean_NothingToClean(t *testing.T) { + ctx := context.TODO() + + manager := cleanup.NewManager() + collector := metrics.New() + clock := clockwork.NewFakeClock() + logger := zerolog.Nop() + options := servercleaner.Opts{ + Retention: time.Hour * 24, + } + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return([]server.Server{}, nil).Once() + + cleaner := servercleaner.New( + manager, + options, + serverRepo, + clock, + collector, + &logger, + ) + cleaner.Clean(ctx) + + serverRepo.AssertExpectations(t) + serverRepo.AssertNotCalled(t, "Remove", mock.Anything, mock.Anything) + + cleanerRemovalsWithServersValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues("servers")) + assert.Equal(t, float64(0), cleanerRemovalsWithServersValue) + cleanerErrorsWithServersValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues("servers")) + assert.Equal(t, float64(0), cleanerErrorsWithServersValue) +} + +func TestServerCleaner_Clean_RepoErrors(t *testing.T) { + ctx := context.TODO() + + manager := cleanup.NewManager() + collector := metrics.New() + clock := clockwork.NewFakeClock() + logger := zerolog.Nop() + options := servercleaner.Opts{ + Retention: time.Hour * 24, + } + + svr1 := serverfactory.BuildRandom() + svr2 := serverfactory.BuildRandom() + svr3 := serverfactory.BuildRandom() + outdatedServers := []server.Server{svr1, svr2, svr3} + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return(outdatedServers, nil).Once() + serverRepo.On("Remove", ctx, svr1, mock.Anything).Return(nil).Once() + serverRepo.On("Remove", ctx, svr2, mock.Anything).Return(nil).Once() + serverRepo.On("Remove", ctx, svr3, mock.Anything).Return(errors.New("error")).Once() + + cleaner := servercleaner.New( + manager, + options, + serverRepo, + clock, + collector, + &logger, + ) + cleaner.Clean(ctx) + + serverRepo.AssertExpectations(t) + serverRepo.AssertNumberOfCalls(t, "Remove", 3) + + cleanerRemovalsWithServersValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues("servers")) + assert.Equal(t, float64(2), cleanerRemovalsWithServersValue) + cleanerErrorsWithServersValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues("servers")) + assert.Equal(t, float64(1), cleanerErrorsWithServersValue) +} diff --git a/internal/core/entities/instance/instance.go b/internal/core/entities/instance/instance.go index ead4478..c31a1e4 100644 --- a/internal/core/entities/instance/instance.go +++ b/internal/core/entities/instance/instance.go @@ -1,19 +1,43 @@ package instance import ( + "fmt" "net" "github.com/sergeii/swat4master/internal/core/entities/addr" ) +type Identifier [4]byte + +func NewID(bytes []byte) (Identifier, error) { + if len(bytes) != 4 { + return Identifier{}, fmt.Errorf("instance ID must be 4 bytes long, got %d", len(bytes)) + } + var id Identifier + copy(id[:], bytes) + return id, nil +} + +func MustNewID(bytes []byte) Identifier { + id, err := NewID(bytes) + if err != nil { + panic(err) + } + return id +} + +func (id Identifier) Hex() string { + return fmt.Sprintf("%x", id) +} + type Instance struct { - ID string + ID Identifier Addr addr.Addr } var Blank Instance // nolint: gochecknoglobals -func New(id string, ip net.IP, port int) (Instance, error) { +func New(id Identifier, ip net.IP, port int) (Instance, error) { insAddr, err := addr.New(ip, port) if err != nil { return Blank, err @@ -21,7 +45,7 @@ func New(id string, ip net.IP, port int) (Instance, error) { return Instance{id, insAddr}, nil } -func MustNew(id string, ip net.IP, port int) Instance { +func MustNew(id Identifier, ip net.IP, port int) Instance { ins, err := New(id, ip, port) if err != nil { panic(err) diff --git a/internal/core/entities/instance/instance_test.go b/internal/core/entities/instance/instance_test.go index 5bcfd7a..a3077da 100644 --- a/internal/core/entities/instance/instance_test.go +++ b/internal/core/entities/instance/instance_test.go @@ -5,15 +5,81 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/sergeii/swat4master/internal/core/entities/instance" ) -func TestInstance_New(t *testing.T) { - ins, err := instance.New("foo", net.ParseIP("2.2.2.2"), 10480) +func TestIdentifier_New_OK(t *testing.T) { + id, err := instance.NewID([]byte("test")) assert.NoError(t, err) - assert.Equal(t, "foo", ins.ID) + assert.Equal(t, "test", string(id[:])) +} + +func TestIdentifier_New_Errors(t *testing.T) { + tests := []struct { + name string + id []byte + }{ + { + name: "empty", + id: []byte{}, + }, + { + name: "short", + id: []byte("123"), + }, + { + name: "long", + id: []byte("12345"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := instance.NewID(tt.id) + assert.ErrorContains(t, err, "instance ID must be 4 bytes long") + }) + } +} + +func TestInstance_New_OK(t *testing.T) { + id := instance.MustNewID([]byte("test")) + ins, err := instance.New(id, net.ParseIP("2.2.2.2"), 10480) + assert.NoError(t, err) + assert.Equal(t, id, ins.ID) assert.Equal(t, "2.2.2.2", ins.Addr.GetDottedIP()) assert.Equal(t, 10480, ins.Addr.Port) assert.Equal(t, "2.2.2.2:10480", ins.Addr.String()) } + +func TestInstance_New_Errors(t *testing.T) { + tests := []struct { + name string + id instance.Identifier + ip string + port int + wantErrMsg string + }{ + { + name: "invalid ip", + id: instance.MustNewID([]byte("test")), + ip: "256.256.256.256", + port: 10480, + wantErrMsg: "invalid IP address", + }, + { + name: "invalid port", + id: instance.MustNewID([]byte("test")), + ip: "1.1.1.1", + port: 0, + wantErrMsg: "invalid port number", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := instance.New(tt.id, net.ParseIP(tt.ip), tt.port) + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErrMsg) + }) + } +} diff --git a/internal/core/repositories/instance.go b/internal/core/repositories/instance.go index 5bf6f7c..96f2fb3 100644 --- a/internal/core/repositories/instance.go +++ b/internal/core/repositories/instance.go @@ -12,8 +12,8 @@ var ErrInstanceNotFound = errors.New("the requested instance was not found") type InstanceRepository interface { Add(context.Context, instance.Instance) error - Get(context.Context, string) (instance.Instance, error) - Remove(context.Context, string) error + Get(context.Context, instance.Identifier) (instance.Instance, error) + Remove(context.Context, instance.Identifier) error Clear(context.Context, filterset.InstanceFilterSet) (int, error) Count(context.Context) (int, error) } diff --git a/internal/core/usecases/cleanservers/cleanservers.go b/internal/core/usecases/cleanservers/cleanservers.go deleted file mode 100644 index 32bde8a..0000000 --- a/internal/core/usecases/cleanservers/cleanservers.go +++ /dev/null @@ -1,87 +0,0 @@ -package cleanservers - -import ( - "context" - "time" - - "github.com/rs/zerolog" - - "github.com/sergeii/swat4master/internal/core/entities/filterset" - "github.com/sergeii/swat4master/internal/core/entities/server" - "github.com/sergeii/swat4master/internal/core/repositories" -) - -type UseCase struct { - serverRepo repositories.ServerRepository - logger *zerolog.Logger -} - -func New( - serverRepo repositories.ServerRepository, - logger *zerolog.Logger, -) UseCase { - return UseCase{ - serverRepo: serverRepo, - logger: logger, - } -} - -type Response struct { - Count int - Errors int -} - -var NoResponse = Response{} - -func (uc UseCase) Execute(ctx context.Context, until time.Time) (Response, error) { - var before, after, removed, errors int - var err error - - if before, err = uc.serverRepo.Count(ctx); err != nil { - return NoResponse, err - } - - uc.logger.Info(). - Stringer("until", until).Int("servers", before). - Msg("Starting to clean outdated servers") - - fs := filterset.NewServerFilterSet().UpdatedBefore(until) - outdatedServers, err := uc.serverRepo.Filter(ctx, fs) - if err != nil { - uc.logger.Error().Err(err).Msg("Unable to obtain servers for cleanup") - return NoResponse, err - } - - for _, svr := range outdatedServers { - err := uc.serverRepo.Remove(ctx, svr, func(conflict *server.Server) bool { - if conflict.RefreshedAt.After(until) { - uc.logger.Info(). - Stringer("server", conflict).Stringer("refreshed", conflict.RefreshedAt). - Msg("Removed server is more recent") - return false - } - return true - }) - if err != nil { - uc.logger.Error(). - Err(err). - Stringer("until", until).Stringer("addr", svr.Addr). - Msg("Failed to remove outdated server") - errors++ - continue - } - removed++ - } - - if after, err = uc.serverRepo.Count(ctx); err != nil { - return NoResponse, err - } - - uc.logger.Info(). - Stringer("until", until). - Int("removed", removed).Int("errors", errors). - Int("before", before).Int("after", after). - Msg("Finished cleaning outdated servers") - - return Response{Count: removed, Errors: errors}, nil -} diff --git a/internal/core/usecases/cleanservers/cleanservers_test.go b/internal/core/usecases/cleanservers/cleanservers_test.go deleted file mode 100644 index 983b7eb..0000000 --- a/internal/core/usecases/cleanservers/cleanservers_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package cleanservers_test - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/sergeii/swat4master/internal/core/entities/filterset" - "github.com/sergeii/swat4master/internal/core/entities/server" - "github.com/sergeii/swat4master/internal/core/repositories" - "github.com/sergeii/swat4master/internal/core/usecases/cleanservers" - "github.com/sergeii/swat4master/internal/testutils/factories/serverfactory" -) - -type MockServerRepository struct { - mock.Mock - repositories.ServerRepository -} - -func (m *MockServerRepository) Count(ctx context.Context) (int, error) { - args := m.Called(ctx) - return args.Int(0), args.Error(1) -} - -func (m *MockServerRepository) Filter( - ctx context.Context, - fs filterset.ServerFilterSet, -) ([]server.Server, error) { - args := m.Called(ctx, fs) - return args.Get(0).([]server.Server), args.Error(1) // nolint: forcetypeassert -} - -func (m *MockServerRepository) Remove( - ctx context.Context, - svr server.Server, - onConflict func(*server.Server) bool, -) error { - args := m.Called(ctx, svr, onConflict) - return args.Error(0) -} - -func TestCleanServersUseCase_Success(t *testing.T) { - ctx := context.TODO() - logger := zerolog.Nop() - - until := time.Now().Add(-24 * time.Hour) // Example time filter - - outdatedServers := []server.Server{ - serverfactory.BuildRandom(), - serverfactory.BuildRandom(), - } - - serverRepo := new(MockServerRepository) - serverRepo.On("Count", ctx).Return(10, nil).Once() - serverRepo.On("Count", ctx).Return(8, nil).Once() - serverRepo.On("Filter", ctx, mock.Anything).Return(outdatedServers, nil).Once() - serverRepo.On("Remove", ctx, mock.Anything, mock.Anything).Return(nil).Times(2) - - uc := cleanservers.New(serverRepo, &logger) - response, err := uc.Execute(ctx, until) - - assert.NoError(t, err) - assert.Equal(t, 2, response.Count) - assert.Equal(t, 0, response.Errors) - - serverRepo.AssertExpectations(t) - serverRepo.AssertCalled( - t, - "Filter", - ctx, - mock.MatchedBy(func(fs filterset.ServerFilterSet) bool { - updatedBefore, _ := fs.GetUpdatedBefore() - return updatedBefore.Equal(until) - }), - ) - for _, svr := range outdatedServers { - serverRepo.AssertCalled(t, "Remove", ctx, svr, mock.Anything) - } -} - -func TestCleanServersUseCase_NothingToClean(t *testing.T) { - ctx := context.TODO() - logger := zerolog.Nop() - - until := time.Now().Add(-24 * time.Hour) // Example time filter - - serverRepo := new(MockServerRepository) - serverRepo.On("Count", ctx).Return(0, nil).Times(2) - serverRepo.On("Filter", ctx, mock.Anything).Return([]server.Server{}, nil).Once() - - uc := cleanservers.New(serverRepo, &logger) - response, err := uc.Execute(ctx, until) - - assert.NoError(t, err) - assert.Equal(t, 0, response.Count) - assert.Equal(t, 0, response.Errors) - - serverRepo.AssertExpectations(t) - serverRepo.AssertNotCalled(t, "Remove", mock.Anything, mock.Anything, mock.Anything) -} - -func TestCleanServersUseCase_RemoveErrors(t *testing.T) { - ctx := context.TODO() - logger := zerolog.Nop() - - until := time.Now().Add(-24 * time.Hour) // Example time filter - - svr1 := serverfactory.BuildRandom() - svr2 := serverfactory.BuildRandom() - svr3 := serverfactory.BuildRandom() - outdatedServers := []server.Server{svr1, svr2, svr3} - - serverRepo := new(MockServerRepository) - serverRepo.On("Count", ctx).Return(3, nil).Once() - serverRepo.On("Count", ctx).Return(2, nil).Once() - serverRepo.On("Filter", ctx, mock.Anything).Return(outdatedServers, nil).Once() - serverRepo.On("Remove", ctx, svr1, mock.Anything).Return(nil).Once() - serverRepo.On("Remove", ctx, svr2, mock.Anything).Return(nil).Once() - serverRepo.On("Remove", ctx, svr3, mock.Anything).Return(errors.New("error")).Once() - - uc := cleanservers.New(serverRepo, &logger) - response, err := uc.Execute(ctx, until) - - assert.NoError(t, err) - assert.Equal(t, 2, response.Count) - assert.Equal(t, 1, response.Errors) - - serverRepo.AssertExpectations(t) - serverRepo.AssertNumberOfCalls(t, "Remove", 3) -} - -func TestCleanServersUseCase_CountError(t *testing.T) { - ctx := context.TODO() - logger := zerolog.Nop() - - until := time.Now().Add(-24 * time.Hour) // Example time filter - countErr := errors.New("error") - - serverRepo := new(MockServerRepository) - serverRepo.On("Count", ctx).Return(0, countErr).Once() - - uc := cleanservers.New(serverRepo, &logger) - response, err := uc.Execute(ctx, until) - - assert.ErrorIs(t, err, countErr) - assert.Equal(t, cleanservers.NoResponse, response) - - serverRepo.AssertExpectations(t) - serverRepo.AssertNumberOfCalls(t, "Filter", 0) - serverRepo.AssertNumberOfCalls(t, "Remove", 0) -} diff --git a/internal/core/usecases/removeserver/removeserver.go b/internal/core/usecases/removeserver/removeserver.go index 195ded2..2d44372 100644 --- a/internal/core/usecases/removeserver/removeserver.go +++ b/internal/core/usecases/removeserver/removeserver.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog" "github.com/sergeii/swat4master/internal/core/entities/addr" + "github.com/sergeii/swat4master/internal/core/entities/instance" "github.com/sergeii/swat4master/internal/core/entities/server" "github.com/sergeii/swat4master/internal/core/repositories" ) @@ -37,11 +38,11 @@ func New( } type Request struct { - instanceID string + instanceID []byte svrAddr addr.Addr } -func NewRequest(instanceID string, svrAddr addr.Addr) Request { +func NewRequest(instanceID []byte, svrAddr addr.Addr) Request { return Request{ instanceID: instanceID, svrAddr: svrAddr, @@ -49,37 +50,18 @@ func NewRequest(instanceID string, svrAddr addr.Addr) Request { } func (uc UseCase) Execute(ctx context.Context, req Request) error { - svr, err := uc.serverRepo.Get(ctx, req.svrAddr) + uc.logger.Info(). + Stringer("addr", req.svrAddr).Str("instance", fmt.Sprintf("% x", req.instanceID)). + Msg("Removing server on request") + + svr, err := uc.getServer(ctx, req.svrAddr) if err != nil { - switch { - case errors.Is(err, repositories.ErrServerNotFound): - uc.logger.Info(). - Stringer("addr", req.svrAddr).Str("instance", fmt.Sprintf("% x", req.instanceID)). - Msg("Removed server not found") - return ErrServerNotFound - default: - return err - } + return err } - inst, err := uc.instanceRepo.Get(ctx, req.instanceID) + inst, err := uc.getInstance(ctx, req.instanceID, svr.Addr) if err != nil { - switch { - // this could be a race condition - ignore - case errors.Is(err, repositories.ErrInstanceNotFound): - uc.logger.Info(). - Stringer("addr", req.svrAddr). - Stringer("server", svr). - Str("instance", fmt.Sprintf("% x", req.instanceID)). - Msg("Instance for removed server not found") - return ErrInstanceNotFound - default: - return err - } - } - // make sure to verify the "owner" of the provided instance id - if inst.Addr.GetDottedIP() != svr.Addr.GetDottedIP() { - return ErrInstanceAddrMismatch + return err } if err = uc.serverRepo.Remove(ctx, svr, func(_ *server.Server) bool { @@ -87,13 +69,55 @@ func (uc UseCase) Execute(ctx context.Context, req Request) error { }); err != nil { return err } - if err = uc.instanceRepo.Remove(ctx, req.instanceID); err != nil { + if err = uc.instanceRepo.Remove(ctx, inst.ID); err != nil { return err } uc.logger.Info(). Stringer("addr", req.svrAddr).Str("instance", fmt.Sprintf("% x", req.instanceID)). - Msg("Successfully removed server on request") + Msg("Removed server on request") return nil } + +func (uc UseCase) getServer(ctx context.Context, svrAddr addr.Addr) (server.Server, error) { + svr, err := uc.serverRepo.Get(ctx, svrAddr) + if err != nil { + if errors.Is(err, repositories.ErrServerNotFound) { + uc.logger.Info().Stringer("addr", svrAddr).Msg("Removed server not found") + return server.Blank, ErrServerNotFound + } + return server.Blank, err + } + return svr, nil +} + +func (uc UseCase) getInstance( + ctx context.Context, + instanceID []byte, + svrAddr addr.Addr, +) (instance.Instance, error) { + instID, err := instance.NewID(instanceID) + if err != nil { + return instance.Blank, err + } + + inst, err := uc.instanceRepo.Get(ctx, instID) + if err != nil { + // this could be a race condition - ignore + if errors.Is(err, repositories.ErrInstanceNotFound) { + uc.logger.Info(). + Str("instance", fmt.Sprintf("% x", instanceID)). + Msg("Instance for removed server not found") + return instance.Blank, ErrInstanceNotFound + } + return instance.Blank, err + } + + // make sure to verify the "owner" of the provided instance id + if inst.Addr.GetDottedIP() != svrAddr.GetDottedIP() { + return instance.Blank, ErrInstanceAddrMismatch + } + + return inst, nil +} diff --git a/internal/core/usecases/removeserver/removeserver_test.go b/internal/core/usecases/removeserver/removeserver_test.go index 0d25a91..11d3c9c 100644 --- a/internal/core/usecases/removeserver/removeserver_test.go +++ b/internal/core/usecases/removeserver/removeserver_test.go @@ -17,6 +17,10 @@ import ( "github.com/sergeii/swat4master/internal/testutils/factories/serverfactory" ) +const DEADBEEF = "\xde\xad\xbe\xef" + +type b []byte + type MockServerRepository struct { mock.Mock repositories.ServerRepository @@ -41,12 +45,12 @@ type MockInstanceRepository struct { repositories.InstanceRepository } -func (m *MockInstanceRepository) Get(ctx context.Context, instanceID string) (instance.Instance, error) { +func (m *MockInstanceRepository) Get(ctx context.Context, instanceID instance.Identifier) (instance.Instance, error) { args := m.Called(ctx, instanceID) return args.Get(0).(instance.Instance), args.Error(1) // nolint: forcetypeassert } -func (m *MockInstanceRepository) Remove(ctx context.Context, instanceID string) error { +func (m *MockInstanceRepository) Remove(ctx context.Context, instanceID instance.Identifier) error { args := m.Called(ctx, instanceID) return args.Error(0) } @@ -56,24 +60,23 @@ func TestRemoveServerUseCase_Success(t *testing.T) { logger := zerolog.Nop() svr := serverfactory.BuildRandom() - inst := instance.MustNew("foo", svr.Addr.GetIP(), svr.Addr.Port) + instID := instance.MustNewID(b(DEADBEEF)) + inst := instance.MustNew(instID, svr.Addr.GetIP(), svr.Addr.Port) serverRepo := new(MockServerRepository) serverRepo.On("Get", ctx, svr.Addr).Return(svr, nil) serverRepo.On("Remove", ctx, svr, mock.Anything).Return(nil) instanceRepo := new(MockInstanceRepository) - instanceRepo.On("Get", ctx, "foo").Return(inst, nil) - instanceRepo.On("Remove", ctx, "foo").Return(nil) + instanceRepo.On("Get", ctx, instID).Return(inst, nil) + instanceRepo.On("Remove", ctx, instID).Return(nil) uc := removeserver.New(serverRepo, instanceRepo, &logger) - err := uc.Execute(ctx, removeserver.NewRequest("foo", svr.Addr)) + err := uc.Execute(ctx, removeserver.NewRequest(b(DEADBEEF), svr.Addr)) assert.NoError(t, err) - serverRepo.AssertCalled(t, "Get", ctx, svr.Addr) - instanceRepo.AssertCalled(t, "Get", ctx, "foo") - serverRepo.AssertCalled(t, "Remove", ctx, svr, mock.Anything) - instanceRepo.AssertCalled(t, "Remove", ctx, "foo") + serverRepo.AssertExpectations(t) + instanceRepo.AssertExpectations(t) } func TestRemoveServerUseCase_ServerAlreadyDeleted(t *testing.T) { @@ -88,7 +91,9 @@ func TestRemoveServerUseCase_ServerAlreadyDeleted(t *testing.T) { instanceRepo := new(MockInstanceRepository) uc := removeserver.New(serverRepo, instanceRepo, &logger) - err := uc.Execute(ctx, removeserver.NewRequest("foo", svr.Addr)) + ucReq := removeserver.NewRequest(b(DEADBEEF), svr.Addr) + + err := uc.Execute(ctx, ucReq) assert.ErrorIs(t, err, removeserver.ErrServerNotFound) serverRepo.AssertCalled(t, "Get", ctx, svr.Addr) @@ -102,20 +107,21 @@ func TestRemoveServerUseCase_InstanceAlreadyDeleted(t *testing.T) { logger := zerolog.Nop() svr := serverfactory.BuildRandom() + instID := instance.MustNewID(b(DEADBEEF)) serverRepo := new(MockServerRepository) serverRepo.On("Get", ctx, svr.Addr).Return(svr, nil) serverRepo.On("Remove", ctx, svr, mock.Anything).Return(nil) instanceRepo := new(MockInstanceRepository) - instanceRepo.On("Get", ctx, "foo").Return(instance.Blank, repositories.ErrInstanceNotFound) + instanceRepo.On("Get", ctx, instID).Return(instance.Blank, repositories.ErrInstanceNotFound) uc := removeserver.New(serverRepo, instanceRepo, &logger) - err := uc.Execute(ctx, removeserver.NewRequest("foo", svr.Addr)) + err := uc.Execute(ctx, removeserver.NewRequest(b(DEADBEEF), svr.Addr)) assert.ErrorIs(t, err, removeserver.ErrInstanceNotFound) serverRepo.AssertCalled(t, "Get", ctx, svr.Addr) - instanceRepo.AssertCalled(t, "Get", ctx, "foo") + instanceRepo.AssertCalled(t, "Get", ctx, instID) serverRepo.AssertNotCalled(t, "Remove", mock.Anything, mock.Anything, mock.Anything) instanceRepo.AssertNotCalled(t, "Remove", mock.Anything, mock.Anything) } @@ -125,22 +131,23 @@ func TestRemoveServerUseCase_InstanceAddrDoesNotMatch(t *testing.T) { logger := zerolog.Nop() svr := serverfactory.BuildRandom() - inst := instance.MustNew("foo", testutils.GenRandomIP(), svr.Addr.Port) + instID := instance.MustNewID(b(DEADBEEF)) + inst := instance.MustNew(instID, testutils.GenRandomIP(), svr.Addr.Port) serverRepo := new(MockServerRepository) serverRepo.On("Get", ctx, svr.Addr).Return(svr, nil) serverRepo.On("Remove", ctx, svr, mock.Anything).Return(nil) instanceRepo := new(MockInstanceRepository) - instanceRepo.On("Get", ctx, "foo").Return(inst, nil) - instanceRepo.On("Remove", ctx, "foo").Return(nil) + instanceRepo.On("Get", ctx, instID).Return(inst, nil) + instanceRepo.On("Remove", ctx, instID).Return(nil) uc := removeserver.New(serverRepo, instanceRepo, &logger) - err := uc.Execute(ctx, removeserver.NewRequest("foo", svr.Addr)) + err := uc.Execute(ctx, removeserver.NewRequest(b(DEADBEEF), svr.Addr)) assert.ErrorIs(t, err, removeserver.ErrInstanceAddrMismatch) serverRepo.AssertCalled(t, "Get", ctx, svr.Addr) - instanceRepo.AssertCalled(t, "Get", ctx, "foo") + instanceRepo.AssertCalled(t, "Get", ctx, instID) serverRepo.AssertNotCalled(t, "Remove", mock.Anything, mock.Anything, mock.Anything) instanceRepo.AssertNotCalled(t, "Remove", mock.Anything, mock.Anything) } diff --git a/internal/core/usecases/renewserver/renewserver.go b/internal/core/usecases/renewserver/renewserver.go index b3dfc59..ffc2b9d 100644 --- a/internal/core/usecases/renewserver/renewserver.go +++ b/internal/core/usecases/renewserver/renewserver.go @@ -7,6 +7,7 @@ import ( "github.com/jonboulle/clockwork" + "github.com/sergeii/swat4master/internal/core/entities/instance" "github.com/sergeii/swat4master/internal/core/entities/server" "github.com/sergeii/swat4master/internal/core/repositories" ) @@ -32,11 +33,11 @@ func New( } type Request struct { - instanceID string + instanceID []byte ipAddr net.IP } -func NewRequest(instanceID string, ipAddr net.IP) Request { +func NewRequest(instanceID []byte, ipAddr net.IP) Request { return Request{ instanceID: instanceID, ipAddr: ipAddr, @@ -44,17 +45,21 @@ func NewRequest(instanceID string, ipAddr net.IP) Request { } func (uc UseCase) Execute(ctx context.Context, req Request) error { - instance, err := uc.instanceRepo.Get(ctx, req.instanceID) + instID, err := instance.NewID(req.instanceID) + if err != nil { + return err + } + inst, err := uc.instanceRepo.Get(ctx, instID) if err != nil { return err } // the addressed must match, otherwise it could be a spoofing attempt - if !instance.Addr.GetIP().Equal(req.ipAddr.To4()) { + if !inst.Addr.GetIP().Equal(req.ipAddr.To4()) { return ErrUnknownInstanceID } - svr, err := uc.serverRepo.Get(ctx, instance.Addr) + svr, err := uc.serverRepo.Get(ctx, inst.Addr) if err != nil { return err } diff --git a/internal/core/usecases/renewserver/renewserver_test.go b/internal/core/usecases/renewserver/renewserver_test.go index e1140f1..dd0c707 100644 --- a/internal/core/usecases/renewserver/renewserver_test.go +++ b/internal/core/usecases/renewserver/renewserver_test.go @@ -18,6 +18,10 @@ import ( "github.com/sergeii/swat4master/internal/testutils/factories/serverfactory" ) +const DEADBEEF = "\xde\xad\xbe\xef" + +type b []byte + type MockServerRepository struct { mock.Mock repositories.ServerRepository @@ -42,7 +46,7 @@ type MockInstanceRepository struct { repositories.InstanceRepository } -func (m *MockInstanceRepository) Get(ctx context.Context, instanceID string) (instance.Instance, error) { +func (m *MockInstanceRepository) Get(ctx context.Context, instanceID instance.Identifier) (instance.Instance, error) { args := m.Called(ctx, instanceID) return args.Get(0).(instance.Instance), args.Error(1) // nolint: forcetypeassert } @@ -52,7 +56,8 @@ func TestRenewServerUseCase_Success(t *testing.T) { clock := clockwork.NewFakeClock() svr := serverfactory.BuildRandom() - inst := instance.MustNew("foo", svr.Addr.GetIP(), svr.Addr.Port) + instID := instance.MustNewID(b(DEADBEEF)) + inst := instance.MustNew(instID, svr.Addr.GetIP(), svr.Addr.Port) clock.Advance(time.Second) passedTime := clock.Now() @@ -65,7 +70,7 @@ func TestRenewServerUseCase_Success(t *testing.T) { instanceRepo.On("Get", ctx, inst.ID).Return(inst, nil) uc := renewserver.New(instanceRepo, serverRepo, clock) - err := uc.Execute(ctx, renewserver.NewRequest(inst.ID, svr.Addr.GetIP())) + err := uc.Execute(ctx, renewserver.NewRequest(b(DEADBEEF), svr.Addr.GetIP())) assert.NoError(t, err) updatedSvr := svr @@ -81,7 +86,8 @@ func TestRenewServerUseCase_InstanceNotFound(t *testing.T) { clock := clockwork.NewFakeClock() svr := serverfactory.BuildRandom() - inst := instance.MustNew("foo", svr.Addr.GetIP(), svr.Addr.Port) + instID := instance.MustNewID(b(DEADBEEF)) + inst := instance.MustNew(instID, svr.Addr.GetIP(), svr.Addr.Port) serverRepo := new(MockServerRepository) serverRepo.On("Get", ctx, svr.Addr).Return(svr, nil) @@ -90,7 +96,7 @@ func TestRenewServerUseCase_InstanceNotFound(t *testing.T) { instanceRepo.On("Get", ctx, inst.ID).Return(instance.Blank, repositories.ErrInstanceNotFound) uc := renewserver.New(instanceRepo, serverRepo, clock) - err := uc.Execute(ctx, renewserver.NewRequest(inst.ID, svr.Addr.GetIP())) + err := uc.Execute(ctx, renewserver.NewRequest(b(DEADBEEF), svr.Addr.GetIP())) assert.ErrorIs(t, err, repositories.ErrInstanceNotFound) instanceRepo.AssertCalled(t, "Get", ctx, inst.ID) @@ -103,7 +109,8 @@ func TestRenewServerUseCase_ServerNotFound(t *testing.T) { clock := clockwork.NewFakeClock() svr := serverfactory.BuildRandom() - inst := instance.MustNew("foo", svr.Addr.GetIP(), svr.Addr.Port) + instID := instance.MustNewID(b(DEADBEEF)) + inst := instance.MustNew(instID, svr.Addr.GetIP(), svr.Addr.Port) serverRepo := new(MockServerRepository) serverRepo.On("Get", ctx, svr.Addr).Return(server.Blank, repositories.ErrServerNotFound) @@ -112,7 +119,7 @@ func TestRenewServerUseCase_ServerNotFound(t *testing.T) { instanceRepo.On("Get", ctx, inst.ID).Return(inst, nil) uc := renewserver.New(instanceRepo, serverRepo, clock) - err := uc.Execute(ctx, renewserver.NewRequest(inst.ID, svr.Addr.GetIP())) + err := uc.Execute(ctx, renewserver.NewRequest(b(DEADBEEF), svr.Addr.GetIP())) assert.ErrorIs(t, err, repositories.ErrServerNotFound) instanceRepo.AssertCalled(t, "Get", ctx, inst.ID) @@ -125,7 +132,8 @@ func TestRenewServerUseCase_InstanceAddressMismatch(t *testing.T) { clock := clockwork.NewFakeClock() svr := serverfactory.BuildRandom() - inst := instance.MustNew("foo", testutils.GenRandomIP(), svr.Addr.Port) + instID := instance.MustNewID(b(DEADBEEF)) + inst := instance.MustNew(instID, testutils.GenRandomIP(), svr.Addr.Port) serverRepo := new(MockServerRepository) serverRepo.On("Get", ctx, svr.Addr).Return(server.Blank, repositories.ErrServerNotFound) @@ -134,7 +142,7 @@ func TestRenewServerUseCase_InstanceAddressMismatch(t *testing.T) { instanceRepo.On("Get", ctx, inst.ID).Return(inst, nil) uc := renewserver.New(instanceRepo, serverRepo, clock) - err := uc.Execute(ctx, renewserver.NewRequest(inst.ID, svr.Addr.GetIP())) + err := uc.Execute(ctx, renewserver.NewRequest(b(DEADBEEF), svr.Addr.GetIP())) assert.ErrorIs(t, err, renewserver.ErrUnknownInstanceID) instanceRepo.AssertCalled(t, "Get", ctx, inst.ID) diff --git a/internal/core/usecases/reportserver/reportserver.go b/internal/core/usecases/reportserver/reportserver.go index 7e8ba50..c486776 100644 --- a/internal/core/usecases/reportserver/reportserver.go +++ b/internal/core/usecases/reportserver/reportserver.go @@ -61,14 +61,14 @@ func New( type Request struct { svrAddr addr.Addr queryPort int - instanceID string + instanceID []byte fields map[string]string } func NewRequest( svrAddr addr.Addr, queryPort int, - instanceID string, + instanceID []byte, fields map[string]string, ) Request { return Request{ @@ -85,12 +85,8 @@ func (uc UseCase) Execute(ctx context.Context, req Request) error { return err } - inst, err := instance.New(req.instanceID, svr.Addr.GetIP(), svr.Addr.Port) + inst, err := uc.prepareInstance(req.instanceID, req.svrAddr) if err != nil { - uc.logger.Error(). - Err(err). - Stringer("addr", req.svrAddr).Str("instance", fmt.Sprintf("% x", req.instanceID)). - Msg("Failed to create an instance") return err } @@ -206,3 +202,21 @@ func (uc UseCase) obtainServerByAddr( } return svr, nil } + +func (uc UseCase) prepareInstance(instanceID []byte, svrAddr addr.Addr) (instance.Instance, error) { + instID, err := instance.NewID(instanceID) + if err != nil { + return instance.Blank, err + } + + inst, err := instance.New(instID, svrAddr.GetIP(), svrAddr.Port) + if err != nil { + uc.logger.Error(). + Err(err). + Stringer("addr", svrAddr).Str("instance", fmt.Sprintf("% x", instanceID)). + Msg("Failed to create an instance") + return instance.Blank, err + } + + return inst, nil +} diff --git a/internal/core/usecases/reportserver/reportserver_test.go b/internal/core/usecases/reportserver/reportserver_test.go index 93e41c5..813ad82 100644 --- a/internal/core/usecases/reportserver/reportserver_test.go +++ b/internal/core/usecases/reportserver/reportserver_test.go @@ -25,6 +25,10 @@ import ( "github.com/sergeii/swat4master/internal/validation" ) +const DEADBEEF = "\xde\xad\xbe\xef" + +type b []byte + type MockServerRepository struct { mock.Mock repositories.ServerRepository @@ -104,7 +108,7 @@ func TestReportServerUseCase_ReportNewServer(t *testing.T) { } uc := reportserver.New(serverRepo, instanceRepo, probeRepo, ucOpts, validate, collector, clock, &logger) - req := reportserver.NewRequest(svrAddr, svrQueryPort, "foo", svrParams) + req := reportserver.NewRequest(svrAddr, svrQueryPort, b(DEADBEEF), svrParams) err := uc.Execute(ctx, req) assert.NoError(t, err) @@ -130,7 +134,7 @@ func TestReportServerUseCase_ReportNewServer(t *testing.T) { ctx, mock.MatchedBy(func(createdInstance instance.Instance) bool { hasAddr := createdInstance.Addr == svrAddr - hasID := createdInstance.ID == "foo" + hasID := createdInstance.ID == instance.MustNewID(b(DEADBEEF)) return hasAddr && hasID }), ) @@ -236,7 +240,7 @@ func TestReportServerUseCase_ReportExistingServer(t *testing.T) { MaxProbeRetries: 3, } uc := reportserver.New(serverRepo, instanceRepo, probeRepo, ucOpts, validate, collector, clock, &logger) - req := reportserver.NewRequest(svr.Addr, svr.QueryPort, "foo", updatedParams) + req := reportserver.NewRequest(svr.Addr, svr.QueryPort, b(DEADBEEF), updatedParams) err := uc.Execute(ctx, req) assert.NoError(t, err) @@ -262,7 +266,7 @@ func TestReportServerUseCase_ReportExistingServer(t *testing.T) { ctx, mock.MatchedBy(func(createdInstance instance.Instance) bool { hasAddr := createdInstance.Addr == svr.Addr - hasID := createdInstance.ID == "foo" + hasID := createdInstance.ID == instance.MustNewID(b(DEADBEEF)) return hasAddr && hasID }), ) @@ -351,7 +355,7 @@ func TestReportServerUseCase_InvalidFields(t *testing.T) { MaxProbeRetries: 3, } uc := reportserver.New(serverRepo, instanceRepo, probeRepo, ucOpts, validate, collector, clock, &logger) - req := reportserver.NewRequest(svrAddr, svrQueryPort, "foo", tt.params) + req := reportserver.NewRequest(svrAddr, svrQueryPort, b(DEADBEEF), tt.params) err := uc.Execute(ctx, req) assert.ErrorIs(t, err, reportserver.ErrInvalidRequestPayload) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 9c9b6ba..412050a 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -26,8 +26,8 @@ type Collector struct { BrowserSent prometheus.Counter BrowserDurations prometheus.Histogram - CleanerRemovals prometheus.Counter - CleanerErrors prometheus.Counter + CleanerRemovals *prometheus.CounterVec + CleanerErrors *prometheus.CounterVec DiscoveryWorkersBusy prometheus.Gauge DiscoveryWorkersAvailable prometheus.Gauge @@ -105,14 +105,14 @@ func New() *Collector { Name: "browser_duration_seconds", Help: "Duration of server browsing requests", }), - CleanerRemovals: promauto.With(registry).NewCounter(prometheus.CounterOpts{ + CleanerRemovals: promauto.With(registry).NewCounterVec(prometheus.CounterOpts{ Name: "cleaner_removals_total", Help: "The total number of inactive servers removed", - }), - CleanerErrors: promauto.With(registry).NewCounter(prometheus.CounterOpts{ + }, []string{"kind"}), + CleanerErrors: promauto.With(registry).NewCounterVec(prometheus.CounterOpts{ Name: "cleaner_errors_total", Help: "The total number of errors occurred during cleaner runs", - }), + }, []string{"kind"}), DiscoveryWorkersBusy: promauto.With(registry).NewGauge(prometheus.GaugeOpts{ Name: "discovery_busy_workers", Help: "The total number of busy discovery workers", diff --git a/internal/persistence/redis/instances/instances.go b/internal/persistence/redis/instances/instances.go index 478b152..3988260 100644 --- a/internal/persistence/redis/instances/instances.go +++ b/internal/persistence/redis/instances/instances.go @@ -2,10 +2,10 @@ package instances import ( "context" - "encoding/hex" "encoding/json" "errors" "fmt" + "net" "strconv" "github.com/jonboulle/clockwork" @@ -22,6 +22,12 @@ const ( updatesKey = "instances:updated" ) +type storedInstance struct { + ID [4]byte `json:"id"` + IP net.IP `json:"ip"` + Port int `json:"port"` +} + type Repository struct { client *redis.Client clock clockwork.Clock @@ -40,13 +46,12 @@ func (r *Repository) Add(ctx context.Context, ins instance.Instance) error { return err } _, err = r.client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - hexID := encodeID(ins.ID) // Add or update the instance in the hash set - pipe.HSet(ctx, itemsKey, hexID, item) + pipe.HSet(ctx, itemsKey, ins.ID.Hex(), item) // Update the timestamp in the sorted set pipe.ZAdd(ctx, updatesKey, redis.Z{ Score: float64(r.clock.Now().UnixNano()), - Member: hexID, + Member: ins.ID.Hex(), }) return nil }) @@ -56,8 +61,8 @@ func (r *Repository) Add(ctx context.Context, ins instance.Instance) error { return nil } -func (r *Repository) Get(ctx context.Context, id string) (instance.Instance, error) { - item, err := r.client.HGet(ctx, itemsKey, encodeID(id)).Result() +func (r *Repository) Get(ctx context.Context, id instance.Identifier) (instance.Instance, error) { + item, err := r.client.HGet(ctx, itemsKey, id.Hex()).Result() if err != nil { if errors.Is(err, redis.Nil) { return instance.Blank, repositories.ErrInstanceNotFound @@ -67,11 +72,10 @@ func (r *Repository) Get(ctx context.Context, id string) (instance.Instance, err return decodeInstance(item) } -func (r *Repository) Remove(ctx context.Context, id string) error { +func (r *Repository) Remove(ctx context.Context, id instance.Identifier) error { _, err := r.client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - hexID := encodeID(id) - pipe.HDel(ctx, itemsKey, hexID) - pipe.ZRem(ctx, updatesKey, hexID) + pipe.HDel(ctx, itemsKey, id.Hex()) + pipe.ZRem(ctx, updatesKey, id.Hex()) return nil }) if err != nil { @@ -125,21 +129,12 @@ func (r *Repository) Count(ctx context.Context) (int, error) { return int(count), nil } -func encodeID(id string) string { - return fmt.Sprintf("%x", id) -} - -func decodeID(hexID string) (string, error) { - id, err := hex.DecodeString(hexID) - if err != nil { - return "", fmt.Errorf("failed to decode hex id: %w", err) - } - return string(id), nil -} - func encodeInstance(ins instance.Instance) ([]byte, error) { - item := newStoredItem(ins.ID, ins.Addr) - encoded, err := json.Marshal(item) + encoded, err := json.Marshal(storedInstance{ + ID: ins.ID, + IP: ins.Addr.GetIP(), + Port: ins.Addr.Port, + }) if err != nil { return nil, fmt.Errorf("failed to marshal instance item: %w", err) } @@ -147,13 +142,13 @@ func encodeInstance(ins instance.Instance) ([]byte, error) { } func decodeInstance(val interface{}) (instance.Instance, error) { - var item storedItem + var decoded storedInstance encoded, ok := val.(string) if !ok { return instance.Blank, fmt.Errorf("unexpected type %T, %v", val, val) } - if err := json.Unmarshal([]byte(encoded), &item); err != nil { + if err := json.Unmarshal([]byte(encoded), &decoded); err != nil { return instance.Blank, fmt.Errorf("failed to unmarshal instance item: %w", err) } - return item.convert() + return instance.New(decoded.ID, decoded.IP, decoded.Port) } diff --git a/internal/persistence/redis/instances/instances_test.go b/internal/persistence/redis/instances/instances_test.go index 1a5abb4..d29b87b 100644 --- a/internal/persistence/redis/instances/instances_test.go +++ b/internal/persistence/redis/instances/instances_test.go @@ -16,14 +16,24 @@ import ( "github.com/sergeii/swat4master/internal/core/entities/filterset" "github.com/sergeii/swat4master/internal/core/entities/instance" "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/persistence/redis/instances" "github.com/sergeii/swat4master/internal/testutils" "github.com/sergeii/swat4master/internal/testutils/factories/instancefactory" "github.com/sergeii/swat4master/internal/testutils/testredis" ) +const ( + DEADBEEF = "\xde\xad\xbe\xef" + FEEDFOOD = "\xfe\xed\xf0\x0d" + CAFEBABE = "\xca\xfe\xba\xbe" + BAADCODE = "\xba\xad\xc0\xde" +) + +type b []byte + type storedItem struct { - ID string `json:"id"` + ID []byte `json:"id"` IP net.IP `json:"ip"` Port int `json:"port"` } @@ -53,9 +63,9 @@ func collectStorageState(ctx context.Context, rdb *redis.Client) storageState { items := make(map[string]instance.Instance) for k, v := range hItems { var item storedItem - id := string(testutils.Must(hex.DecodeString(k))) + id4bytes := testutils.Must(hex.DecodeString(k)) testutils.MustNoError(json.Unmarshal([]byte(v), &item)) - items[id] = instance.MustNew(id, item.IP, item.Port) + items[string(id4bytes)] = instance.MustNew(instance.MustNewID(id4bytes), item.IP, item.Port) } return storageState{ @@ -75,7 +85,10 @@ func TestInstancesRedisRepo_Add_New(t *testing.T) { repo := instances.New(rdb, c) // ...and a new instance to add - ins1 := instancefactory.Build(instancefactory.WithID("foo"), instancefactory.WithRandomServerAddress()) + ins1 := instancefactory.Build( + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithRandomServerAddress(), + ) // When adding the instance to the repository err := repo.Add(ctx, ins1) @@ -84,40 +97,44 @@ func TestInstancesRedisRepo_Add_New(t *testing.T) { // Then the instance is stored in redis state := collectStorageState(ctx, rdb) require.Len(t, state.Items, 1) - require.Equal(t, ins1, state.Items["foo"]) + require.Equal(t, ins1, state.Items[DEADBEEF]) // And the update time is stored in the sorted set require.Len(t, state.Updates, 1) - assert.Equal(t, "foo", state.Updates[0].ID) + assert.Equal(t, DEADBEEF, state.Updates[0].ID) assert.Equal(t, float64(now.UnixNano()), state.Updates[0].Time) // When another instance is added at a later time c.Advance(time.Millisecond * 100) - ins2 := instancefactory.Build(instancefactory.WithID("bar"), instancefactory.WithRandomServerAddress()) + ins2 := instancefactory.Build( + instancefactory.WithStringID(FEEDFOOD), + instancefactory.WithRandomServerAddress(), + ) err = repo.Add(ctx, ins2) require.NoError(t, err) // Then the second instance is also stored in redis state = collectStorageState(ctx, rdb) require.Len(t, state.Items, 2) - require.Equal(t, ins2, state.Items["bar"]) + require.Equal(t, ins2, state.Items[FEEDFOOD]) // And the added instances are sorted by the time of addition require.Len(t, state.Updates, 2) - assert.Equal(t, "bar", state.Updates[1].ID) + assert.Equal(t, FEEDFOOD, state.Updates[1].ID) assert.Equal(t, float64(c.Now().UnixNano()), state.Updates[1].Time) // When another instance is added at the same time as the last one - ins3 := instancefactory.Build(instancefactory.WithID("baz"), instancefactory.WithRandomServerAddress()) + ins3 := instancefactory.Build( + instancefactory.WithStringID("baz3"), + instancefactory.WithRandomServerAddress(), + ) err = repo.Add(ctx, ins3) require.NoError(t, err) // Then the third instance is also stored in redis state = collectStorageState(ctx, rdb) require.Len(t, state.Items, 3) - require.Equal(t, ins3, state.Items["baz"]) + require.Equal(t, ins3, state.Items["baz3"]) // And the added instances are sorted by the time of addition require.Len(t, state.Updates, 3) - assert.Equal(t, "baz", state.Updates[2].ID) - assert.Equal(t, float64(c.Now().UnixNano()), state.Updates[2].Time) } func TestInstancesRedisRepo_Add_Existing(t *testing.T) { @@ -129,26 +146,35 @@ func TestInstancesRedisRepo_Add_Existing(t *testing.T) { // Given a repository... repo := instances.New(rdb, c) // ...with an instance previously added - ins := instancefactory.Build(instancefactory.WithID("foo"), instancefactory.WithServerAddress("1.1.1.1", 10480)) + ins := instancefactory.Build( + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithServerAddress("1.1.1.1", 10480), + ) err := repo.Add(ctx, ins) require.NoError(t, err) // And the instance is stored in the storage state := collectStorageState(ctx, rdb) require.Len(t, state.Items, 1) - require.Equal(t, ins, state.Items["foo"]) + require.Equal(t, ins, state.Items[DEADBEEF]) require.Len(t, state.Updates, 1) assert.Equal(t, float64(then.UnixNano()), state.Updates[0].Time) // When adding another instance with the same ID at a later time c.Advance(time.Millisecond * 100) - other := instancefactory.Build(instancefactory.WithID("foo"), instancefactory.WithServerAddress("2.2.2.2", 10580)) + other := instancefactory.Build( + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithServerAddress( + "2.2.2.2", + 10580, + ), + ) err = repo.Add(ctx, other) require.NoError(t, err) // Then the instance is replaced in the storage state = collectStorageState(ctx, rdb) require.Len(t, state.Items, 1) - assert.Equal(t, other, state.Items["foo"]) + assert.Equal(t, other, state.Items[DEADBEEF]) assert.Len(t, state.Updates, 1) assert.Equal(t, float64(then.Add(time.Millisecond*100).UnixNano()), state.Updates[0].Time) } @@ -161,10 +187,16 @@ func TestInstancesRedisRepo_Get_OK(t *testing.T) { // Given a repository with 2 instances added at the same time... repo := instances.New(rdb, c) - ins1 := instancefactory.Build(instancefactory.WithID("foo"), instancefactory.WithRandomServerAddress()) - ins2 := instancefactory.Build(instancefactory.WithID("bar"), instancefactory.WithRandomServerAddress()) + ins1 := instancefactory.Build( + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithRandomServerAddress(), + ) + ins2 := instancefactory.Build( + instancefactory.WithStringID(FEEDFOOD), + instancefactory.WithRandomServerAddress(), + ) ins3 := instancefactory.Build( - instancefactory.WithID(string([]byte{0xfe, 0xed, 0xf0, 0x0d})), + instancefactory.WithStringID(CAFEBABE), instancefactory.WithRandomServerAddress(), ) @@ -177,21 +209,21 @@ func TestInstancesRedisRepo_Get_OK(t *testing.T) { // When retrieving an instance by ID for _, pair := range []struct { - id string + id []byte want instance.Instance }{ - {"foo", ins1}, - {"bar", ins2}, - {string([]byte{0xfe, 0xed, 0xf0, 0x0d}), ins3}, + {b(DEADBEEF), ins1}, + {b(FEEDFOOD), ins2}, + {b(CAFEBABE), ins3}, } { - got, err := repo.Get(ctx, pair.id) + got, err := repo.Get(ctx, instance.MustNewID(pair.id)) // Then the instance should be retrieved successfully require.NoError(t, err) assert.Equal(t, pair.want, got) } // When retrieving a non-existent instance - _, err := repo.Get(ctx, "qux") + _, err := repo.Get(ctx, instance.MustNewID(b(BAADCODE))) // Then the operation should fail with an error assert.ErrorIs(t, err, repositories.ErrInstanceNotFound) } @@ -199,23 +231,23 @@ func TestInstancesRedisRepo_Get_OK(t *testing.T) { func TestInstancesRedisRepo_Remove_OK(t *testing.T) { tests := []struct { name string - id string + id []byte wantRemainingIDs []string }{ { - name: "remove 'foo'", - id: "foo", - wantRemainingIDs: []string{"bar", "baz"}, + name: "remove DEADBEEF", + id: b(DEADBEEF), + wantRemainingIDs: []string{FEEDFOOD, CAFEBABE}, }, { - name: "remove 'bar'", - id: "bar", - wantRemainingIDs: []string{"foo", "baz"}, + name: "remove FEEDFOOD", + id: b(FEEDFOOD), + wantRemainingIDs: []string{DEADBEEF, CAFEBABE}, }, { name: "remove non-existent", - id: "qux", - wantRemainingIDs: []string{"foo", "bar", "baz"}, + id: b(BAADCODE), + wantRemainingIDs: []string{DEADBEEF, FEEDFOOD, CAFEBABE}, }, } @@ -229,20 +261,29 @@ func TestInstancesRedisRepo_Remove_OK(t *testing.T) { repo := instances.New(rdb, c) for _, ins := range []instance.Instance{ - instancefactory.Build(instancefactory.WithID("foo"), instancefactory.WithRandomServerAddress()), - instancefactory.Build(instancefactory.WithID("bar"), instancefactory.WithRandomServerAddress()), + instancefactory.Build( + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithRandomServerAddress(), + ), + instancefactory.Build( + instancefactory.WithStringID(FEEDFOOD), + instancefactory.WithRandomServerAddress(), + ), } { c.Advance(time.Millisecond * 100) err := repo.Add(ctx, ins) require.NoError(t, err) } // And another instance added at the same time as the last one - ins := instancefactory.Build(instancefactory.WithID("baz"), instancefactory.WithRandomServerAddress()) + ins := instancefactory.Build( + instancefactory.WithStringID(CAFEBABE), + instancefactory.WithRandomServerAddress(), + ) err := repo.Add(ctx, ins) require.NoError(t, err) // When removing an instance by ID - err = repo.Remove(ctx, tt.id) + err = repo.Remove(ctx, instance.MustNewID(tt.id)) require.NoError(t, err) // Then the instance should be removed from the storage @@ -279,7 +320,7 @@ func TestInstancesRedisRepo_Clear_OK(t *testing.T) { return fs.UpdatedBefore(now) }, wantAffected: 0, - wantRemainingIDs: []string{"foo", "bar", "baz", "qux"}, + wantRemainingIDs: []string{DEADBEEF, FEEDFOOD, CAFEBABE, BAADCODE}, }, { name: "filter by time in the future", @@ -290,28 +331,28 @@ func TestInstancesRedisRepo_Clear_OK(t *testing.T) { wantRemainingIDs: []string{}, }, { - name: "filter by time before 'baz' and 'qux'", + name: "filter by time before CAFEBABE and BAADCODE", factory: func(fs filterset.InstanceFilterSet, now time.Time) filterset.InstanceFilterSet { return fs.UpdatedBefore(now.Add(time.Millisecond * 200)) }, wantAffected: 2, - wantRemainingIDs: []string{"baz", "qux"}, + wantRemainingIDs: []string{CAFEBABE, BAADCODE}, }, { - name: "filter by time before 'bar'", + name: "filter by time before FEEDFOOD", factory: func(fs filterset.InstanceFilterSet, now time.Time) filterset.InstanceFilterSet { return fs.UpdatedBefore(now.Add(time.Millisecond * 100)) }, wantAffected: 1, - wantRemainingIDs: []string{"bar", "baz", "qux"}, + wantRemainingIDs: []string{FEEDFOOD, CAFEBABE, BAADCODE}, }, { - name: "filter by time before 'foo'", + name: "filter by time before DEADBEEF", factory: func(fs filterset.InstanceFilterSet, now time.Time) filterset.InstanceFilterSet { return fs.UpdatedBefore(now.Add(time.Millisecond * 99)) }, wantAffected: 0, - wantRemainingIDs: []string{"foo", "bar", "baz", "qux"}, + wantRemainingIDs: []string{DEADBEEF, FEEDFOOD, CAFEBABE, BAADCODE}, }, } for _, tt := range tests { @@ -325,16 +366,28 @@ func TestInstancesRedisRepo_Clear_OK(t *testing.T) { before := c.Now() for _, ins := range []instance.Instance{ - instancefactory.Build(instancefactory.WithID("foo"), instancefactory.WithRandomServerAddress()), - instancefactory.Build(instancefactory.WithID("bar"), instancefactory.WithRandomServerAddress()), - instancefactory.Build(instancefactory.WithID("baz"), instancefactory.WithRandomServerAddress()), + instancefactory.Build( + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithRandomServerAddress(), + ), + instancefactory.Build( + instancefactory.WithStringID(FEEDFOOD), + instancefactory.WithRandomServerAddress(), + ), + instancefactory.Build( + instancefactory.WithStringID(CAFEBABE), + instancefactory.WithRandomServerAddress(), + ), } { c.Advance(time.Millisecond * 100) err := repo.Add(ctx, ins) require.NoError(t, err) } // and another instance added at the same time as the last one - ins := instancefactory.Build(instancefactory.WithID("qux"), instancefactory.WithRandomServerAddress()) + ins := instancefactory.Build( + instancefactory.WithStringID(BAADCODE), + instancefactory.WithRandomServerAddress(), + ) err := repo.Add(ctx, ins) require.NoError(t, err) @@ -404,9 +457,18 @@ func TestInstancesRedisRepo_Count_OK(t *testing.T) { repo := instances.New(rdb, c) // Given a repository with 3 instances - ins1 := instancefactory.Build(instancefactory.WithID("foo"), instancefactory.WithRandomServerAddress()) - ins2 := instancefactory.Build(instancefactory.WithID("bar"), instancefactory.WithRandomServerAddress()) - ins3 := instancefactory.Build(instancefactory.WithID("baz"), instancefactory.WithRandomServerAddress()) + ins1 := instancefactory.Build( + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithRandomServerAddress(), + ) + ins2 := instancefactory.Build( + instancefactory.WithStringID(FEEDFOOD), + instancefactory.WithRandomServerAddress(), + ) + ins3 := instancefactory.Build( + instancefactory.WithStringID(CAFEBABE), + instancefactory.WithRandomServerAddress(), + ) for _, ins := range []instance.Instance{ins1, ins2, ins3} { err := repo.Add(ctx, ins) diff --git a/internal/persistence/redis/instances/item.go b/internal/persistence/redis/instances/item.go deleted file mode 100644 index a614f0a..0000000 --- a/internal/persistence/redis/instances/item.go +++ /dev/null @@ -1,30 +0,0 @@ -package instances - -import ( - "net" - - "github.com/sergeii/swat4master/internal/core/entities/addr" - "github.com/sergeii/swat4master/internal/core/entities/instance" -) - -type storedItem struct { - ID string `json:"id"` - IP net.IP `json:"ip"` - Port int `json:"port"` -} - -func newStoredItem(id string, addr addr.Addr) storedItem { - return storedItem{ - ID: encodeID(id), - IP: addr.GetIP(), - Port: addr.Port, - } -} - -func (i storedItem) convert() (instance.Instance, error) { - id, err := decodeID(i.ID) - if err != nil { - return instance.Blank, err - } - return instance.New(id, i.IP, i.Port) -} diff --git a/internal/reporter/dispatcher.go b/internal/reporter/dispatcher.go index d71fa6e..c30c60a 100644 --- a/internal/reporter/dispatcher.go +++ b/internal/reporter/dispatcher.go @@ -111,9 +111,11 @@ func (d *Dispatcher) selectHandler(msgType master.Msg) (Handler, error) { return nil, fmt.Errorf("no associated handler for message type '%s'", msgType) } -func ParseInstanceID(payload []byte) (string, []byte, error) { +func ParseInstanceID(payload []byte) ([]byte, []byte, error) { if len(payload) < 5 { - return "", nil, fmt.Errorf("invalid payload length %d", len(payload)) + return nil, nil, fmt.Errorf("invalid payload length %d", len(payload)) } - return string(payload[1:5]), payload[5:], nil + id := make([]byte, 4) + copy(id, payload[1:5]) + return id, payload[5:], nil } diff --git a/internal/reporter/handlers/heartbeat/heartbeat.go b/internal/reporter/handlers/heartbeat/heartbeat.go index bd774dd..4040051 100644 --- a/internal/reporter/handlers/heartbeat/heartbeat.go +++ b/internal/reporter/handlers/heartbeat/heartbeat.go @@ -77,7 +77,7 @@ func (h Handler) reportServer( connAddr *net.UDPAddr, svrAddr addr.Addr, queryPort int, - instanceID string, + instanceID []byte, fields map[string]string, ) ([]byte, error) { req := reportserver.NewRequest(svrAddr, queryPort, instanceID, fields) @@ -104,7 +104,7 @@ func (h Handler) reportServer( func (h Handler) removeServer( ctx context.Context, svrAddr addr.Addr, - instanceID string, + instanceID []byte, ) ([]byte, error) { req := removeserver.NewRequest(instanceID, svrAddr) if err := h.removeServerUC.Execute(ctx, req); err != nil { diff --git a/internal/testutils/factories/instancefactory/instancefactory.go b/internal/testutils/factories/instancefactory/instancefactory.go index 75deb4e..3639225 100644 --- a/internal/testutils/factories/instancefactory/instancefactory.go +++ b/internal/testutils/factories/instancefactory/instancefactory.go @@ -11,22 +11,34 @@ import ( ) type BuildParams struct { - ID string + ID instance.Identifier IP string Port int } type BuildOption func(*BuildParams) -func WithID(id string) BuildOption { +func WithID(id instance.Identifier) BuildOption { return func(p *BuildParams) { p.ID = id } } +func WithStringID(id string) BuildOption { + return func(p *BuildParams) { + p.ID = instance.MustNewID([]byte(id)) + } +} + +func WithBytesID(id []byte) BuildOption { + return func(p *BuildParams) { + p.ID = instance.MustNewID(id) + } +} + func WithRandomID() BuildOption { return func(p *BuildParams) { - p.ID = string(random.RandBytes(4)) + p.ID = instance.MustNewID(random.RandBytes(4)) } } @@ -46,7 +58,7 @@ func WithRandomServerAddress() BuildOption { func Build(opts ...BuildOption) instance.Instance { params := BuildParams{ - ID: "foo", + ID: instance.MustNewID([]byte("test")), IP: "1.1.1.1", Port: 10480, } @@ -68,3 +80,12 @@ func Save( } return ins } + +func Create( + ctx context.Context, + repo repositories.InstanceRepository, + opts ...BuildOption, +) instance.Instance { + svr := Build(opts...) + return Save(ctx, repo, svr) +} diff --git a/tests/modules/cleaner_test.go b/tests/modules/cleaner_test.go index 1095d5b..54f712e 100644 --- a/tests/modules/cleaner_test.go +++ b/tests/modules/cleaner_test.go @@ -2,12 +2,12 @@ package modules_test import ( "context" - "net" "testing" "time" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/fx" "github.com/sergeii/swat4master/cmd/swat4master/application" @@ -21,6 +21,15 @@ import ( "github.com/sergeii/swat4master/tests/testapp" ) +const ( + DEADBEEF = "\xde\xad\xbe\xef" + FEEDFOOD = "\xfe\xed\xf0\x0d" + CAFEBABE = "\xca\xfe\xba\xbe" + BAADCODE = "\xba\xad\xc0\xde" +) + +type b []byte + func makeAppWithCleaner(extra ...fx.Option) (*fx.App, func()) { fxopts := []fx.Option{ fx.Provide(testapp.ProvidePersistence), @@ -54,13 +63,24 @@ func TestCleaner_OK(t *testing.T) { defer cancel() app.Start(ctx) // nolint: errcheck - ins1 := instance.MustNew("foo", net.ParseIP("1.1.1.1"), 10480) - ins2 := instance.MustNew("bar", net.ParseIP("3.3.3.3"), 10480) - ins4 := instance.MustNew("baz", net.ParseIP("4.4.4.4"), 10480) - - instancefactory.Save(ctx, instanceRepo, ins1) - instancefactory.Save(ctx, instanceRepo, ins2) - instancefactory.Save(ctx, instanceRepo, ins4) + ins1 := instancefactory.Create( + ctx, + instanceRepo, + instancefactory.WithStringID(DEADBEEF), + instancefactory.WithServerAddress("1.1.1.1", 10480), + ) + instancefactory.Create( + ctx, + instanceRepo, + instancefactory.WithStringID(FEEDFOOD), + instancefactory.WithServerAddress("3.3.3.3", 10480), + ) + instancefactory.Create( + ctx, + instanceRepo, + instancefactory.WithStringID(CAFEBABE), + instancefactory.WithServerAddress("4.4.4.4", 10480), + ) gs1 := serverfactory.Create( ctx, @@ -95,39 +115,48 @@ func TestCleaner_OK(t *testing.T) { serverfactory.WithAddress("5.5.5.5", 10480), serverfactory.WithQueryPort(10481), ) + instancefactory.Create( + ctx, + instanceRepo, + instancefactory.WithStringID(BAADCODE), + instancefactory.WithServerAddress("5.5.5.5", 10480), + ) - ins5 := instance.MustNew("qux", net.ParseIP("5.5.5.5"), 10480) - instancefactory.Save(ctx, instanceRepo, ins5) + // refresh the first instance to prevent it from being cleaned + instanceRepo.Add(ctx, ins1) // nolint: errcheck // wait for cleaner to clean servers 2 and 3 <-time.After(time.Millisecond * 150) // check that the refreshed server and the new one are still there svrCount, err := serverRepo.Count(ctx) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 2, svrCount) _, err = serverRepo.Get(ctx, gs1.Addr) - assert.NoError(t, err) + require.NoError(t, err) _, err = serverRepo.Get(ctx, gs5.Addr) - assert.NoError(t, err) + require.NoError(t, err) - // no instance was removed + // 2 instances should be cleaned up insCount, err := instanceRepo.Count(ctx) - assert.NoError(t, err) - assert.Equal(t, 4, insCount) - - _, err = instanceRepo.Get(ctx, "foo") - assert.NoError(t, err) - _, err = instanceRepo.Get(ctx, "baz") - assert.NoError(t, err) - _, err = instanceRepo.Get(ctx, "qux") - assert.NoError(t, err) - - removalValue := testutil.ToFloat64(collector.CleanerRemovals) - assert.Equal(t, 2.0, removalValue) - errorValue := testutil.ToFloat64(collector.CleanerErrors) - assert.Equal(t, 0.0, errorValue) + require.NoError(t, err) + assert.Equal(t, 2, insCount) + + _, err = instanceRepo.Get(ctx, instance.MustNewID(b(DEADBEEF))) + require.NoError(t, err) + _, err = instanceRepo.Get(ctx, instance.MustNewID(b(BAADCODE))) + require.NoError(t, err) + + serverRemovalsValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues("servers")) + assert.Equal(t, 2.0, serverRemovalsValue) + serverErrorsValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues("servers")) + assert.Equal(t, 0.0, serverErrorsValue) + + instanceRemovalsValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues("instances")) + assert.Equal(t, 2.0, instanceRemovalsValue) + instanceErrorsValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues("instances")) + assert.Equal(t, 0.0, instanceErrorsValue) } func TestCleaner_NoErrorWhenNothingToClean(t *testing.T) { @@ -143,8 +172,10 @@ func TestCleaner_NoErrorWhenNothingToClean(t *testing.T) { // wait for cleaner to run some cycles <-time.After(time.Millisecond * 100) - removalValue := testutil.ToFloat64(collector.CleanerRemovals) - errorValue := testutil.ToFloat64(collector.CleanerErrors) - assert.Equal(t, 0.0, removalValue) - assert.Equal(t, 0.0, errorValue) + for _, kind := range []string{"servers", "instances"} { + removalValue := testutil.ToFloat64(collector.CleanerRemovals.WithLabelValues(kind)) + errorValue := testutil.ToFloat64(collector.CleanerErrors.WithLabelValues(kind)) + assert.Equal(t, 0.0, removalValue) + assert.Equal(t, 0.0, errorValue) + } } diff --git a/tests/modules/exporter_test.go b/tests/modules/exporter_test.go index 41f0a5a..21a8c5b 100644 --- a/tests/modules/exporter_test.go +++ b/tests/modules/exporter_test.go @@ -25,12 +25,12 @@ import ( "github.com/sergeii/swat4master/cmd/swat4master/modules/reporter" "github.com/sergeii/swat4master/internal/core/entities/addr" ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" - "github.com/sergeii/swat4master/internal/core/entities/instance" "github.com/sergeii/swat4master/internal/core/entities/probe" "github.com/sergeii/swat4master/internal/core/entities/server" "github.com/sergeii/swat4master/internal/core/repositories" "github.com/sergeii/swat4master/internal/testutils" "github.com/sergeii/swat4master/internal/testutils/factories/infofactory" + "github.com/sergeii/swat4master/internal/testutils/factories/instancefactory" "github.com/sergeii/swat4master/pkg/gamespy/serverquery/gs1" "github.com/sergeii/swat4master/tests/testapp" ) @@ -288,15 +288,15 @@ func TestExporter_ReposMetrics(t *testing.T) { serversRepo.Add(ctx, svr3, repositories.ServerOnConflictIgnore) // nolint: errcheck // instances - ins1 := instance.MustNew("foo", net.ParseIP("1.1.1.1"), 10480) - ins2 := instance.MustNew("bar", net.ParseIP("2.2.2.2"), 10480) - instancesRepo.Add(ctx, ins1) // nolint: errcheck - instancesRepo.Add(ctx, ins2) // nolint: errcheck + for range 2 { + inst := instancefactory.Build(instancefactory.WithRandomID(), instancefactory.WithRandomServerAddress()) + testutils.MustNoError(instancesRepo.Add(ctx, inst)) + } probe1 := probe.New(svr1.Addr, svr1.QueryPort, probe.GoalDetails, 0) probe2 := probe.New(svr2.Addr, svr2.QueryPort, probe.GoalDetails, 0) - probesRepo.AddBetween(ctx, probe1, time.Now().Add(time.Hour), repositories.NC) // nolint: errcheck - probesRepo.Add(ctx, probe2) // nolint: errcheck + testutils.MustNoError(probesRepo.AddBetween(ctx, probe1, time.Now().Add(time.Hour), repositories.NC)) + testutils.MustNoError(probesRepo.Add(ctx, probe2)) app.Start(context.TODO()) // nolint: errcheck defer func() { @@ -358,8 +358,13 @@ func TestExporter_CleanerMetrics(t *testing.T) { parser := expfmt.TextParser{} mf, _ := parser.TextToMetricFamilies(resp.Body) - assert.Equal(t, 2, int(mf["cleaner_removals_total"].Metric[0].Counter.GetValue())) + assert.Equal(t, 0, int(mf["cleaner_removals_total"].Metric[0].Counter.GetValue())) + assert.Equal(t, "instances", *mf["cleaner_removals_total"].Metric[0].Label[0].Value) + + assert.Equal(t, 2, int(mf["cleaner_removals_total"].Metric[1].Counter.GetValue())) + assert.Equal(t, "servers", *mf["cleaner_removals_total"].Metric[1].Label[0].Value) assert.Equal(t, 0, int(mf["cleaner_errors_total"].Metric[0].Counter.GetValue())) + assert.Equal(t, "servers", *mf["cleaner_errors_total"].Metric[0].Label[0].Value) } func TestExporter_ProberMetrics(t *testing.T) { diff --git a/tests/modules/reporter_test.go b/tests/modules/reporter_test.go index eb1ee8c..8d17e05 100644 --- a/tests/modules/reporter_test.go +++ b/tests/modules/reporter_test.go @@ -203,7 +203,7 @@ func TestReporter_Heartbeat_ServerIsAddedAndThenUpdated(t *testing.T) { assert.Equal(t, "A-Bomb Nightclub", svr.Info.MapName) // instance is stored with the server's address - inst, err := instanceRepo.Get(ctx, string([]byte{0xfe, 0xed, 0xf0, 0x0d})) + inst, err := instanceRepo.Get(ctx, instance.MustNewID([]byte{0xfe, 0xed, 0xf0, 0x0d})) assert.NoError(t, err) assert.Equal(t, "127.0.0.1:10480", inst.Addr.String()) @@ -254,7 +254,7 @@ func TestReporter_Heartbeat_ServerIsUpdatedWithNewInstanceID(t *testing.T) { req := testutils.PackHeartbeatRequest(oldInstanceID, paramsBefore) testutils.SendUDP("127.0.0.1:33811", req) - ins, err := instanceRepo.Get(ctx, string(oldInstanceID)) + ins, err := instanceRepo.Get(ctx, instance.MustNewID(oldInstanceID)) require.NoError(t, err) svr, err := serverRepo.Get(ctx, ins.Addr) require.NoError(t, err) @@ -282,10 +282,10 @@ func TestReporter_Heartbeat_ServerIsUpdatedWithNewInstanceID(t *testing.T) { assert.Equal(t, "127.0.0.1", svr.Addr.GetDottedIP()) // the server is still accessible by the former instance key until the instance is recycled - ins, err = instanceRepo.Get(ctx, string(oldInstanceID)) + ins, err = instanceRepo.Get(ctx, instance.MustNewID(oldInstanceID)) require.NoError(t, err) assert.Equal(t, "127.0.0.1:10480", ins.Addr.String()) - assert.Equal(t, string(oldInstanceID), ins.ID) + assert.Equal(t, oldInstanceID, ins.ID[:]) } func TestReporter_Heartbeat_ServerPortIsDiscovered(t *testing.T) { @@ -528,7 +528,11 @@ func TestReporter_Heartbeat_ServerRemovalIsValidated(t *testing.T) { client := testutils.NewUDPClient("127.0.0.1:33811", 1024, time.Millisecond*10) svr := serverfactory.Build(serverfactory.WithAddress(tt.ipaddr, 10480), serverfactory.WithQueryPort(10484)) - inst := instance.MustNew(string([]byte{0xfe, 0xed, 0xf0, 0x0d}), svr.Addr.GetIP(), svr.Addr.Port) + inst := instance.MustNew( + instance.MustNewID([]byte{0xfe, 0xed, 0xf0, 0x0d}), + svr.Addr.GetIP(), + svr.Addr.Port, + ) serverRepo.Add(ctx, svr, repositories.ServerOnConflictIgnore) // nolint: errcheck instanceRepo.Add(ctx, inst) // nolint: errcheck @@ -739,7 +743,7 @@ func TestReporter_Keepalive_Errors(t *testing.T) { client := testutils.NewUDPClient("127.0.0.1:33811", 1024, time.Millisecond*10) svr := server.MustNew(net.ParseIP(tt.svrAddr), 10480, 10484) - inst := instance.MustNew(string(tt.instanceID), svr.Addr.GetIP(), svr.Addr.Port) + inst := instance.MustNew(instance.MustNewID(tt.instanceID), svr.Addr.GetIP(), svr.Addr.Port) serverRepo.Add(ctx, svr, repositories.ServerOnConflictIgnore) // nolint: errcheck instanceRepo.Add(ctx, inst) // nolint: errcheck