diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a5bcd30b..4c496bef 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,7 +59,7 @@ jobs: - name: Test run: | set -euo pipefail - go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all + go test -count=1 -race -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all shell: bash env: POSTGRES_HOST: localhost diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index dea3efe1..f3132ae8 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -84,7 +84,11 @@ func (ps *PubSub) Notify(chanName string, p Payload) error { return fmt.Errorf("notify with payload %v timed out", p.Type()) } if ps.bufferSize == 0 { + // for some reason go test -race flags this as racing with calls + // to close(ch), despite the fact that it _should_ be thread-safe :S + ps.mu.Lock() ch <- &emptyPayload{} + ps.mu.Unlock() } return nil } diff --git a/state/accumulator_test.go b/state/accumulator_test.go index db358830..5d95b0b4 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/matrix-org/sliding-sync/testutils" "reflect" "sort" "sync" + "sync/atomic" "testing" + "github.com/matrix-org/sliding-sync/testutils" + "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/sqlutil" "github.com/matrix-org/sliding-sync/sync2" @@ -680,7 +682,7 @@ func TestAccumulatorConcurrency(t *testing.T) { []byte(`{"event_id":"con_4", "type":"m.room.name", "state_key":"", "content":{"name":"4"}}`), []byte(`{"event_id":"con_5", "type":"m.room.name", "state_key":"", "content":{"name":"5"}}`), } - totalNumNew := 0 + var totalNumNew atomic.Int64 var wg sync.WaitGroup wg.Add(len(newEvents)) for i := 0; i < len(newEvents); i++ { @@ -689,7 +691,7 @@ func TestAccumulatorConcurrency(t *testing.T) { subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { result, err := accumulator.Accumulate(txn, userID, roomID, sync2.TimelineResponse{Events: subset}) - totalNumNew += result.NumNew + totalNumNew.Add(int64(result.NumNew)) return err }) if err != nil { @@ -698,8 +700,8 @@ func TestAccumulatorConcurrency(t *testing.T) { }(i) } wg.Wait() // wait for all goroutines to finish - if totalNumNew != len(newEvents) { - t.Errorf("got %d total new events, want %d", totalNumNew, len(newEvents)) + if int(totalNumNew.Load()) != len(newEvents) { + t.Errorf("got %d total new events, want %d", totalNumNew.Load(), len(newEvents)) } // check that the name of the room is "5" snapshot := currentSnapshotNIDs(t, accumulator.snapshotTable, roomID)