diff --git a/components/ledger/internal/storage/ledger/transactions.go b/components/ledger/internal/storage/ledger/transactions.go index 347e6afe4e..497982c2fc 100644 --- a/components/ledger/internal/storage/ledger/transactions.go +++ b/components/ledger/internal/storage/ledger/transactions.go @@ -549,6 +549,9 @@ func (s *Store) RevertTransaction(ctx context.Context, id int) (*ledger.Transact } reverted = rowsAffected > 0 + if !reverted { + return nil, nil + } return pointer.For(ret.toCore()), nil }) diff --git a/components/ledger/pkg/testserver/api.go b/components/ledger/pkg/testserver/api.go index 35b6b5fc28..423f1a64a1 100644 --- a/components/ledger/pkg/testserver/api.go +++ b/components/ledger/pkg/testserver/api.go @@ -55,6 +55,16 @@ func ListTransactions(ctx context.Context, srv *Server, request operations.V2Lis return &response.V2TransactionsCursorResponse.Cursor, nil } +func ListLedgers(ctx context.Context, srv *Server, request operations.V2ListLedgersRequest) (*components.Cursor, error) { + response, err := srv.Client().Ledger.V2.ListLedgers(ctx, request) + + if err != nil { + return nil, mapSDKError(err) + } + + return &response.V2LedgerListResponse.Cursor, nil +} + func GetAggregatedBalances(ctx context.Context, srv *Server, request operations.V2GetBalancesAggregatedRequest) (map[string]*big.Int, error) { response, err := srv.Client().Ledger.V2.GetBalancesAggregated(ctx, request) diff --git a/components/ledger/test/e2e/stress_test.go b/components/ledger/test/e2e/stress_test.go index 9846c6a5fd..1048732307 100644 --- a/components/ledger/test/e2e/stress_test.go +++ b/components/ledger/test/e2e/stress_test.go @@ -3,6 +3,7 @@ package test_suite import ( + "context" "fmt" "github.com/alitto/pond" ledger "github.com/formancehq/ledger/internal" @@ -14,8 +15,11 @@ import ( "github.com/formancehq/stack/libs/go-libs/testing/platform/pgtesting" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/types" "math/big" "math/rand" + "sync" + "sync/atomic" ) var _ = Context("Ledger stress tests", func() { @@ -55,14 +59,20 @@ var _ = Context("Ledger stress tests", func() { } }) When(fmt.Sprintf("creating %d transactions across the same account pool", countTransactions), func() { + var ( + createdTransactions map[string][]*big.Int + mu sync.Mutex + ) BeforeEach(func() { + createdTransactions = map[string][]*big.Int{} wp := pond.New(80, 80) for range countTransactions { wp.Submit(func() { defer GinkgoRecover() - _, err := CreateTransaction(ctx, testServer.GetValue(), operations.V2CreateTransactionRequest{ - Ledger: fmt.Sprintf("ledger%d", rand.Intn(countLedgers)), + ledger := fmt.Sprintf("ledger%d", rand.Intn(countLedgers)) + createdTx, err := CreateTransaction(ctx, testServer.GetValue(), operations.V2CreateTransactionRequest{ + Ledger: ledger, V2PostTransaction: components.V2PostTransaction{ // todo: add another postings Postings: []components.V2Posting{{ @@ -75,6 +85,12 @@ var _ = Context("Ledger stress tests", func() { Force: pointer.For(true), }) Expect(err).ShouldNot(HaveOccurred()) + mu.Lock() + if createdTransactions[ledger] == nil { + createdTransactions[ledger] = []*big.Int{} + } + createdTransactions[ledger] = append(createdTransactions[ledger], createdTx.ID) + mu.Unlock() }) go func() { @@ -84,29 +100,101 @@ var _ = Context("Ledger stress tests", func() { }) When("getting aggregated volumes with no parameters", func() { It("should be zero", func() { - for i := range countLedgers { - ledger := fmt.Sprintf("ledger%d", i) - By("checking ledger "+ledger, func() { - aggregatedBalances, err := GetAggregatedBalances(ctx, testServer.GetValue(), operations.V2GetBalancesAggregatedRequest{ - Ledger: ledger, - UseInsertionDate: pointer.For(true), - }) - Expect(err).To(BeNil()) - if len(aggregatedBalances) == 0 { // it's random, a ledger could not have been targeted - // just in case, check if the ledger has transactions - txs, err := ListTransactions(ctx, testServer.GetValue(), operations.V2ListTransactionsRequest{ - Ledger: ledger, + Expect(testServer.GetValue()).To(HaveCoherentState()) + }) + }) + When("trying to revert concurrently all transactions", func() { + It("should be handled correctly", func() { + const ( + duplicates = 3 + ) + var ( + success atomic.Int64 + failures atomic.Int64 + ) + wp := pond.New(80, 80) + for ledger, ids := range createdTransactions { + for _, id := range ids { + for range duplicates + 1 { + wp.Submit(func() { + defer GinkgoRecover() + + _, err := RevertTransaction(ctx, testServer.GetValue(), operations.V2RevertTransactionRequest{ + Ledger: ledger, + ID: id, + Force: pointer.For(true), + }) + if err == nil { + success.Add(1) + } else { + failures.Add(1) + } }) - Expect(err).To(BeNil()) - Expect(txs.Data).To(HaveLen(0)) - } else { - Expect(aggregatedBalances).To(HaveLen(1)) - Expect(aggregatedBalances["USD"]).To(Equal(big.NewInt(0))) } - }) + } } + wp.StopAndWait() + By("we should have the correct amount of success/failures", func() { + Expect(success.Load()).To(Equal(int64(countTransactions))) + Expect(failures.Load()).To(Equal(int64(duplicates * countTransactions))) + }) + By("we should still have the aggregated balances to 0", func() { + Expect(testServer.GetValue()).To(HaveCoherentState()) + }) }) }) }) }) }) + +type HaveCoherentStateMatcher struct{} + +func (h HaveCoherentStateMatcher) Match(actual interface{}) (success bool, err error) { + srv, ok := actual.(*Server) + if !ok { + return false, fmt.Errorf("expect type %T", new(Server)) + } + ctx := context.Background() + + ledgers, err := ListLedgers(ctx, srv, operations.V2ListLedgersRequest{ + PageSize: pointer.For(int64(100)), + }) + if err != nil { + return false, err + } + + for _, ledger := range ledgers.Data { + aggregatedBalances, err := GetAggregatedBalances(ctx, srv, operations.V2GetBalancesAggregatedRequest{ + Ledger: ledger.Name, + UseInsertionDate: pointer.For(true), + }) + Expect(err).To(BeNil()) + if len(aggregatedBalances) == 0 { // it's random, a ledger could not have been targeted + // just in case, check if the ledger has transactions + txs, err := ListTransactions(ctx, srv, operations.V2ListTransactionsRequest{ + Ledger: ledger.Name, + }) + Expect(err).To(BeNil()) + Expect(txs.Data).To(HaveLen(0)) + } else { + Expect(aggregatedBalances).To(HaveLen(1)) + Expect(aggregatedBalances["USD"]).To(Equal(big.NewInt(0))) + } + } + + return true, nil +} + +func (h HaveCoherentStateMatcher) FailureMessage(_ interface{}) (message string) { + return fmt.Sprintf("server should has coherent state") +} + +func (h HaveCoherentStateMatcher) NegatedFailureMessage(_ interface{}) (message string) { + return fmt.Sprintf("server should not has coherent state but has") +} + +var _ types.GomegaMatcher = (*HaveCoherentStateMatcher)(nil) + +func HaveCoherentState() *HaveCoherentStateMatcher { + return &HaveCoherentStateMatcher{} +}