diff --git a/pkg/connector/service.go b/pkg/connector/service.go index 25e3f824e..71dd9525d 100644 --- a/pkg/connector/service.go +++ b/pkg/connector/service.go @@ -233,11 +233,11 @@ func (s *Service) SetState(ctx context.Context, id string, state any) (*Instance if state != nil { switch conn.Type { case TypeSource: - if _, ok := state.(SourceState); ok { + if _, ok := state.(SourceState); !ok { return nil, cerrors.Errorf("expected source state (ID: %s): %w", id, ErrInvalidConnectorStateType) } case TypeDestination: - if _, ok := state.(DestinationState); ok { + if _, ok := state.(DestinationState); !ok { return nil, cerrors.Errorf("expected destination state (ID: %s): %w", id, ErrInvalidConnectorStateType) } default: diff --git a/pkg/connector/service_test.go b/pkg/connector/service_test.go index b7dd1a35f..650c16d10 100644 --- a/pkg/connector/service_test.go +++ b/pkg/connector/service_test.go @@ -24,6 +24,7 @@ import ( "github.com/conduitio/conduit/pkg/foundation/database/inmemory" "github.com/conduitio/conduit/pkg/foundation/database/mock" "github.com/conduitio/conduit/pkg/foundation/log" + "github.com/conduitio/conduit/pkg/record" "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/matryer/is" @@ -388,3 +389,89 @@ func TestService_UpdateInstanceNotFound(t *testing.T) { is.True(cerrors.Is(err, ErrInstanceNotFound)) is.Equal(got, nil) } + +func TestService_SetState(t *testing.T) { + type testCase struct { + name string + connType Type + state any + wantErr error + } + testCases := []testCase{ + { + name: "nil state", + connType: TypeSource, + state: nil, + wantErr: nil, + }, + { + name: "correct state (source)", + connType: TypeSource, + state: SourceState{Position: record.Position("test position")}, + wantErr: nil, + }, + { + name: "correct state (destination)", + connType: TypeDestination, + state: DestinationState{ + Positions: map[string]record.Position{ + "test-connector": record.Position("test-position"), + }, + }, + wantErr: nil, + }, + { + name: "wrong state", + connType: TypeSource, + state: DestinationState{ + Positions: map[string]record.Position{ + "test-connector": record.Position("test-position"), + }, + }, + wantErr: ErrInvalidConnectorStateType, + }, + { + name: "completely wrong state", + connType: TypeSource, + state: testCase{name: "completely wrong state"}, + wantErr: ErrInvalidConnectorStateType, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + ctx := context.Background() + logger := log.Nop() + db := &inmemory.DB{} + + service := NewService(logger, db, nil) + conn, err := service.Create( + ctx, + uuid.NewString(), + tc.connType, + "test-plugin", + uuid.NewString(), + Config{ + Name: "test-connector", + Settings: map[string]string{"foo": "bar"}, + }, + ProvisionTypeAPI, + ) + is.NoErr(err) + + gotConn, err := service.SetState( + ctx, + conn.ID, + tc.state, + ) + if tc.wantErr != nil { + is.True(cerrors.Is(err, tc.wantErr)) + is.True(gotConn == nil) + } else { + is.NoErr(err) + is.Equal(conn, gotConn) + } + }) + } +}