Skip to content

Commit

Permalink
fix race condition in connector persister (#857)
Browse files Browse the repository at this point in the history
Co-authored-by: Haris Osmanagić <[email protected]>
  • Loading branch information
lovromazgon and hariso authored Feb 17, 2023
1 parent ab0ee32 commit c8eabf7
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 49 deletions.
10 changes: 8 additions & 2 deletions pkg/connector/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ func (d *Destination) initPlugin(ctx context.Context) (plugin.DestinationPlugin,
d.Instance.logger.Debug(ctx).Msg("configuring destination connector plugin")
err = dest.Configure(ctx, d.Instance.Config.Settings)
if err != nil {
_ = dest.Teardown(ctx)
tdErr := dest.Teardown(ctx)
err = cerrors.LogOrReplace(err, tdErr, func() {
d.Instance.logger.Err(ctx, tdErr).Msg("could not tear down destination connector plugin")
})
return nil, err
}

Expand All @@ -88,7 +91,10 @@ func (d *Destination) Open(ctx context.Context) error {
err = dest.Start(streamCtx)
if err != nil {
cancelStreamCtx()
_ = dest.Teardown(ctx)
tdErr := dest.Teardown(ctx)
err = cerrors.LogOrReplace(err, tdErr, func() {
d.Instance.logger.Err(ctx, tdErr).Msg("could not tear down destination connector plugin")
})
return err
}

Expand Down
55 changes: 32 additions & 23 deletions pkg/connector/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"sync"
"time"

"github.com/conduitio/conduit/pkg/foundation/cerrors"
"github.com/conduitio/conduit/pkg/foundation/database"
"github.com/conduitio/conduit/pkg/foundation/log"
)
Expand All @@ -43,14 +44,19 @@ type Persister struct {
// m guards all private variables below it.
m sync.Mutex
bundleCount int
callbacks map[*Instance]PersistCallback
batch map[string]persistData
flushTimer *time.Timer
flushWg sync.WaitGroup
}

// PersistCallback is a function that's called when a connector is persisted.
type PersistCallback func(error)

type persistData struct {
callback PersistCallback
storeFunc func(context.Context) error
}

// NewPersister creates a new persister that stores data into the supplied
// database when the thresholds are met.
func NewPersister(
Expand Down Expand Up @@ -92,30 +98,39 @@ func (p *Persister) ConnectorStopped() {
// connectors until either the number of detected changes reaches the configured
// threshold or the configured delay is reached (whichever comes first), then
// the connectors are flushed and a new batch starts to be collected.
func (p *Persister) Persist(ctx context.Context, conn *Instance, callback PersistCallback) {
func (p *Persister) Persist(ctx context.Context, conn *Instance, callback PersistCallback) error {
p.m.Lock()
defer p.m.Unlock()

p.logger.Trace(ctx).
Str(log.ConnectorIDField, conn.ID).
Msg("adding connector to next persist batch")
if p.callbacks == nil {
p.callbacks = make(map[*Instance]PersistCallback)
if p.batch == nil {
p.batch = make(map[string]persistData)
}

storeFunc, err := p.store.PrepareSet(conn.ID, conn)
if err != nil {
return cerrors.Errorf("failed to prepare connector for persistance: %w", err)
}
p.batch[conn.ID] = persistData{
callback: callback,
storeFunc: storeFunc,
}
p.callbacks[conn] = callback
p.bundleCount++

if p.bundleCount == p.bundleCountThreshold {
p.logger.Trace(ctx).Msg("reached bundle count threshold")
p.triggerFlush(context.Background()) // use a new context because action happens in background
return
return nil
}

if p.flushTimer == nil {
p.flushTimer = time.AfterFunc(p.delayThreshold, func() {
p.Flush(context.Background()) // use a new context because action happens in background
})
}
return nil
}

// Wait waits for all connectors to stop running and for the last flush to be executed.
Expand All @@ -139,24 +154,24 @@ func (p *Persister) triggerFlush(ctx context.Context) {
p.flushTimer.Stop()
p.flushTimer = nil
}
if p.callbacks == nil {
if p.batch == nil {
return
}

// wait for any running flusher to finish
p.flushWg.Wait()

// reset callbacks and bundle count
callbacks := p.callbacks
p.callbacks = nil
batch := p.batch
p.batch = nil
p.bundleCount = 0

p.flushWg.Add(1)
go p.flushNow(ctx, callbacks)
go p.flushNow(ctx, batch)
}

// flushNow will flush the state to the store.
func (p *Persister) flushNow(ctx context.Context, callbacks map[*Instance]PersistCallback) {
func (p *Persister) flushNow(ctx context.Context, batch map[string]persistData) {
defer p.flushWg.Done()
start := time.Now()

Expand All @@ -168,31 +183,25 @@ func (p *Persister) flushNow(ctx context.Context, callbacks map[*Instance]Persis
}

defer tx.Discard()
for conn := range callbacks {
err := p.flushSingle(ctx, conn)
for id, data := range batch {
err := data.storeFunc(ctx)
if err != nil {
p.logger.Err(ctx, err).
Str(log.ConnectorIDField, conn.ID).
Str(log.ConnectorIDField, id).
Msg("error while saving connector")
}
}
if err == nil {
err = tx.Commit()
}
for _, c := range callbacks {
for _, data := range batch {
// execute callbacks in go routines to make sure they can't block this function
go c(err)
go data.callback(err)
}

p.logger.Debug(ctx).
Err(err).
Int("count", len(callbacks)).
Int("count", len(batch)).
Dur(log.DurationField, time.Since(start)).
Msg("persisted connectors")
}

func (p *Persister) flushSingle(ctx context.Context, conn *Instance) error {
conn.Lock()
defer conn.Unlock()
return p.store.Set(ctx, conn.ID, conn)
}
16 changes: 11 additions & 5 deletions pkg/connector/persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ func TestPersister_PersistFlushesAfterDelayThreshold(t *testing.T) {
conn := &Instance{ID: uuid.NewString(), Type: TypeDestination}
callbackCalled := make(chan struct{})
persistAt := time.Now()
persister.Persist(ctx, conn, func(err error) {
err := persister.Persist(ctx, conn, func(err error) {
if err != nil {
t.Fatalf("expected nil error, got: %v", err)
}
close(callbackCalled)
})
is.NoErr(err)

// we are testing a delay which is not exact, this is the acceptable margin
maxDelay := delayThreshold + time.Millisecond*10
Expand Down Expand Up @@ -90,16 +91,18 @@ func TestPersister_PersistFlushesAfterBundleCountThreshold(t *testing.T) {

for i := 0; i < bundleCountThreshold/2; i++ {
conn := &Instance{ID: uuid.NewString(), Type: TypeDestination}
persister.Persist(ctx, conn, func(err error) {
err := persister.Persist(ctx, conn, func(err error) {
t.Fatal("expected callback to be overwritten!")
})
is.NoErr(err)
// second persist will overwrite first callback
persister.Persist(ctx, conn, func(err error) {
err = persister.Persist(ctx, conn, func(err error) {
if err != nil {
t.Fatalf("expected nil error, got: %v", err)
}
wgCallbacks.Done()
})
is.NoErr(err)
}
lastPersistAt := time.Now()

Expand Down Expand Up @@ -127,12 +130,13 @@ func TestPersister_FlushStoresRightAway(t *testing.T) {
conn := &Instance{ID: uuid.NewString(), Type: TypeDestination}
callbackCalled := make(chan struct{})
timeAtPersist := time.Now()
persister.Persist(ctx, conn, func(err error) {
err := persister.Persist(ctx, conn, func(err error) {
if err != nil {
t.Fatalf("expected nil error, got: %v", err)
}
close(callbackCalled)
})
is.NoErr(err)

// flush right away
persister.Flush(ctx)
Expand All @@ -155,6 +159,7 @@ func TestPersister_FlushStoresRightAway(t *testing.T) {
}

func TestPersister_WaitsForOpenConnectorsAndFlush(t *testing.T) {
is := is.New(t)
ctx := context.Background()
persister, _ := initPersisterTest(time.Millisecond*100, 2)

Expand All @@ -171,7 +176,8 @@ func TestPersister_WaitsForOpenConnectorsAndFlush(t *testing.T) {
persister.ConnectorStopped()
// before last stop we persist another change which should be flushed
// automatically when the connector is stopped
persister.Persist(ctx, conn, func(err error) {})
err := persister.Persist(ctx, conn, func(err error) {})
is.NoErr(err)
persister.ConnectorStopped()
}()

Expand Down
22 changes: 16 additions & 6 deletions pkg/connector/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ func (s *Source) initPlugin(ctx context.Context) (plugin.SourcePlugin, error) {
s.Instance.logger.Debug(ctx).Msg("configuring source connector plugin")
err = src.Configure(ctx, s.Instance.Config.Settings)
if err != nil {
_ = src.Teardown(ctx)
tdErr := src.Teardown(ctx)
err = cerrors.LogOrReplace(err, tdErr, func() {
s.Instance.logger.Err(ctx, tdErr).Msg("could not tear down source connector plugin")
})
return nil, err
}

Expand Down Expand Up @@ -93,7 +96,10 @@ func (s *Source) Open(ctx context.Context) error {
err = src.Start(streamCtx, state.Position)
if err != nil {
cancelStreamCtx()
_ = src.Teardown(ctx)
tdErr := src.Teardown(ctx)
err = cerrors.LogOrReplace(err, tdErr, func() {
s.Instance.logger.Err(ctx, tdErr).Msg("could not tear down source connector plugin")
})
return err
}

Expand Down Expand Up @@ -201,16 +207,20 @@ func (s *Source) Ack(ctx context.Context, p record.Position) error {
return err
}

// lock to prevent race condition with connector persister
// lock as we are updating the state and leave it locked so the persister
// can safely prepare the connector before it stores it
s.Instance.Lock()
defer s.Instance.Unlock()
s.Instance.State = SourceState{Position: p}
s.Instance.Unlock()

s.Instance.persister.Persist(ctx, s.Instance, func(err error) {
err = s.Instance.persister.Persist(ctx, s.Instance, func(err error) {
if err != nil {
s.errs <- err
}
})
if err != nil {
return cerrors.Errorf("failed to persist source connector: %w", err)
}

return nil
}

Expand Down
32 changes: 22 additions & 10 deletions pkg/connector/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,25 +140,37 @@ func (s *Store) migratePre041(ctx context.Context) {
}
}

// Set stores connector under the key id and returns nil on success, error
// otherwise.
func (s *Store) Set(ctx context.Context, id string, c *Instance) error {
// PrepareSet encodes the connector instance and returns a function that stores
// the connector. This can be used to prepare everything needed to store an
// instance without actually storing it yet.
func (s *Store) PrepareSet(id string, instance *Instance) (func(context.Context) error, error) {
if id == "" {
return cerrors.Errorf("can't store connector: %w", cerrors.ErrEmptyID)
return nil, cerrors.Errorf("can't store connector: %w", cerrors.ErrEmptyID)
}

raw, err := s.encode(c)
raw, err := s.encode(instance)
if err != nil {
return err
return nil, err
}
key := s.addKeyPrefix(id)

err = s.db.Set(ctx, key, raw)
return func(ctx context.Context) error {
err = s.db.Set(ctx, key, raw)
if err != nil {
return cerrors.Errorf("failed to store connector with ID %q: %w", id, err)
}
return nil
}, nil
}

// Set stores connector under the key id and returns nil on success, error
// otherwise.
func (s *Store) Set(ctx context.Context, id string, c *Instance) error {
prepared, err := s.PrepareSet(id, c)
if err != nil {
return cerrors.Errorf("failed to store connector with ID %q: %w", id, err)
return err
}

return nil
return prepared(ctx)
}

// Delete deletes connector under the key id and returns nil on success, error
Expand Down
42 changes: 39 additions & 3 deletions pkg/connector/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,43 @@ import (
"github.com/matryer/is"
)

func TestConfigStore_SetGet(t *testing.T) {
func TestStore_PrepareSet(t *testing.T) {
is := is.New(t)
ctx := context.Background()
logger := log.Nop()
db := &inmemory.DB{}

s := NewStore(db, logger)

want := &Instance{
ID: uuid.NewString(),
Type: TypeSource,
State: SourceState{
Position: []byte(uuid.NewString()),
},
CreatedAt: time.Now().UTC(),
}

// prepare only prepares the connector for storage, but doesn't store it yet
set, err := s.PrepareSet(want.ID, want)
is.NoErr(err)

// at this point the store should still be empty
got, err := s.Get(ctx, want.ID)
is.True(cerrors.Is(err, database.ErrKeyNotExist)) // expected error for non-existing key
is.True(got == nil)

// now we actually store the connector
err = set(ctx)
is.NoErr(err)

// get should return the connector now
got, err = s.Get(ctx, want.ID)
is.NoErr(err)
is.Equal(want, got)
}

func TestStore_SetGet(t *testing.T) {
is := is.New(t)
ctx := context.Background()
logger := log.Nop()
Expand All @@ -53,7 +89,7 @@ func TestConfigStore_SetGet(t *testing.T) {
is.Equal(want, got)
}

func TestConfigStore_GetAll(t *testing.T) {
func TestStore_GetAll(t *testing.T) {
is := is.New(t)
ctx := context.Background()
logger := log.Nop()
Expand Down Expand Up @@ -88,7 +124,7 @@ func TestConfigStore_GetAll(t *testing.T) {
is.Equal(want, got)
}

func TestConfigStore_Delete(t *testing.T) {
func TestStore_Delete(t *testing.T) {
is := is.New(t)
ctx := context.Background()
logger := log.Nop()
Expand Down

0 comments on commit c8eabf7

Please sign in to comment.