diff --git a/components/payments/go.mod b/components/payments/go.mod index fd9733658f..7dc919bbe3 100644 --- a/components/payments/go.mod +++ b/components/payments/go.mod @@ -24,6 +24,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.9.0 github.com/uptrace/bun v1.2.1 + github.com/uptrace/bun/dialect/pgdialect v1.2.1 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 go.opentelemetry.io/otel v1.28.0 go.opentelemetry.io/otel/trace v1.28.0 @@ -36,8 +37,12 @@ require ( ) require ( + dario.cat/mergo v1.0.1 // indirect filippo.io/edwards25519 v1.1.0 // indirect + github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/IBM/sarama v1.43.3 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/ThreeDotsLabs/watermill-http/v2 v2.3.0 // indirect github.com/ThreeDotsLabs/watermill-kafka/v3 v3.0.1 // indirect github.com/ThreeDotsLabs/watermill-nats/v2 v2.0.2 // indirect @@ -57,9 +62,14 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.30.4 // indirect github.com/aws/smithy-go v1.20.4 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/containerd/continuity v0.4.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/dnwe/otelsarama v0.0.0-20231212173111-631a0a53d5d4 // indirect + github.com/docker/cli v26.1.4+incompatible // indirect + github.com/docker/docker v27.2.1+incompatible // indirect + github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-units v0.5.0 // indirect github.com/eapache/go-resiliency v1.7.0 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 // indirect github.com/eapache/queue v1.1.0 // indirect @@ -82,6 +92,7 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/gorilla/schema v1.4.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect @@ -114,6 +125,9 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/go-testing-interface v0.0.0-20171004221916-a61a99592b77 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/term v0.5.0 // indirect github.com/muhlemmer/gu v0.3.1 // indirect github.com/muhlemmer/httpforwarded v0.1.0 // indirect github.com/nats-io/nats.go v1.37.0 // indirect @@ -121,6 +135,10 @@ require ( github.com/nats-io/nuid v1.0.1 // indirect github.com/oklog/run v1.0.0 // indirect github.com/oklog/ulid v1.3.1 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.0 // indirect + github.com/opencontainers/runc v1.1.14 // indirect + github.com/ory/dockertest/v3 v3.11.0 // indirect github.com/pborman/uuid v1.2.1 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -136,7 +154,6 @@ require ( github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect - github.com/uptrace/bun/dialect/pgdialect v1.2.1 // indirect github.com/uptrace/bun/extra/bunotel v1.2.1 // indirect github.com/uptrace/opentelemetry-go-extra/otellogrus v0.3.1 // indirect github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.4 // indirect @@ -146,6 +163,9 @@ require ( github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect github.com/xo/dburl v0.23.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/zitadel/oidc/v2 v2.12.0 // indirect @@ -180,6 +200,7 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/components/payments/go.sum b/components/payments/go.sum index 02d467a767..f31a7c8b2a 100644 --- a/components/payments/go.sum +++ b/components/payments/go.sum @@ -472,6 +472,8 @@ github.com/containerd/continuity v0.4.3 h1:6HVkalIp+2u1ZLH1J/pYX2oBVXlJZvh1X1A7b github.com/containerd/continuity v0.4.3/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -912,6 +914,7 @@ github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= @@ -1598,6 +1601,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/components/payments/internal/connectors/engine/activities/storage_balances_delete.go b/components/payments/internal/connectors/engine/activities/storage_balances_delete.go index 5d2158326a..5ebede20e3 100644 --- a/components/payments/internal/connectors/engine/activities/storage_balances_delete.go +++ b/components/payments/internal/connectors/engine/activities/storage_balances_delete.go @@ -8,7 +8,7 @@ import ( ) func (a Activities) StorageBalancesDelete(ctx context.Context, connectorID models.ConnectorID) error { - return a.storage.BalancesDeleteForConnectorID(ctx, connectorID) + return a.storage.BalancesDeleteFromConnectorID(ctx, connectorID) } var StorageBalancesDeleteActivity = Activities{}.StorageBalancesDelete diff --git a/components/payments/internal/storage/accounts.go b/components/payments/internal/storage/accounts.go index d0ca97f540..ea0174cd5a 100644 --- a/components/payments/internal/storage/accounts.go +++ b/components/payments/internal/storage/accounts.go @@ -161,7 +161,7 @@ func fromAccountModels(from models.Account) account { return account{ ID: from.ID, ConnectorID: from.ConnectorID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Reference: from.Reference, Type: string(from.Type), DefaultAsset: from.DefaultAsset, @@ -176,7 +176,7 @@ func toAccountModels(from account) models.Account { ID: from.ID, ConnectorID: from.ConnectorID, Reference: from.Reference, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Type: models.AccountType(from.Type), Name: from.Name, DefaultAsset: from.DefaultAsset, diff --git a/components/payments/internal/storage/accounts_test.go b/components/payments/internal/storage/accounts_test.go new file mode 100644 index 0000000000..fadcbc665a --- /dev/null +++ b/components/payments/internal/storage/accounts_test.go @@ -0,0 +1,526 @@ +package storage + +import ( + "context" + "testing" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/pointer" + "github.com/formancehq/go-libs/query" + "github.com/formancehq/go-libs/time" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultAccounts = []models.Account{ + { + ID: models.AccountID{ + Reference: "test1", + ConnectorID: defaultConnector.ID, + }, + ConnectorID: defaultConnector.ID, + Reference: "test1", + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Type: models.ACCOUNT_TYPE_INTERNAL, + Name: pointer.For("test1"), + DefaultAsset: pointer.For("USD/2"), + Metadata: map[string]string{ + "foo": "bar", + }, + Raw: []byte(`{}`), + }, + { + ID: models.AccountID{ + Reference: "test2", + ConnectorID: defaultConnector.ID, + }, + ConnectorID: defaultConnector.ID, + Reference: "test2", + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Type: models.ACCOUNT_TYPE_INTERNAL, + Metadata: map[string]string{ + "foo2": "bar2", + }, + Raw: []byte(`{}`), + }, + { + ID: models.AccountID{ + Reference: "test3", + ConnectorID: defaultConnector.ID, + }, + ConnectorID: defaultConnector.ID, + Reference: "test3", + CreatedAt: now.Add(-45 * time.Minute).UTC().Time, + Type: models.ACCOUNT_TYPE_EXTERNAL, + Name: pointer.For("test3"), + Metadata: map[string]string{ + "foo3": "bar3", + }, + Raw: []byte(`{}`), + }, + } + + defaultAccounts2 = []models.Account{ + { + ID: models.AccountID{ + Reference: "test1", + ConnectorID: defaultConnector2.ID, + }, + ConnectorID: defaultConnector2.ID, + Reference: "test1", + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Type: models.ACCOUNT_TYPE_INTERNAL, + Name: pointer.For("test1"), + DefaultAsset: pointer.For("USD/2"), + Metadata: map[string]string{ + "foo5": "bar5", + }, + Raw: []byte(`{}`), + }, + } +) + +func upsertAccounts(t *testing.T, ctx context.Context, storage Storage, accounts []models.Account) { + require.NoError(t, storage.AccountsUpsert(ctx, accounts)) +} + +func TestAccountsUpsert(t *testing.T) { + t.Parallel() + + now := time.Now() + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + + t.Run("same id insert", func(t *testing.T) { + id := models.AccountID{ + Reference: "test1", + ConnectorID: defaultConnector.ID, + } + + // Same account I but different fields + acc := models.Account{ + ID: id, + ConnectorID: defaultConnector.ID, + Reference: "test1", + CreatedAt: now.Add(-12 * time.Minute).UTC().Time, + Type: models.ACCOUNT_TYPE_EXTERNAL, + Name: pointer.For("changed"), + DefaultAsset: pointer.For("EUR"), + Metadata: map[string]string{ + "foo4": "bar4", + }, + Raw: []byte(`{}`), + } + + require.NoError(t, store.AccountsUpsert(ctx, []models.Account{acc})) + + // Check that account was not updated + account, err := store.AccountsGet(ctx, id) + require.NoError(t, err) + + // Accounts should not have changed + require.Equal(t, defaultAccounts[0], *account) + }) + + t.Run("unknown connector id", func(t *testing.T) { + unknownConnectorID := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + acc := models.Account{ + ID: models.AccountID{ + Reference: "test_unknown", + ConnectorID: unknownConnectorID, + }, + ConnectorID: unknownConnectorID, + Reference: "test_unknown", + CreatedAt: now.Add(-12 * time.Minute).UTC().Time, + Type: models.ACCOUNT_TYPE_EXTERNAL, + Name: pointer.For("changed"), + DefaultAsset: pointer.For("EUR"), + Metadata: map[string]string{ + "foo4": "bar4", + }, + Raw: []byte(`{}`), + } + + require.Error(t, store.AccountsUpsert(ctx, []models.Account{acc})) + }) +} + +func TestAccountsGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + + t.Run("get account", func(t *testing.T) { + for _, acc := range defaultAccounts { + account, err := store.AccountsGet(ctx, acc.ID) + require.NoError(t, err) + require.Equal(t, acc, *account) + } + }) + + t.Run("get unknown account", func(t *testing.T) { + acc := models.AccountID{ + Reference: "unknown", + ConnectorID: defaultConnector.ID, + } + + account, err := store.AccountsGet(ctx, acc) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + require.Nil(t, account) + }) +} + +func TestAccountsDelete(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertConnector(t, ctx, store, defaultConnector2) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertAccounts(t, ctx, store, defaultAccounts2) + + t.Run("delete account from unknown connector", func(t *testing.T) { + unknownConnectorID := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + require.NoError(t, store.AccountsDeleteFromConnectorID(ctx, unknownConnectorID)) + + for _, acc := range defaultAccounts { + account, err := store.AccountsGet(ctx, acc.ID) + require.NoError(t, err) + require.Equal(t, acc, *account) + } + + for _, acc := range defaultAccounts2 { + account, err := store.AccountsGet(ctx, acc.ID) + require.NoError(t, err) + require.Equal(t, acc, *account) + } + }) + + t.Run("delete account from default connector", func(t *testing.T) { + require.NoError(t, store.AccountsDeleteFromConnectorID(ctx, defaultConnector.ID)) + + for _, acc := range defaultAccounts { + account, err := store.AccountsGet(ctx, acc.ID) + require.Error(t, err) + require.Nil(t, account) + require.ErrorIs(t, err, ErrNotFound) + } + + for _, acc := range defaultAccounts2 { + account, err := store.AccountsGet(ctx, acc.ID) + require.NoError(t, err) + require.Equal(t, acc, *account) + } + }) + +} + +func TestAccountsList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertConnector(t, ctx, store, defaultConnector2) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertAccounts(t, ctx, store, defaultAccounts2) + + t.Run("list accounts by reference", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("reference", "test1")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts2[0], cursor.Data[0]) + require.Equal(t, defaultAccounts[0], cursor.Data[1]) + }) + + t.Run("list accounts by reference 2", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("reference", "test2")), + ) + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts[1], cursor.Data[0]) + }) + + t.Run("list accounts by unknown reference", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("reference", "unknown")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list accounts by connector id", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", defaultConnector.ID)), + ) + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 3) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts[1], cursor.Data[0]) + require.Equal(t, defaultAccounts[2], cursor.Data[1]) + require.Equal(t, defaultAccounts[0], cursor.Data[2]) + }) + + t.Run("list accounts by connector id 2", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", defaultConnector2.ID)), + ) + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts2[0], cursor.Data[0]) + }) + + t.Run("list accounts by unknown connector id", func(t *testing.T) { + unknownConnectorID := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", unknownConnectorID)), + ) + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list accounts by type", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("type", models.ACCOUNT_TYPE_INTERNAL)), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 3) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts[1], cursor.Data[0]) + require.Equal(t, defaultAccounts2[0], cursor.Data[1]) + require.Equal(t, defaultAccounts[0], cursor.Data[2]) + }) + + t.Run("list accounts by unknown type", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("type", "unknown")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list accounts by default asset", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("default_asset", "USD/2")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts2[0], cursor.Data[0]) + require.Equal(t, defaultAccounts[0], cursor.Data[1]) + }) + + t.Run("list accounts by unknown default asset", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("default_asset", "unknown")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list accounts by name", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "test1")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts2[0], cursor.Data[0]) + require.Equal(t, defaultAccounts[0], cursor.Data[1]) + }) + + t.Run("list accounts by name 2", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "test3")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts[2], cursor.Data[0]) + }) + + t.Run("list accounts by unknown name", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "unknown")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list accounts by metadata", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("metadata[foo]", "bar")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Equal(t, defaultAccounts[0], cursor.Data[0]) + }) + + t.Run("list accounts by unknown metadata", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("metadata[foo]", "unknown")), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list accounts test cursor", func(t *testing.T) { + q := NewListAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(AccountQuery{}). + WithPageSize(1), + ) + + cursor, err := store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.Empty(t, cursor.Previous) + require.Equal(t, defaultAccounts[1], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultAccounts[2], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultAccounts2[0], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultAccounts[0], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultAccounts2[0], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.AccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultAccounts[2], cursor.Data[0]) + }) +} diff --git a/components/payments/internal/storage/balances.go b/components/payments/internal/storage/balances.go index e21e0fc767..7a9f619ee5 100644 --- a/components/payments/internal/storage/balances.go +++ b/components/payments/internal/storage/balances.go @@ -74,7 +74,7 @@ func (s *store) BalancesUpsert(ctx context.Context, balances []models.Balance) e return e("failed to commit transaction", tx.Commit()) } -func (s *store) BalancesDeleteForConnectorID(ctx context.Context, connectorID models.ConnectorID) error { +func (s *store) BalancesDeleteFromConnectorID(ctx context.Context, connectorID models.ConnectorID) error { _, err := s.db.NewDelete(). Model((*balance)(nil)). Where("connector_id = ?", connectorID). @@ -155,6 +155,7 @@ func (s *store) BalancesList(ctx context.Context, q ListBalancesQuery) (*bunpagi cursor, err := paginateWithOffset[bunpaginate.PaginatedQueryOptions[BalanceQuery], balance](s, ctx, (*bunpaginate.OffsetPaginatedQuery[bunpaginate.PaginatedQueryOptions[BalanceQuery]])(&q), func(query *bun.SelectQuery) *bun.SelectQuery { + query = applyBalanceQuery(query, q.Options.Options) query = query.Order("created_at DESC") @@ -244,11 +245,11 @@ func fromBalancesModels(from []models.Balance) []balance { func fromBalanceModels(from models.Balance) balance { return balance{ AccountID: from.AccountID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Asset: from.Asset, ConnectorID: from.AccountID.ConnectorID, Balance: from.Balance, - LastUpdatedAt: from.LastUpdatedAt, + LastUpdatedAt: from.LastUpdatedAt.UTC(), } } @@ -263,9 +264,9 @@ func toBalancesModels(from []balance) []models.Balance { func toBalanceModels(from balance) models.Balance { return models.Balance{ AccountID: from.AccountID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Asset: from.Asset, Balance: from.Balance, - LastUpdatedAt: from.LastUpdatedAt, + LastUpdatedAt: from.LastUpdatedAt.UTC(), } } diff --git a/components/payments/internal/storage/balances_test.go b/components/payments/internal/storage/balances_test.go new file mode 100644 index 0000000000..687f83f36f --- /dev/null +++ b/components/payments/internal/storage/balances_test.go @@ -0,0 +1,480 @@ +package storage + +import ( + "context" + "math/big" + "testing" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/pointer" + "github.com/formancehq/go-libs/time" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultBalances = []models.Balance{ + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-60 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + { + AccountID: defaultAccounts[1].ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-30 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(1000), + }, + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-55 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(150), + }, + } + + defaultBalances2 = []models.Balance{ + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-59 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-59 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-31 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-31 * time.Minute).UTC().Time, + Asset: "DKK/2", + Balance: big.NewInt(1000), + }, + } +) + +func upsertBalances(t *testing.T, ctx context.Context, storage Storage, balances []models.Balance) { + require.NoError(t, storage.BalancesUpsert(ctx, balances)) +} + +func TestBalancesUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBalances(t, ctx, store, defaultBalances) + upsertBalances(t, ctx, store, defaultBalances2) + + t.Run("insert balances with same asset and same balance", func(t *testing.T) { + b := models.Balance{ + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-20 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + } + + upsertBalances(t, ctx, store, []models.Balance{b}) + + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + AccountID: pointer.For(defaultAccounts[2].ID), + Asset: "USD/2", + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-59 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-20 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + } + + balances, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, balances.Data, 1) + require.Equal(t, expectedBalances, balances.Data) + }) + + t.Run("insert balances same asset different balance", func(t *testing.T) { + b := models.Balance{ + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-20 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(200), + } + + upsertBalances(t, ctx, store, []models.Balance{b}) + + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + AccountID: pointer.For(defaultAccounts[0].ID), + Asset: "USD/2", + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-20 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(200), + }, + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-20 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + } + + balances, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, balances.Data, 2) + require.Equal(t, expectedBalances, balances.Data) + }) + + t.Run("insert balances with new asset", func(t *testing.T) { + b := models.Balance{ + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-10 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-10 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(200), + } + + upsertBalances(t, ctx, store, []models.Balance{b}) + + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + AccountID: pointer.For(defaultAccounts[2].ID), + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-10 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-10 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(200), + }, + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-31 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-31 * time.Minute).UTC().Time, + Asset: "DKK/2", + Balance: big.NewInt(1000), + }, + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-59 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-20 * time.Minute).UTC().Time, // Because on the first function it was modified + Asset: "USD/2", + Balance: big.NewInt(100), + }, + } + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 3) + require.Equal(t, expectedBalances, cursor.Data) + }) +} + +func TestBalancesDeleteFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBalances(t, ctx, store, defaultBalances) + upsertBalances(t, ctx, store, defaultBalances2) + + t.Run("delete balances from unknown connector id", func(t *testing.T) { + err := store.BalancesDeleteFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }) + require.NoError(t, err) + }) + + t.Run("delete balances from known connector id", func(t *testing.T) { + err := store.BalancesDeleteFromConnectorID(ctx, defaultConnector.ID) + require.NoError(t, err) + + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + AccountID: pointer.For(defaultAccounts[0].ID), + }).WithPageSize(15), + ) + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + }) +} + +func TestBalancesList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBalances(t, ctx, store, defaultBalances) + upsertBalances(t, ctx, store, defaultBalances2) + + t.Run("list balances with account id", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + AccountID: pointer.For(defaultAccounts[0].ID), + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-55 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(150), + }, + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-60 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + } + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Equal(t, expectedBalances, cursor.Data) + }) + + t.Run("list balances with asset 1", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + Asset: "USD/2", + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-59 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-59 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-60 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + } + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Equal(t, expectedBalances, cursor.Data) + }) + + t.Run("list balances with asset 2", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + Asset: "DKK/2", + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-31 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-31 * time.Minute).UTC().Time, + Asset: "DKK/2", + Balance: big.NewInt(1000), + }, + } + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Equal(t, expectedBalances, cursor.Data) + }) + + t.Run("list balances with from", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + From: now.Add(-40 * time.Minute).UTC().Time, + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[1].ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-30 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(1000), + }, + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-31 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-31 * time.Minute).UTC().Time, + Asset: "DKK/2", + Balance: big.NewInt(1000), + }, + } + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Equal(t, expectedBalances, cursor.Data) + }) + + t.Run("list balances with from 2", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + From: now.Add(-20 * time.Minute).UTC().Time, + }).WithPageSize(15), + ) + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list balances with to", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + To: now.Add(-40 * time.Minute).UTC().Time, + }).WithPageSize(15), + ) + + expectedBalances := []models.Balance{ + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-55 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(150), + }, + { + AccountID: defaultAccounts[2].ID, + CreatedAt: now.Add(-59 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-59 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-60 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + } + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 3) + require.False(t, cursor.HasMore) + require.Equal(t, expectedBalances, cursor.Data) + }) + + t.Run("list balances with to 2", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + To: now.Add(-70 * time.Minute).UTC().Time, + }).WithPageSize(15), + ) + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list balances test cursor", func(t *testing.T) { + q := NewListBalancesQuery( + bunpaginate.NewPaginatedQueryOptions(BalanceQuery{ + AccountID: pointer.For(defaultAccounts[0].ID), + }).WithPageSize(1), + ) + + expectedBalances1 := []models.Balance{ + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-55 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(150), + }, + } + + cursor, err := store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, expectedBalances1, cursor.Data) + + expectedBalances2 := []models.Balance{ + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-60 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + } + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, expectedBalances2, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.BalancesList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, expectedBalances1, cursor.Data) + }) +} diff --git a/components/payments/internal/storage/bank_accounts.go b/components/payments/internal/storage/bank_accounts.go index 58a2b77b56..40237c8906 100644 --- a/components/payments/internal/storage/bank_accounts.go +++ b/components/payments/internal/storage/bank_accounts.go @@ -38,30 +38,48 @@ type bankAccount struct { RelatedAccounts []*bankAccountRelatedAccount `bun:"rel:has-many,join:id=bank_account_id"` } -func (s *store) BankAccountsUpsert(ctx context.Context, bankAccount models.BankAccount) error { +func (s *store) BankAccountsUpsert(ctx context.Context, ba models.BankAccount) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return e("begin transaction", err) } defer tx.Rollback() - toInsert := fromBankAccountModels(bankAccount) + toInsert := fromBankAccountModels(ba) // Insert or update the bank account - _, err = tx.NewInsert(). + var idsUpdated []uuid.UUID + err = tx.NewInsert(). Model(&toInsert). + Column("id", "created_at", "name", "country", "metadata"). On("CONFLICT (id) DO NOTHING"). - Exec(ctx) + Returning("id"). + Scan(ctx, &idsUpdated) if err != nil { return e("insert bank account", err) } - // Insert or update the related accounts - _, err = tx.NewInsert(). - Model(&toInsert.RelatedAccounts). - On("CONFLICT (bank_account_id, account_id) DO NOTHING"). - Exec(ctx) - if err != nil { - return e("insert related accounts", err) + if len(idsUpdated) > 0 { + _, err = tx.NewUpdate(). + Model((*bankAccount)(nil)). + Set("account_number = pgp_sym_encrypt(?::TEXT, ?, ?)", toInsert.AccountNumber, s.configEncryptionKey, encryptionOptions). + Set("iban = pgp_sym_encrypt(?::TEXT, ?, ?)", toInsert.IBAN, s.configEncryptionKey, encryptionOptions). + Set("swift_bic_code = pgp_sym_encrypt(?::TEXT, ?, ?)", toInsert.SwiftBicCode, s.configEncryptionKey, encryptionOptions). + Where("id = ?", toInsert.ID). + Exec(ctx) + if err != nil { + return e("update bank account", err) + } + } + + if len(toInsert.RelatedAccounts) > 0 { + // Insert or update the related accounts + _, err = tx.NewInsert(). + Model(&toInsert.RelatedAccounts). + On("CONFLICT (bank_account_id, account_id) DO NOTHING"). + Exec(ctx) + if err != nil { + return e("insert related accounts", err) + } } return e("commit transaction", tx.Commit()) @@ -108,9 +126,12 @@ func (s *store) BankAccountsGet(ctx context.Context, id uuid.UUID, expand bool) var account bankAccount query := s.db.NewSelect(). Model(&account). + Column("id", "created_at", "name", "country", "metadata"). Relation("RelatedAccounts") - if !expand { - query = query.Column("id", "created_at", "name", "country", "metadata") + if expand { + query = query.ColumnExpr("pgp_sym_decrypt(account_number, ?, ?) AS decrypted_account_number", s.configEncryptionKey, encryptionOptions). + ColumnExpr("pgp_sym_decrypt(iban, ?, ?) AS decrypted_iban", s.configEncryptionKey, encryptionOptions). + ColumnExpr("pgp_sym_decrypt(swift_bic_code, ?, ?) AS decrypted_swift_bic_code", s.configEncryptionKey, encryptionOptions) } err := query.Where("id = ?", id).Scan(ctx) if err != nil { @@ -239,7 +260,7 @@ func (s *store) BankAccountsDeleteRelatedAccountFromConnectorID(ctx context.Cont func fromBankAccountModels(from models.BankAccount) bankAccount { ba := bankAccount{ ID: from.ID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Name: from.Name, Country: from.Country, Metadata: from.Metadata, @@ -269,7 +290,7 @@ func fromBankAccountModels(from models.BankAccount) bankAccount { func toBankAccountModels(from bankAccount) models.BankAccount { ba := models.BankAccount{ ID: from.ID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Name: from.Name, Country: from.Country, Metadata: from.Metadata, @@ -301,7 +322,7 @@ func fromBankAccountRelatedAccountModels(from models.BankAccountRelatedAccount) BankAccountID: from.BankAccountID, AccountID: from.AccountID, ConnectorID: from.ConnectorID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), } } @@ -310,6 +331,6 @@ func toBankAccountRelatedAccountModels(from bankAccountRelatedAccount) models.Ba BankAccountID: from.BankAccountID, AccountID: from.AccountID, ConnectorID: from.ConnectorID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), } } diff --git a/components/payments/internal/storage/bank_accounts_test.go b/components/payments/internal/storage/bank_accounts_test.go new file mode 100644 index 0000000000..3224acc4bf --- /dev/null +++ b/components/payments/internal/storage/bank_accounts_test.go @@ -0,0 +1,715 @@ +package storage + +import ( + "context" + "testing" + "time" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/pointer" + "github.com/formancehq/go-libs/query" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultBankAccount = models.BankAccount{ + ID: uuid.New(), + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Name: "test1", + AccountNumber: pointer.For("12345678"), + Country: pointer.For("US"), + Metadata: map[string]string{ + "foo": "bar", + }, + } + + bcID2 = uuid.New() + defaultBankAccount2 = models.BankAccount{ + ID: bcID2, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Name: "test2", + IBAN: pointer.For("DE89370400440532013000"), + SwiftBicCode: pointer.For("COBADEFFXXX"), + Country: pointer.For("DE"), + Metadata: map[string]string{ + "foo2": "bar2", + }, + RelatedAccounts: []models.BankAccountRelatedAccount{ + { + BankAccountID: bcID2, + AccountID: defaultAccounts[0].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + }, + }, + } + + // No metadata + defaultBankAccount3 = models.BankAccount{ + ID: uuid.New(), + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Name: "test1", + AccountNumber: pointer.For("12345678"), + Country: pointer.For("US"), + } +) + +func upsertBankAccount(t *testing.T, ctx context.Context, storage Storage, bankAccounts models.BankAccount) { + require.NoError(t, storage.BankAccountsUpsert(ctx, bankAccounts)) +} + +func TestBankAccountsUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBankAccount(t, ctx, store, defaultBankAccount) + upsertBankAccount(t, ctx, store, defaultBankAccount2) + + t.Run("upsert with same id", func(t *testing.T) { + ba := models.BankAccount{ + ID: defaultBankAccount.ID, + CreatedAt: now.UTC().Time, + Name: "changed", + AccountNumber: pointer.For("987654321"), + Country: pointer.For("CA"), + Metadata: map[string]string{ + "changed": "changed", + }, + } + + require.NoError(t, store.BankAccountsUpsert(ctx, ba)) + + actual, err := store.BankAccountsGet(ctx, ba.ID, true) + require.NoError(t, err) + // Should not update the bank account + compareBankAccounts(t, defaultBankAccount, *actual) + }) + + t.Run("unknown connector id", func(t *testing.T) { + ba := models.BankAccount{ + ID: uuid.New(), + CreatedAt: now.UTC().Time, + Name: "foo", + AccountNumber: pointer.For("12345678"), + Country: pointer.For("US"), + Metadata: map[string]string{ + "foo": "bar", + }, + RelatedAccounts: []models.BankAccountRelatedAccount{ + { + BankAccountID: uuid.New(), + AccountID: defaultAccounts[0].ID, + ConnectorID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }, + CreatedAt: now.UTC().Time, + }, + }, + } + + require.Error(t, store.BankAccountsUpsert(ctx, ba)) + b, err := store.BankAccountsGet(ctx, ba.ID, true) + require.Error(t, err) + require.Nil(t, b) + }) +} + +func TestBankAccountsUpdateMetadata(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBankAccount(t, ctx, store, defaultBankAccount) + upsertBankAccount(t, ctx, store, defaultBankAccount2) + upsertBankAccount(t, ctx, store, defaultBankAccount3) + + t.Run("update metadata", func(t *testing.T) { + metadata := map[string]string{ + "test1": "test2", + "test3": "test4", + } + + acc := defaultBankAccount + for k, v := range metadata { + acc.Metadata[k] = v + } + + require.NoError(t, store.BankAccountsUpdateMetadata(ctx, defaultBankAccount.ID, metadata)) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount.ID, true) + require.NoError(t, err) + compareBankAccounts(t, acc, *actual) + }) + + t.Run("update same metadata", func(t *testing.T) { + metadata := map[string]string{ + "foo2": "bar3", + } + + acc := models.BankAccount{ + ID: bcID2, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Name: "test2", + IBAN: pointer.For("DE89370400440532013000"), + SwiftBicCode: pointer.For("COBADEFFXXX"), + Country: pointer.For("DE"), + Metadata: map[string]string{ + "foo2": "bar2", + }, + RelatedAccounts: []models.BankAccountRelatedAccount{ + { + BankAccountID: bcID2, + AccountID: defaultAccounts[0].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + }, + }, + } + for k, v := range metadata { + acc.Metadata[k] = v + } + + require.NoError(t, store.BankAccountsUpdateMetadata(ctx, defaultBankAccount2.ID, metadata)) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount2.ID, true) + require.NoError(t, err) + compareBankAccounts(t, acc, *actual) + }) + + t.Run("update metadata of bank accounts with nil map", func(t *testing.T) { + metadata := map[string]string{ + "test1": "test2", + "test3": "test4", + } + + acc := defaultBankAccount3 + acc.Metadata = make(map[string]string) + for k, v := range metadata { + acc.Metadata[k] = v + } + + require.NoError(t, store.BankAccountsUpdateMetadata(ctx, defaultBankAccount3.ID, metadata)) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount3.ID, true) + require.NoError(t, err) + compareBankAccounts(t, acc, *actual) + }) +} + +func TestBankAccountsGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBankAccount(t, ctx, store, defaultBankAccount) + upsertBankAccount(t, ctx, store, defaultBankAccount2) + upsertBankAccount(t, ctx, store, defaultBankAccount3) + + t.Run("get bank account without related accounts", func(t *testing.T) { + actual, err := store.BankAccountsGet(ctx, defaultBankAccount.ID, true) + require.NoError(t, err) + compareBankAccounts(t, defaultBankAccount, *actual) + }) + + t.Run("get bank account without metadata", func(t *testing.T) { + actual, err := store.BankAccountsGet(ctx, defaultBankAccount3.ID, true) + require.NoError(t, err) + compareBankAccounts(t, defaultBankAccount3, *actual) + }) + + t.Run("get bank account with related accounts", func(t *testing.T) { + actual, err := store.BankAccountsGet(ctx, defaultBankAccount2.ID, true) + require.NoError(t, err) + compareBankAccounts(t, defaultBankAccount2, *actual) + }) + + t.Run("get unknown bank account", func(t *testing.T) { + actual, err := store.BankAccountsGet(ctx, uuid.New(), true) + require.Error(t, err) + require.Nil(t, actual) + }) + + t.Run("get bank account with expand to false", func(t *testing.T) { + acc := models.BankAccount{ + ID: defaultBankAccount.ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Name: "test1", + Country: pointer.For("US"), + Metadata: map[string]string{ + "foo": "bar", + }, + } + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount.ID, false) + require.NoError(t, err) + compareBankAccounts(t, acc, *actual) + }) + + t.Run("get bank account with expand to false 2", func(t *testing.T) { + acc := models.BankAccount{ + ID: bcID2, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Name: "test2", + Country: pointer.For("DE"), + Metadata: map[string]string{ + "foo2": "bar2", + }, + RelatedAccounts: []models.BankAccountRelatedAccount{ + { + BankAccountID: bcID2, + AccountID: defaultAccounts[0].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + }, + }, + } + + actual, err := store.BankAccountsGet(ctx, bcID2, false) + require.NoError(t, err) + compareBankAccounts(t, acc, *actual) + }) +} + +func TestBankAccountsList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + d1 := models.BankAccount{ + ID: defaultBankAccount.ID, + CreatedAt: defaultBankAccount.CreatedAt, + Name: defaultBankAccount.Name, + Country: defaultBankAccount.Country, + Metadata: defaultBankAccount.Metadata, + } + + d2 := models.BankAccount{ + ID: defaultBankAccount2.ID, + CreatedAt: defaultBankAccount2.CreatedAt, + Name: defaultBankAccount2.Name, + Country: defaultBankAccount2.Country, + Metadata: defaultBankAccount2.Metadata, + RelatedAccounts: defaultBankAccount2.RelatedAccounts, + } + _ = d2 + + d3 := models.BankAccount{ + ID: defaultBankAccount3.ID, + CreatedAt: defaultBankAccount3.CreatedAt, + Name: defaultBankAccount3.Name, + Country: defaultBankAccount3.Country, + } + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBankAccount(t, ctx, store, defaultBankAccount) + upsertBankAccount(t, ctx, store, defaultBankAccount2) + upsertBankAccount(t, ctx, store, defaultBankAccount3) + + t.Run("list bank accounts by name", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "test1")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + compareBankAccounts(t, d3, cursor.Data[0]) + compareBankAccounts(t, d1, cursor.Data[1]) + }) + + t.Run("list bank accounts by name 2", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "test2")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + compareBankAccounts(t, d2, cursor.Data[0]) + }) + + t.Run("list bank accounts by unknown name", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "unknown")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + }) + + t.Run("list bank accounts by country", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("country", "US")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + compareBankAccounts(t, d3, cursor.Data[0]) + compareBankAccounts(t, d1, cursor.Data[1]) + }) + + t.Run("list bank accounts by country 2", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("country", "DE")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + compareBankAccounts(t, d2, cursor.Data[0]) + }) + + t.Run("list bank accounts by unknown country", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("country", "unknown")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + }) + + t.Run("list bank accounts by metadata", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("metadata[foo]", "bar")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + compareBankAccounts(t, d1, cursor.Data[0]) + }) + + t.Run("list bank accounts by unknown metadata", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("metadata[unknown]", "bar")), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + }) + + t.Run("list bank accounts test cursor", func(t *testing.T) { + q := NewListBankAccountsQuery( + bunpaginate.NewPaginatedQueryOptions(BankAccountQuery{}). + WithPageSize(1), + ) + + cursor, err := store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + compareBankAccounts(t, d2, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + compareBankAccounts(t, d3, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.Empty(t, cursor.Next) + compareBankAccounts(t, d1, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + compareBankAccounts(t, d3, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.BankAccountsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + compareBankAccounts(t, d2, cursor.Data[0]) + }) +} + +func TestBankAccountsAddRelatedAccount(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBankAccount(t, ctx, store, defaultBankAccount) + upsertBankAccount(t, ctx, store, defaultBankAccount2) + upsertBankAccount(t, ctx, store, defaultBankAccount3) + + t.Run("add related account when empty", func(t *testing.T) { + acc := models.BankAccountRelatedAccount{ + BankAccountID: defaultBankAccount.ID, + AccountID: defaultAccounts[0].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.UTC().Time, + } + + ba := defaultBankAccount + ba.RelatedAccounts = append(ba.RelatedAccounts, acc) + + require.NoError(t, store.BankAccountsAddRelatedAccount(ctx, acc)) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount.ID, true) + require.NoError(t, err) + compareBankAccounts(t, ba, *actual) + }) + + t.Run("add related account when not empty", func(t *testing.T) { + acc := models.BankAccountRelatedAccount{ + BankAccountID: defaultBankAccount2.ID, + AccountID: defaultAccounts[1].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.UTC().Time, + } + + ba := defaultBankAccount2 + ba.RelatedAccounts = append(ba.RelatedAccounts, acc) + + require.NoError(t, store.BankAccountsAddRelatedAccount(ctx, acc)) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount2.ID, true) + require.NoError(t, err) + compareBankAccounts(t, ba, *actual) + }) + + t.Run("add related account with unknown bank account", func(t *testing.T) { + acc := models.BankAccountRelatedAccount{ + BankAccountID: uuid.New(), + AccountID: defaultAccounts[1].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.UTC().Time, + } + + require.Error(t, store.BankAccountsAddRelatedAccount(ctx, acc)) + }) + + t.Run("add related account with unknown account", func(t *testing.T) { + acc := models.BankAccountRelatedAccount{ + BankAccountID: defaultBankAccount2.ID, + AccountID: models.AccountID{ + Reference: "unknown", + ConnectorID: defaultConnector.ID, + }, + ConnectorID: defaultConnector.ID, + CreatedAt: now.UTC().Time, + } + + require.Error(t, store.BankAccountsAddRelatedAccount(ctx, acc)) + }) + + t.Run("add related account with unknown connector", func(t *testing.T) { + acc := models.BankAccountRelatedAccount{ + BankAccountID: defaultBankAccount2.ID, + AccountID: defaultAccounts[2].ID, + ConnectorID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }, + CreatedAt: now.UTC().Time, + } + + require.Error(t, store.BankAccountsAddRelatedAccount(ctx, acc)) + }) + + t.Run("add related account with existing related account", func(t *testing.T) { + acc := models.BankAccountRelatedAccount{ + BankAccountID: defaultBankAccount3.ID, + AccountID: defaultAccounts[0].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + } + + ba := defaultBankAccount3 + ba.RelatedAccounts = append(ba.RelatedAccounts, acc) + + require.NoError(t, store.BankAccountsAddRelatedAccount(ctx, acc)) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount3.ID, true) + require.NoError(t, err) + compareBankAccounts(t, ba, *actual) + + require.NoError(t, store.BankAccountsAddRelatedAccount(ctx, acc)) + + actual, err = store.BankAccountsGet(ctx, defaultBankAccount3.ID, true) + require.NoError(t, err) + compareBankAccounts(t, ba, *actual) + }) +} + +func TestBankAccountsDeleteRelatedAccountFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertBankAccount(t, ctx, store, defaultBankAccount) + upsertBankAccount(t, ctx, store, defaultBankAccount2) + upsertBankAccount(t, ctx, store, defaultBankAccount3) + + t.Run("delete related account with unknown connector", func(t *testing.T) { + require.NoError(t, store.BankAccountsDeleteRelatedAccountFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount2.ID, true) + require.NoError(t, err) + compareBankAccounts(t, defaultBankAccount2, *actual) + }) + + t.Run("delete related account with another connector id", func(t *testing.T) { + require.NoError(t, store.BankAccountsDeleteRelatedAccountFromConnectorID(ctx, defaultConnector2.ID)) + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount2.ID, true) + require.NoError(t, err) + compareBankAccounts(t, defaultBankAccount2, *actual) + }) + + t.Run("delete related account", func(t *testing.T) { + require.NoError(t, store.BankAccountsDeleteRelatedAccountFromConnectorID(ctx, defaultConnector.ID)) + + ba := defaultBankAccount2 + ba.RelatedAccounts = nil + + actual, err := store.BankAccountsGet(ctx, defaultBankAccount2.ID, true) + require.NoError(t, err) + compareBankAccounts(t, ba, *actual) + }) +} + +func compareBankAccounts(t *testing.T, expected, actual models.BankAccount) { + require.Equal(t, expected.ID, actual.ID) + require.Equal(t, expected.CreatedAt, actual.CreatedAt) + require.Equal(t, expected.Name, actual.Name) + + require.Equal(t, len(expected.Metadata), len(actual.Metadata)) + for k, v := range expected.Metadata { + require.Equal(t, v, actual.Metadata[k]) + } + for k, v := range actual.Metadata { + require.Equal(t, v, expected.Metadata[k]) + } + + switch { + case expected.AccountNumber != nil && actual.AccountNumber != nil: + require.Equal(t, *expected.AccountNumber, *actual.AccountNumber) + case expected.AccountNumber == nil && actual.AccountNumber == nil: + // Nothing to do + default: + require.Fail(t, "AccountNumber mismatch") + } + + switch { + case expected.IBAN != nil && actual.IBAN != nil: + require.Equal(t, *expected.IBAN, *actual.IBAN) + case expected.IBAN == nil && actual.IBAN == nil: + // Nothing to do + default: + require.Fail(t, "IBAN mismatch") + } + + switch { + case expected.SwiftBicCode != nil && actual.SwiftBicCode != nil: + require.Equal(t, *expected.SwiftBicCode, *actual.SwiftBicCode) + case expected.SwiftBicCode == nil && actual.SwiftBicCode == nil: + // Nothing to do + default: + require.Fail(t, "SwiftBicCode mismatch") + } + + switch { + case expected.Country != nil && actual.Country != nil: + require.Equal(t, *expected.Country, *actual.Country) + case expected.Country == nil && actual.Country == nil: + // Nothing to do + default: + require.Fail(t, "Country mismatch") + } + + require.Equal(t, len(expected.RelatedAccounts), len(actual.RelatedAccounts)) + for i := range expected.RelatedAccounts { + require.Equal(t, expected.RelatedAccounts[i], actual.RelatedAccounts[i]) + } +} diff --git a/components/payments/internal/storage/connectors.go b/components/payments/internal/storage/connectors.go index ddf9c04d11..7b0e4d3b1f 100644 --- a/components/payments/internal/storage/connectors.go +++ b/components/payments/internal/storage/connectors.go @@ -40,7 +40,7 @@ func (s *store) ConnectorsInstall(ctx context.Context, c models.Connector) error toInsert := connector{ ID: c.ID, Name: c.Name, - CreatedAt: c.CreatedAt, + CreatedAt: c.CreatedAt.UTC(), Provider: c.Provider, } @@ -88,7 +88,7 @@ func (s *store) ConnectorsGet(ctx context.Context, id models.ConnectorID) (*mode return &models.Connector{ ID: connector.ID, Name: connector.Name, - CreatedAt: connector.CreatedAt, + CreatedAt: connector.CreatedAt.UTC(), Provider: connector.Provider, Config: connector.DecryptedConfig, }, nil @@ -155,7 +155,7 @@ func (s *store) ConnectorsList(ctx context.Context, q ListConnectorsQuery) (*bun connectors = append(connectors, models.Connector{ ID: c.ID, Name: c.Name, - CreatedAt: c.CreatedAt, + CreatedAt: c.CreatedAt.UTC(), Provider: c.Provider, Config: c.DecryptedConfig, }) diff --git a/components/payments/internal/storage/connectors_test.go b/components/payments/internal/storage/connectors_test.go new file mode 100644 index 0000000000..4b57e50203 --- /dev/null +++ b/components/payments/internal/storage/connectors_test.go @@ -0,0 +1,282 @@ +package storage + +import ( + "context" + "testing" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/query" + "github.com/formancehq/go-libs/time" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + now = time.Now() + defaultConnector = models.Connector{ + ID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "default", + }, + Name: "default", + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Provider: "default", + Config: []byte(`{}`), + } + + defaultConnector2 = models.Connector{ + ID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "default2", + }, + Name: "default2", + CreatedAt: now.Add(-45 * time.Minute).UTC().Time, + Provider: "default2", + Config: []byte(`{}`), + } + + defaultConnector3 = models.Connector{ + ID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "default", + }, + Name: "default3", + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Provider: "default", + Config: []byte(`{}`), + } +) + +func upsertConnector(t *testing.T, ctx context.Context, storage Storage, connector models.Connector) { + require.NoError(t, storage.ConnectorsInstall(ctx, connector)) +} + +func TestConnectorsInstall(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertConnector(t, ctx, store, defaultConnector2) + + t.Run("same id upsert", func(t *testing.T) { + c := models.Connector{ + ID: defaultConnector.ID, + Name: "test changed", + CreatedAt: time.Now().UTC().Time, + Provider: "test", + Config: []byte(`{}`), + } + + require.NoError(t, store.ConnectorsInstall(ctx, c)) + + connector, err := store.ConnectorsGet(ctx, c.ID) + require.NoError(t, err) + require.NotNil(t, connector) + require.Equal(t, defaultConnector, *connector) + }) + + t.Run("unique same upsert", func(t *testing.T) { + c := models.Connector{ + ID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "test", + }, + Name: "default", + CreatedAt: now.Add(-23 * time.Minute).UTC().Time, + Provider: "test", + Config: []byte(`{}`), + } + + require.Error(t, store.ConnectorsInstall(ctx, c)) + }) +} + +func TestConnectorsUninstall(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertConnector(t, ctx, store, defaultConnector2) + + t.Run("uninstall default connector", func(t *testing.T) { + require.NoError(t, store.ConnectorsUninstall(ctx, defaultConnector.ID)) + + connector, err := store.ConnectorsGet(ctx, defaultConnector.ID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + require.Nil(t, connector) + }) + + t.Run("uninstall unknown connector", func(t *testing.T) { + id := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + require.NoError(t, store.ConnectorsUninstall(ctx, id)) + }) +} + +func TestConnectorsGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertConnector(t, ctx, store, defaultConnector2) + + t.Run("get connector", func(t *testing.T) { + connector, err := store.ConnectorsGet(ctx, defaultConnector.ID) + require.NoError(t, err) + require.NotNil(t, connector) + require.Equal(t, defaultConnector, *connector) + }) + + t.Run("get unknown connector", func(t *testing.T) { + id := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + connector, err := store.ConnectorsGet(ctx, id) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + require.Nil(t, connector) + }) +} + +func TestConnectorsList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertConnector(t, ctx, store, defaultConnector2) + upsertConnector(t, ctx, store, defaultConnector3) + + t.Run("list connectors by name", func(t *testing.T) { + q := NewListConnectorsQuery( + bunpaginate.NewPaginatedQueryOptions(ConnectorQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "default")), + ) + + cursor, err := store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Next) + require.Empty(t, cursor.Previous) + require.Equal(t, defaultConnector, cursor.Data[0]) + }) + + t.Run("list connectors by unknown name", func(t *testing.T) { + q := NewListConnectorsQuery( + bunpaginate.NewPaginatedQueryOptions(ConnectorQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "unknown")), + ) + + cursor, err := store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Empty(t, cursor.Data) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Next) + require.Empty(t, cursor.Previous) + }) + + t.Run("list connectors by provider", func(t *testing.T) { + q := NewListConnectorsQuery( + bunpaginate.NewPaginatedQueryOptions(ConnectorQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("provider", "default")), + ) + + cursor, err := store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Next) + require.Empty(t, cursor.Previous) + require.Equal(t, defaultConnector3, cursor.Data[0]) + require.Equal(t, defaultConnector, cursor.Data[1]) + }) + + t.Run("list connectors by unknown provider", func(t *testing.T) { + q := NewListConnectorsQuery( + bunpaginate.NewPaginatedQueryOptions(ConnectorQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("provider", "unknown")), + ) + + cursor, err := store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Empty(t, cursor.Data) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Next) + require.Empty(t, cursor.Previous) + }) + + t.Run("list connectors test cursor", func(t *testing.T) { + q := NewListConnectorsQuery( + bunpaginate.NewPaginatedQueryOptions(ConnectorQuery{}). + WithPageSize(1), + ) + + cursor, err := store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.Empty(t, cursor.Previous) + require.Equal(t, defaultConnector3, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultConnector2, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultConnector, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.NotEmpty(t, cursor.Previous) + require.Equal(t, defaultConnector2, cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.ConnectorsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Next) + require.Empty(t, cursor.Previous) + require.Equal(t, defaultConnector3, cursor.Data[0]) + }) +} diff --git a/components/payments/internal/storage/main_test.go b/components/payments/internal/storage/main_test.go index 02d459423e..4c7bc23f80 100644 --- a/components/payments/internal/storage/main_test.go +++ b/components/payments/internal/storage/main_test.go @@ -1,43 +1,58 @@ package storage -// func TestMain(m *testing.M) { -// if err := pgtesting.CreatePostgresServer(); err != nil { -// logging.Error(err) -// os.Exit(1) -// } - -// code := m.Run() -// if err := pgtesting.DestroyPostgresServer(); err != nil { -// logging.Error(err) -// } -// os.Exit(code) -// } - -// func newStore(t *testing.T) Storage { -// t.Helper() - -// pgServer := pgtesting.NewPostgresDatabase(t) - -// config, err := pgx.ParseConfig(pgServer.ConnString()) -// require.NoError(t, err) - -// key := make([]byte, 64) -// _, err = rand.Read(key) -// require.NoError(t, err) - -// db := bun.NewDB(stdlib.OpenDB(*config), pgdialect.New()) -// t.Cleanup(func() { -// _ = db.Close() -// }) - -// // TODO(polo): add migrations -// // err = migrationstorage.Migrate(context.Background(), db) -// // require.NoError(t, err) - -// store := newStorage( -// db, -// string(key), -// ) - -// return store -// } +import ( + "context" + "crypto/rand" + "database/sql" + "os" + "testing" + + "github.com/formancehq/go-libs/bun/bunconnect" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/testing/docker" + "github.com/formancehq/go-libs/testing/platform/pgtesting" + "github.com/formancehq/go-libs/testing/utils" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +var ( + srv *pgtesting.PostgresServer + bunDB *bun.DB +) + +func TestMain(m *testing.M) { + utils.WithTestMain(func(t *utils.TestingTForMain) int { + srv = pgtesting.CreatePostgresServer(t, docker.NewPool(t, logging.Testing())) + + db, err := sql.Open("postgres", srv.GetDSN()) + if err != nil { + logging.Error(err) + os.Exit(1) + } + + bunDB = bun.NewDB(db, pgdialect.New()) + + return m.Run() + }) +} + +func newStore(t *testing.T) Storage { + t.Helper() + ctx := logging.TestingContext() + + pgServer := srv.NewDatabase(t) + + db, err := bunconnect.OpenSQLDB(ctx, pgServer.ConnectionOptions()) + require.NoError(t, err) + + key := make([]byte, 64) + _, err = rand.Read(key) + require.NoError(t, err) + + err = Migrate(context.Background(), db) + require.NoError(t, err) + + return newStorage(db, string(key)) +} diff --git a/components/payments/internal/storage/migrations/0-init-schema.sql b/components/payments/internal/storage/migrations/0-init-schema.sql index 20c87f2628..66466cfbf4 100644 --- a/components/payments/internal/storage/migrations/0-init-schema.sql +++ b/components/payments/internal/storage/migrations/0-init-schema.sql @@ -166,7 +166,7 @@ create table if not exists pools ( ); create unique index pools_unique_name on pools (name); -create table if not exists pools_related_accounts ( +create table if not exists pool_accounts ( -- Mandatory fields pool_id uuid not null, account_id varchar not null, @@ -174,12 +174,12 @@ create table if not exists pools_related_accounts ( -- Primary key primary key (pool_id, account_id) ); -alter table pools_related_accounts - add constraint pools_related_accounts_pool_id_fk foreign key (pool_id) +alter table pool_accounts + add constraint pool_accounts_pool_id_fk foreign key (pool_id) references pools (id) on delete cascade; -alter table pools_related_accounts - add constraint pools_related_accounts_account_id_fk foreign key (account_id) +alter table pool_accounts + add constraint pool_accounts_account_id_fk foreign key (account_id) references accounts (id) on delete cascade; @@ -252,6 +252,10 @@ alter table workflows_instances add constraint workflows_instances_connector_id_fk foreign key (connector_id) references connectors (id) on delete cascade; +alter table workflows_instances + add constraint workflows_instances_schedule_id_fk foreign key (schedule_id, connector_id) + references schedules (id, connector_id) + on delete cascade; -- Webhook configs create table if not exists webhooks_configs ( diff --git a/components/payments/internal/storage/payments.go b/components/payments/internal/storage/payments.go index 25a0fd5765..a4577440de 100644 --- a/components/payments/internal/storage/payments.go +++ b/components/payments/internal/storage/payments.go @@ -65,11 +65,18 @@ type paymentAdjustment struct { func (s *store) PaymentsUpsert(ctx context.Context, payments []models.Payment) error { paymentsToInsert := make([]payment, 0, len(payments)) adjustmentsToInsert := make([]paymentAdjustment, 0) + paymentsRefunded := make([]payment, 0) for _, p := range payments { paymentsToInsert = append(paymentsToInsert, fromPaymentModels(p)) for _, a := range p.Adjustments { adjustmentsToInsert = append(adjustmentsToInsert, fromPaymentAdjustmentModels(a)) + switch a.Status { + case models.PAYMENT_STATUS_REFUNDED: + res := fromPaymentModels(p) + res.Amount = a.Amount + paymentsRefunded = append(paymentsRefunded, res) + } } } @@ -77,21 +84,37 @@ func (s *store) PaymentsUpsert(ctx context.Context, payments []models.Payment) e if err != nil { return errors.Wrap(err, "failed to create transaction") } + defer tx.Rollback() - _, err = tx.NewInsert(). - Model(&paymentsToInsert). - On("CONFLICT (id) DO NOTHING"). - Exec(ctx) - if err != nil { - return e("failed to insert payments", err) + if len(paymentsToInsert) > 0 { + _, err = tx.NewInsert(). + Model(&paymentsToInsert). + On("CONFLICT (id) DO NOTHING"). + Exec(ctx) + if err != nil { + return e("failed to insert payments", err) + } } - _, err = tx.NewInsert(). - Model(&adjustmentsToInsert). - On("CONFLICT (id) DO NOTHING"). - Exec(ctx) - if err != nil { - return e("failed to insert adjustments", err) + if len(paymentsRefunded) > 0 { + _, err = tx.NewInsert(). + Model(&paymentsRefunded). + On("CONFLICT (id) DO UPDATE"). + Set("amount = payment.amount - EXCLUDED.amount"). + Exec(ctx) + if err != nil { + return e("failed to update payment", err) + } + } + + if len(adjustmentsToInsert) > 0 { + _, err = tx.NewInsert(). + Model(&adjustmentsToInsert). + On("CONFLICT (id) DO NOTHING"). + Exec(ctx) + if err != nil { + return e("failed to insert adjustments", err) + } } return e("failed to commit transactions", tx.Commit()) @@ -160,7 +183,11 @@ func (s *store) PaymentsGet(ctx context.Context, id models.PaymentID) (*models.P adjustments = append(adjustments, toPaymentAdjustmentModels(a)) } - res := toPaymentModels(payment, adjustments[len(adjustments)-1].Status) + status := models.PAYMENT_STATUS_PENDING + if len(adjustments) > 0 { + status = adjustments[len(adjustments)-1].Status + } + res := toPaymentModels(payment, status) res.Adjustments = adjustments return &res, nil } @@ -282,7 +309,7 @@ func fromPaymentModels(from models.Payment) payment { ID: from.ID, ConnectorID: from.ConnectorID, Reference: from.Reference, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Type: from.Type, InitialAmount: from.InitialAmount, Amount: from.Amount, @@ -300,7 +327,7 @@ func toPaymentModels(payment payment, status models.PaymentStatus) models.Paymen ConnectorID: payment.ConnectorID, InitialAmount: payment.InitialAmount, Reference: payment.Reference, - CreatedAt: payment.CreatedAt, + CreatedAt: payment.CreatedAt.UTC(), Type: payment.Type, Amount: payment.Amount, Asset: payment.Asset, @@ -316,7 +343,7 @@ func fromPaymentAdjustmentModels(from models.PaymentAdjustment) paymentAdjustmen return paymentAdjustment{ ID: from.ID, PaymentID: from.PaymentID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Status: from.Status, Amount: from.Amount, Asset: from.Asset, @@ -329,7 +356,7 @@ func toPaymentAdjustmentModels(from paymentAdjustment) models.PaymentAdjustment return models.PaymentAdjustment{ ID: from.ID, PaymentID: from.PaymentID, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), Status: from.Status, Amount: from.Amount, Asset: from.Asset, diff --git a/components/payments/internal/storage/payments_test.go b/components/payments/internal/storage/payments_test.go new file mode 100644 index 0000000000..9ba519c828 --- /dev/null +++ b/components/payments/internal/storage/payments_test.go @@ -0,0 +1,898 @@ +package storage + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/pointer" + "github.com/formancehq/go-libs/query" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + pID1 = models.PaymentID{ + PaymentReference: models.PaymentReference{ + Reference: "test1", + Type: models.PAYMENT_TYPE_TRANSFER, + }, + ConnectorID: defaultConnector.ID, + } + + pid2 = models.PaymentID{ + PaymentReference: models.PaymentReference{ + Reference: "test2", + Type: models.PAYMENT_TYPE_PAYIN, + }, + ConnectorID: defaultConnector.ID, + } + + pid3 = models.PaymentID{ + PaymentReference: models.PaymentReference{ + Reference: "test3", + Type: models.PAYMENT_TYPE_PAYOUT, + }, + ConnectorID: defaultConnector.ID, + } + + defaultPayments = []models.Payment{ + { + ID: pID1, + ConnectorID: defaultConnector.ID, + Reference: "test1", + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_TRANSFER, + InitialAmount: big.NewInt(100), + Amount: big.NewInt(100), + Asset: "USD/2", + Scheme: models.PAYMENT_SCHEME_OTHER, + SourceAccountID: &defaultAccounts[0].ID, + DestinationAccountID: &defaultAccounts[1].ID, + Metadata: map[string]string{ + "key1": "value1", + }, + Adjustments: []models.PaymentAdjustment{ + { + ID: models.PaymentAdjustmentID{ + PaymentID: pID1, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_SUCCEEDED, + }, + PaymentID: pID1, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_SUCCEEDED, + Amount: big.NewInt(100), + Asset: pointer.For("USD/2"), + Raw: []byte(`{}`), + }, + }, + }, + { + ID: pid2, + ConnectorID: defaultConnector.ID, + Reference: "test2", + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_PAYIN, + InitialAmount: big.NewInt(200), + Amount: big.NewInt(200), + Asset: "EUR/2", + Scheme: models.PAYMENT_SCHEME_OTHER, + DestinationAccountID: &defaultAccounts[0].ID, + Adjustments: []models.PaymentAdjustment{ + { + ID: models.PaymentAdjustmentID{ + PaymentID: pid2, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_FAILED, + }, + PaymentID: pid2, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_FAILED, + Amount: big.NewInt(200), + Asset: pointer.For("EUR/2"), + Raw: []byte(`{}`), + }, + }, + }, + { + ID: pid3, + ConnectorID: defaultConnector.ID, + Reference: "test3", + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_PAYOUT, + InitialAmount: big.NewInt(300), + Amount: big.NewInt(300), + Asset: "DKK/2", + Scheme: models.PAYMENT_SCHEME_A2A, + SourceAccountID: &defaultAccounts[1].ID, + Adjustments: []models.PaymentAdjustment{ + { + ID: models.PaymentAdjustmentID{ + PaymentID: pid3, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_PENDING, + }, + PaymentID: pid3, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_PENDING, + Amount: big.NewInt(300), + Asset: pointer.For("DKK/2"), + Raw: []byte(`{}`), + }, + }, + }, + } +) + +func upsertPayments(t *testing.T, ctx context.Context, storage Storage, payments []models.Payment) { + require.NoError(t, storage.PaymentsUpsert(ctx, payments)) +} + +func TestPaymentsUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPayments(t, ctx, store, defaultPayments) + + t.Run("upsert with unknown connector", func(t *testing.T) { + connector := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + p := defaultPayments[0] + p.ID = models.PaymentID{ + PaymentReference: models.PaymentReference{ + Reference: "test4", + Type: models.PAYMENT_TYPE_PAYOUT, + }, + ConnectorID: connector, + } + p.ConnectorID = connector + + err := store.PaymentsUpsert(ctx, []models.Payment{p}) + require.Error(t, err) + }) + + t.Run("upsert with same id", func(t *testing.T) { + p := defaultPayments[2] + p.Amount = big.NewInt(200) + p.Scheme = models.PAYMENT_SCHEME_A2A + upsertPayments(t, ctx, store, []models.Payment{p}) + + // should not have changed + actual, err := store.PaymentsGet(ctx, p.ID) + require.NoError(t, err) + + comparePayments(t, defaultPayments[2], *actual) + }) + + t.Run("upsert with different adjustments", func(t *testing.T) { + p := models.Payment{ + ID: pid3, + ConnectorID: defaultConnector.ID, + Reference: "test3", + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_PAYOUT, + InitialAmount: big.NewInt(300), + Amount: big.NewInt(300), + Asset: "DKK/2", + Scheme: models.PAYMENT_SCHEME_A2A, + SourceAccountID: &defaultAccounts[1].ID, + Adjustments: []models.PaymentAdjustment{ + { + ID: models.PaymentAdjustmentID{ + PaymentID: pid3, + CreatedAt: now.Add(-45 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_SUCCEEDED, + }, + PaymentID: pid3, + CreatedAt: now.Add(-45 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_SUCCEEDED, + Amount: big.NewInt(300), + Asset: pointer.For("DKK/2"), + Metadata: map[string]string{}, + Raw: []byte(`{}`), + }, + { + ID: models.PaymentAdjustmentID{ + PaymentID: pid3, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_PENDING, + }, + PaymentID: pid3, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_PENDING, + Amount: big.NewInt(300), + Asset: pointer.For("DKK/2"), + Raw: []byte(`{}`), + }, + }, + } + + upsertPayments(t, ctx, store, []models.Payment{p}) + + actual, err := store.PaymentsGet(ctx, p.ID) + require.NoError(t, err) + comparePayments(t, p, *actual) + }) + + t.Run("upsert with refund", func(t *testing.T) { + p := models.Payment{ + ID: pID1, + ConnectorID: defaultConnector.ID, + Reference: "test1", + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_TRANSFER, + InitialAmount: big.NewInt(100), + Amount: big.NewInt(100), + Asset: "USD/2", + Scheme: models.PAYMENT_SCHEME_OTHER, + Adjustments: []models.PaymentAdjustment{ + { + ID: models.PaymentAdjustmentID{ + PaymentID: pID1, + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_REFUNDED, + }, + PaymentID: pID1, + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_REFUNDED, + Amount: big.NewInt(50), + Asset: pointer.For("USD/2"), + Raw: []byte(`{}`), + }, + }, + } + + upsertPayments(t, ctx, store, []models.Payment{p}) + + actual, err := store.PaymentsGet(ctx, p.ID) + require.NoError(t, err) + + expectedPayments := models.Payment{ + ID: pID1, + ConnectorID: defaultConnector.ID, + Reference: "test1", + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_TRANSFER, + InitialAmount: big.NewInt(100), + Amount: big.NewInt(50), + Asset: "USD/2", + Scheme: models.PAYMENT_SCHEME_OTHER, + Status: models.PAYMENT_STATUS_REFUNDED, + SourceAccountID: &defaultAccounts[0].ID, + DestinationAccountID: &defaultAccounts[1].ID, + Metadata: map[string]string{ + "key1": "value1", + }, + Adjustments: []models.PaymentAdjustment{ + { + ID: models.PaymentAdjustmentID{ + PaymentID: pID1, + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_REFUNDED, + }, + PaymentID: pID1, + CreatedAt: now.Add(-20 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_REFUNDED, + Amount: big.NewInt(50), + Asset: pointer.For("USD/2"), + Raw: []byte(`{}`), + }, + { + ID: models.PaymentAdjustmentID{ + PaymentID: pID1, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_SUCCEEDED, + }, + PaymentID: pID1, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Status: models.PAYMENT_STATUS_SUCCEEDED, + Amount: big.NewInt(100), + Asset: pointer.For("USD/2"), + Raw: []byte(`{}`), + }, + }, + } + + comparePayments(t, expectedPayments, *actual) + }) +} + +func TestPaymentsUpdateMetadata(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPayments(t, ctx, store, defaultPayments) + + t.Run("update metadata of unknown payment", func(t *testing.T) { + require.Error(t, store.PaymentsUpdateMetadata(ctx, models.PaymentID{ + PaymentReference: models.PaymentReference{Reference: "unknown", Type: models.PAYMENT_TYPE_TRANSFER}, + ConnectorID: defaultConnector.ID, + }, map[string]string{})) + }) + + t.Run("update existing metadata", func(t *testing.T) { + metadata := map[string]string{ + "key1": "changed", + } + + require.NoError(t, store.PaymentsUpdateMetadata(ctx, defaultPayments[0].ID, metadata)) + + actual, err := store.PaymentsGet(ctx, defaultPayments[0].ID) + require.NoError(t, err) + require.Equal(t, len(metadata), len(actual.Metadata)) + for k, v := range metadata { + _, ok := actual.Metadata[k] + require.True(t, ok) + require.Equal(t, v, actual.Metadata[k]) + } + }) + + t.Run("add new metadata", func(t *testing.T) { + metadata := map[string]string{ + "key2": "value2", + "key3": "value3", + } + + require.NoError(t, store.PaymentsUpdateMetadata(ctx, defaultPayments[1].ID, metadata)) + + actual, err := store.PaymentsGet(ctx, defaultPayments[1].ID) + require.NoError(t, err) + require.Equal(t, len(metadata), len(actual.Metadata)) + for k, v := range metadata { + _, ok := actual.Metadata[k] + require.True(t, ok) + require.Equal(t, v, actual.Metadata[k]) + } + }) +} + +func TestPaymentsGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPayments(t, ctx, store, defaultPayments) + + t.Run("get unknown payment", func(t *testing.T) { + _, err := store.PaymentsGet(ctx, models.PaymentID{ + PaymentReference: models.PaymentReference{Reference: "unknown", Type: models.PAYMENT_TYPE_TRANSFER}, + ConnectorID: defaultConnector.ID, + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + }) + + t.Run("get existing payments", func(t *testing.T) { + for _, p := range defaultPayments { + actual, err := store.PaymentsGet(ctx, p.ID) + require.NoError(t, err) + comparePayments(t, p, *actual) + } + }) +} + +func TestPaymentsDeleteFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPayments(t, ctx, store, defaultPayments) + + t.Run("delete from unknown connector", func(t *testing.T) { + require.NoError(t, store.PaymentsDeleteFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })) + + for _, p := range defaultPayments { + actual, err := store.PaymentsGet(ctx, p.ID) + require.NoError(t, err) + comparePayments(t, p, *actual) + } + }) + + t.Run("delete from existing connector", func(t *testing.T) { + require.NoError(t, store.PaymentsDeleteFromConnectorID(ctx, defaultConnector.ID)) + + for _, p := range defaultPayments { + _, err := store.PaymentsGet(ctx, p.ID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + } + }) +} + +func TestPaymentsList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPayments(t, ctx, store, defaultPayments) + + dps := []models.Payment{ + { + ID: pID1, + ConnectorID: defaultConnector.ID, + Reference: "test1", + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_TRANSFER, + InitialAmount: big.NewInt(100), + Amount: big.NewInt(100), + Asset: "USD/2", + Scheme: models.PAYMENT_SCHEME_OTHER, + Status: models.PAYMENT_STATUS_SUCCEEDED, + SourceAccountID: &defaultAccounts[0].ID, + DestinationAccountID: &defaultAccounts[1].ID, + Metadata: map[string]string{ + "key1": "value1", + }, + }, + { + ID: pid2, + ConnectorID: defaultConnector.ID, + Reference: "test2", + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_PAYIN, + InitialAmount: big.NewInt(200), + Amount: big.NewInt(200), + Asset: "EUR/2", + Scheme: models.PAYMENT_SCHEME_OTHER, + Status: models.PAYMENT_STATUS_FAILED, + DestinationAccountID: &defaultAccounts[0].ID, + }, + { + ID: pid3, + ConnectorID: defaultConnector.ID, + Reference: "test3", + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + Type: models.PAYMENT_TYPE_PAYOUT, + InitialAmount: big.NewInt(300), + Amount: big.NewInt(300), + Asset: "DKK/2", + Scheme: models.PAYMENT_SCHEME_A2A, + Status: models.PAYMENT_STATUS_PENDING, + SourceAccountID: &defaultAccounts[1].ID, + }, + } + + t.Run("list payments by reference", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("reference", "test1")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[0], cursor.Data[0]) + }) + + t.Run("list payments by unknown reference", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("reference", "unknown")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by connector_id", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", defaultConnector.ID.String())), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 3) + require.False(t, cursor.HasMore) + comparePayments(t, dps[1], cursor.Data[0]) + comparePayments(t, dps[2], cursor.Data[1]) + comparePayments(t, dps[0], cursor.Data[2]) + }) + + t.Run("list payments by unknown connector_id", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", "unknown")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by type", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("type", "PAYOUT")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[2], cursor.Data[0]) + }) + + t.Run("list payments by unknown type", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("type", "UNKNOWN")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by asset", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("asset", "EUR/2")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[1], cursor.Data[0]) + }) + + t.Run("list payments by unknown asset", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("asset", "UNKNOWN")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by scheme", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("scheme", "OTHER")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 2) + require.False(t, cursor.HasMore) + comparePayments(t, dps[1], cursor.Data[0]) + comparePayments(t, dps[0], cursor.Data[1]) + }) + + t.Run("list payments by unknown scheme", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("scheme", "UNKNOWN")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by status", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("status", "PENDING")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[2], cursor.Data[0]) + }) + + t.Run("list payments by unknown status", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("status", "UNKNOWN")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by source account id", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("source_account_id", defaultAccounts[0].ID.String())), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[0], cursor.Data[0]) + }) + + t.Run("list payments by unknown source account id", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("source_account_id", "unknown")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by destination account id", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("destination_account_id", defaultAccounts[0].ID.String())), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[1], cursor.Data[0]) + }) + + t.Run("list payments by unknown destination account id", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("destination_account_id", "unknown")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by amount", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("amount", 200)), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[1], cursor.Data[0]) + }) + + t.Run("list payments by unknown amount", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("amount", 0)), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by initial_amount", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("initial_amount", 300)), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[2], cursor.Data[0]) + }) + + t.Run("list payments by unknown initial_amount", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("initial_amount", 0)), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by metadata", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("metadata[key1]", "value1")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + comparePayments(t, dps[0], cursor.Data[0]) + }) + + t.Run("list payments by unknown metadata", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("metadata[key1]", "unknown")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments by unknown metadata 2", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("metadata[unknown]", "unknown")), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 0) + require.False(t, cursor.HasMore) + }) + + t.Run("list payments test cursor", func(t *testing.T) { + q := NewListPaymentsQuery( + bunpaginate.NewPaginatedQueryOptions(PaymentQuery{}). + WithPageSize(1), + ) + + cursor, err := store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + comparePayments(t, dps[1], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + comparePayments(t, dps[2], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.Empty(t, cursor.Next) + comparePayments(t, dps[0], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + comparePayments(t, dps[2], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.PaymentsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + comparePayments(t, dps[1], cursor.Data[0]) + }) +} + +func comparePayments(t *testing.T, expected, actual models.Payment) { + require.Equal(t, expected.ID, actual.ID) + require.Equal(t, expected.ConnectorID, actual.ConnectorID) + require.Equal(t, expected.Reference, actual.Reference) + require.Equal(t, expected.CreatedAt, actual.CreatedAt) + require.Equal(t, expected.Type, actual.Type) + require.Equal(t, expected.InitialAmount, actual.InitialAmount) + require.Equal(t, expected.Amount, actual.Amount) + require.Equal(t, expected.Asset, actual.Asset) + require.Equal(t, expected.Scheme, actual.Scheme) + + switch { + case expected.SourceAccountID == nil: + require.Nil(t, actual.SourceAccountID) + default: + require.NotNil(t, actual.SourceAccountID) + require.Equal(t, *expected.SourceAccountID, *actual.SourceAccountID) + } + + switch { + case expected.DestinationAccountID == nil: + require.Nil(t, actual.DestinationAccountID) + default: + require.NotNil(t, actual.DestinationAccountID) + require.Equal(t, *expected.DestinationAccountID, *actual.DestinationAccountID) + } + + require.Equal(t, len(expected.Metadata), len(actual.Metadata)) + for k, v := range expected.Metadata { + _, ok := actual.Metadata[k] + require.True(t, ok) + require.Equal(t, v, actual.Metadata[k]) + } + + require.Equal(t, len(expected.Adjustments), len(actual.Adjustments)) + for i := range expected.Adjustments { + comparePaymentAdjustments(t, expected.Adjustments[i], actual.Adjustments[i]) + } +} + +func comparePaymentAdjustments(t *testing.T, expected, actual models.PaymentAdjustment) { + require.Equal(t, expected.ID, actual.ID) + require.Equal(t, expected.PaymentID, actual.PaymentID) + require.Equal(t, expected.CreatedAt, actual.CreatedAt) + require.Equal(t, expected.Status, actual.Status) + require.Equal(t, expected.Amount, actual.Amount) + require.Equal(t, expected.Asset, actual.Asset) +} diff --git a/components/payments/internal/storage/pools.go b/components/payments/internal/storage/pools.go index 6655c9c21f..13b63ffced 100644 --- a/components/payments/internal/storage/pools.go +++ b/components/payments/internal/storage/pools.go @@ -203,7 +203,7 @@ func fromPoolModel(from models.Pool) (pool, []poolAccounts) { p := pool{ ID: from.ID, Name: from.Name, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), } var accounts []poolAccounts @@ -229,7 +229,7 @@ func toPoolModel(from pool) models.Pool { return models.Pool{ ID: from.ID, Name: from.Name, - CreatedAt: from.CreatedAt, + CreatedAt: from.CreatedAt.UTC(), PoolAccounts: accounts, } } diff --git a/components/payments/internal/storage/pools_test.go b/components/payments/internal/storage/pools_test.go new file mode 100644 index 0000000000..a5be5db247 --- /dev/null +++ b/components/payments/internal/storage/pools_test.go @@ -0,0 +1,374 @@ +package storage + +import ( + "context" + "testing" + "time" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/query" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + poolID1 = uuid.New() + poolID2 = uuid.New() + poolID3 = uuid.New() + defaultPools = []models.Pool{ + { + ID: poolID1, + Name: "test1", + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + PoolAccounts: []models.PoolAccounts{ + { + PoolID: poolID1, + AccountID: defaultAccounts[0].ID, + }, + { + PoolID: poolID1, + AccountID: defaultAccounts[1].ID, + }, + }, + }, + { + ID: poolID2, + Name: "test2", + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + PoolAccounts: []models.PoolAccounts{ + { + PoolID: poolID2, + AccountID: defaultAccounts[2].ID, + }, + }, + }, + { + ID: poolID3, + Name: "test3", + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + PoolAccounts: []models.PoolAccounts{ + { + PoolID: poolID3, + AccountID: defaultAccounts[2].ID, + }, + }, + }, + } +) + +func upsertPool(t *testing.T, ctx context.Context, storage Storage, pool models.Pool) { + require.NoError(t, storage.PoolsUpsert(ctx, pool)) +} + +func TestPoolsUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPool(t, ctx, store, defaultPools[0]) + upsertPool(t, ctx, store, defaultPools[1]) + + t.Run("upsert with same name", func(t *testing.T) { + poolID3 := uuid.New() + p := models.Pool{ + ID: poolID3, + Name: "test1", + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + PoolAccounts: []models.PoolAccounts{ + { + PoolID: poolID3, + AccountID: defaultAccounts[2].ID, + }, + }, + } + + err := store.PoolsUpsert(ctx, p) + require.Error(t, err) + }) + + t.Run("upsert with same id", func(t *testing.T) { + upsertPool(t, ctx, store, defaultPools[1]) + + actual, err := store.PoolsGet(ctx, defaultPools[1].ID) + require.NoError(t, err) + require.Equal(t, defaultPools[1], *actual) + }) + + t.Run("upsert with same id but more related accounts", func(t *testing.T) { + p := defaultPools[0] + p.PoolAccounts = append(p.PoolAccounts, models.PoolAccounts{ + PoolID: p.ID, + AccountID: defaultAccounts[2].ID, + }) + + upsertPool(t, ctx, store, p) + + actual, err := store.PoolsGet(ctx, defaultPools[0].ID) + require.NoError(t, err) + require.Equal(t, p, *actual) + }) + + t.Run("upsert with same id, but wrong related account pool id", func(t *testing.T) { + p := defaultPools[0] + p.PoolAccounts = append(p.PoolAccounts, models.PoolAccounts{ + PoolID: uuid.New(), + AccountID: defaultAccounts[2].ID, + }) + + err := store.PoolsUpsert(ctx, p) + require.Error(t, err) + }) + + t.Run("upsert with same id, but wrong related account account id", func(t *testing.T) { + p := defaultPools[0] + p.PoolAccounts = append(p.PoolAccounts, models.PoolAccounts{ + PoolID: p.ID, + AccountID: models.AccountID{ + Reference: "unknown", + ConnectorID: defaultConnector.ID, + }, + }) + + err := store.PoolsUpsert(ctx, p) + require.Error(t, err) + }) +} + +func TestPoolsGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPool(t, ctx, store, defaultPools[0]) + upsertPool(t, ctx, store, defaultPools[1]) + upsertPool(t, ctx, store, defaultPools[2]) + + t.Run("get existing pool", func(t *testing.T) { + for _, p := range defaultPools { + actual, err := store.PoolsGet(ctx, p.ID) + require.NoError(t, err) + require.Equal(t, p, *actual) + } + }) + + t.Run("get non-existing pool", func(t *testing.T) { + p, err := store.PoolsGet(ctx, uuid.New()) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + require.Nil(t, p) + }) +} + +func TestPoolsDelete(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPool(t, ctx, store, defaultPools[0]) + upsertPool(t, ctx, store, defaultPools[1]) + upsertPool(t, ctx, store, defaultPools[2]) + + t.Run("delete unknown pool", func(t *testing.T) { + require.NoError(t, store.PoolsDelete(ctx, uuid.New())) + for _, p := range defaultPools { + actual, err := store.PoolsGet(ctx, p.ID) + require.NoError(t, err) + require.Equal(t, p, *actual) + } + }) + + t.Run("delete existing pool", func(t *testing.T) { + require.NoError(t, store.PoolsDelete(ctx, defaultPools[0].ID)) + + _, err := store.PoolsGet(ctx, defaultPools[0].ID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + + actual, err := store.PoolsGet(ctx, defaultPools[1].ID) + require.NoError(t, err) + require.Equal(t, defaultPools[1], *actual) + }) +} + +func TestPoolsAddAccount(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPool(t, ctx, store, defaultPools[0]) + upsertPool(t, ctx, store, defaultPools[1]) + + t.Run("add unknown account to pool", func(t *testing.T) { + err := store.PoolsAddAccount(ctx, defaultPools[0].ID, models.AccountID{ + Reference: "unknown", + ConnectorID: defaultConnector.ID, + }) + require.Error(t, err) + }) + + t.Run("add account to unknown pool", func(t *testing.T) { + err := store.PoolsAddAccount(ctx, uuid.New(), defaultAccounts[0].ID) + require.Error(t, err) + }) + + t.Run("add account to pool", func(t *testing.T) { + require.NoError(t, store.PoolsAddAccount(ctx, defaultPools[0].ID, defaultAccounts[2].ID)) + + p := defaultPools[0] + p.PoolAccounts = append(p.PoolAccounts, models.PoolAccounts{ + PoolID: p.ID, + AccountID: defaultAccounts[2].ID, + }) + + actual, err := store.PoolsGet(ctx, defaultPools[0].ID) + require.NoError(t, err) + require.Equal(t, p, *actual) + }) +} + +func TestPoolsRemoveAccount(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPool(t, ctx, store, defaultPools[0]) + upsertPool(t, ctx, store, defaultPools[1]) + + t.Run("remove unknown account from pool", func(t *testing.T) { + require.NoError(t, store.PoolsRemoveAccount(ctx, defaultPools[0].ID, models.AccountID{ + Reference: "unknown", + ConnectorID: defaultConnector.ID, + })) + }) + + t.Run("remove account from unknown pool", func(t *testing.T) { + require.NoError(t, store.PoolsRemoveAccount(ctx, uuid.New(), defaultAccounts[0].ID)) + }) + + t.Run("remove account from pool", func(t *testing.T) { + require.NoError(t, store.PoolsRemoveAccount(ctx, defaultPools[0].ID, defaultAccounts[1].ID)) + + p := defaultPools[0] + p.PoolAccounts = p.PoolAccounts[:1] + + actual, err := store.PoolsGet(ctx, defaultPools[0].ID) + require.NoError(t, err) + require.Equal(t, p, *actual) + }) +} + +func TestPoolsList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertAccounts(t, ctx, store, defaultAccounts) + upsertPool(t, ctx, store, defaultPools[0]) + upsertPool(t, ctx, store, defaultPools[1]) + upsertPool(t, ctx, store, defaultPools[2]) + + t.Run("list pools by name", func(t *testing.T) { + q := NewListPoolsQuery( + bunpaginate.NewPaginatedQueryOptions(PoolQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "test1")), + ) + + cursor, err := store.PoolsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + require.Equal(t, []models.Pool{defaultPools[0]}, cursor.Data) + }) + + t.Run("list pools by unknown name", func(t *testing.T) { + q := NewListPoolsQuery( + bunpaginate.NewPaginatedQueryOptions(PoolQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("name", "unknown")), + ) + + cursor, err := store.PoolsList(ctx, q) + require.NoError(t, err) + require.Empty(t, cursor.Data) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + }) + + t.Run("list pools test cursor", func(t *testing.T) { + q := NewListPoolsQuery( + bunpaginate.NewPaginatedQueryOptions(PoolQuery{}). + WithPageSize(1), + ) + + cursor, err := store.PoolsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Pool{defaultPools[1]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.PoolsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Pool{defaultPools[2]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.PoolsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.False(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.Empty(t, cursor.Next) + require.Equal(t, []models.Pool{defaultPools[0]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.PoolsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Pool{defaultPools[2]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.PoolsList(ctx, q) + require.NoError(t, err) + require.Len(t, cursor.Data, 1) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Pool{defaultPools[1]}, cursor.Data) + }) +} diff --git a/components/payments/internal/storage/schedules.go b/components/payments/internal/storage/schedules.go index 6509ed54b9..f1bb53343c 100644 --- a/components/payments/internal/storage/schedules.go +++ b/components/payments/internal/storage/schedules.go @@ -128,7 +128,7 @@ func fromScheduleModel(s models.Schedule) schedule { return schedule{ ID: s.ID, ConnectorID: s.ConnectorID, - CreatedAt: s.CreatedAt, + CreatedAt: s.CreatedAt.UTC(), } } @@ -136,6 +136,6 @@ func toScheduleModel(s schedule) models.Schedule { return models.Schedule{ ID: s.ID, ConnectorID: s.ConnectorID, - CreatedAt: s.CreatedAt, + CreatedAt: s.CreatedAt.UTC(), } } diff --git a/components/payments/internal/storage/schedules_test.go b/components/payments/internal/storage/schedules_test.go new file mode 100644 index 0000000000..5837ca5078 --- /dev/null +++ b/components/payments/internal/storage/schedules_test.go @@ -0,0 +1,238 @@ +package storage + +import ( + "context" + "testing" + "time" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/query" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultSchedules = []models.Schedule{ + { + ID: "test1", + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + }, + { + ID: "test2", + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + }, + { + ID: "test3", + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + }, + } +) + +func upsertSchedule(t *testing.T, ctx context.Context, storage Storage, schedule models.Schedule) { + require.NoError(t, storage.SchedulesUpsert(ctx, schedule)) +} + +func TestSchedulesUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertSchedule(t, ctx, store, defaultSchedules[0]) + upsertSchedule(t, ctx, store, defaultSchedules[1]) + upsertSchedule(t, ctx, store, defaultSchedules[2]) + + t.Run("upsert with same id", func(t *testing.T) { + sch := models.Schedule{ + ID: "test1", + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-90 * time.Minute).UTC().Time, + } + + require.NoError(t, store.SchedulesUpsert(ctx, sch)) + + actual, err := store.SchedulesGet(ctx, sch.ID, sch.ConnectorID) + require.NoError(t, err) + require.Equal(t, defaultSchedules[0], *actual) + }) + + t.Run("upsert with unknown connector id", func(t *testing.T) { + sch := models.Schedule{ + ID: "test4", + ConnectorID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }, + CreatedAt: now.Add(-90 * time.Minute).UTC().Time, + } + + require.Error(t, store.SchedulesUpsert(ctx, sch)) + }) +} + +func TestSchedulesDeleteFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertSchedule(t, ctx, store, defaultSchedules[0]) + upsertSchedule(t, ctx, store, defaultSchedules[1]) + upsertSchedule(t, ctx, store, defaultSchedules[2]) + + t.Run("delete schedules from unknown connector id", func(t *testing.T) { + require.NoError(t, store.SchedulesDeleteFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })) + + for _, sch := range defaultSchedules { + actual, err := store.SchedulesGet(ctx, sch.ID, sch.ConnectorID) + require.NoError(t, err) + require.Equal(t, sch, *actual) + } + }) + + t.Run("delete schedules", func(t *testing.T) { + require.NoError(t, store.SchedulesDeleteFromConnectorID(ctx, defaultConnector.ID)) + + for _, sch := range defaultSchedules { + _, err := store.SchedulesGet(ctx, sch.ID, sch.ConnectorID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + } + }) +} + +func TestSchedulesGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertSchedule(t, ctx, store, defaultSchedules[0]) + upsertSchedule(t, ctx, store, defaultSchedules[1]) + upsertSchedule(t, ctx, store, defaultSchedules[2]) + + t.Run("get schedule", func(t *testing.T) { + actual, err := store.SchedulesGet(ctx, defaultSchedules[0].ID, defaultSchedules[0].ConnectorID) + require.NoError(t, err) + require.Equal(t, defaultSchedules[0], *actual) + }) + + t.Run("get unknown schedule", func(t *testing.T) { + _, err := store.SchedulesGet(ctx, "unknown", defaultConnector.ID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + }) +} + +func TestSchedulesList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertSchedule(t, ctx, store, defaultSchedules[0]) + upsertSchedule(t, ctx, store, defaultSchedules[1]) + upsertSchedule(t, ctx, store, defaultSchedules[2]) + + t.Run("list schedules by connector id", func(t *testing.T) { + q := NewListSchedulesQuery( + bunpaginate.NewPaginatedQueryOptions(ScheduleQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", defaultConnector.ID)), + ) + + cursor, err := store.SchedulesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 3, len(cursor.Data)) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + require.Equal(t, []models.Schedule{defaultSchedules[1], defaultSchedules[2], defaultSchedules[0]}, cursor.Data) + }) + + t.Run("list schedules by unknown connector id", func(t *testing.T) { + q := NewListSchedulesQuery( + bunpaginate.NewPaginatedQueryOptions(ScheduleQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }), + ), + ) + + cursor, err := store.SchedulesList(ctx, q) + require.NoError(t, err) + require.Empty(t, cursor.Data) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + }) + + t.Run("list schedules test cursor", func(t *testing.T) { + q := NewListSchedulesQuery( + bunpaginate.NewPaginatedQueryOptions(ScheduleQuery{}). + WithPageSize(1), + ) + + cursor, err := store.SchedulesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Schedule{defaultSchedules[1]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.SchedulesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Schedule{defaultSchedules[2]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.SchedulesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.False(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.Empty(t, cursor.Next) + require.Equal(t, []models.Schedule{defaultSchedules[0]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.SchedulesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Schedule{defaultSchedules[2]}, cursor.Data) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.SchedulesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, []models.Schedule{defaultSchedules[1]}, cursor.Data) + }) +} diff --git a/components/payments/internal/storage/states_test.go b/components/payments/internal/storage/states_test.go new file mode 100644 index 0000000000..f716285239 --- /dev/null +++ b/components/payments/internal/storage/states_test.go @@ -0,0 +1,151 @@ +package storage + +import ( + "context" + "encoding/json" + "testing" + + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultStates = []models.State{ + { + ID: models.StateID{ + Reference: "test1", + ConnectorID: defaultConnector.ID, + }, + ConnectorID: defaultConnector.ID, + State: []byte(`{}`), + }, + { + ID: models.StateID{ + Reference: "test2", + ConnectorID: defaultConnector.ID, + }, + ConnectorID: defaultConnector.ID, + State: []byte(`{"foo":"bar"}`), + }, + { + ID: models.StateID{ + Reference: "test3", + ConnectorID: defaultConnector.ID, + }, + ConnectorID: defaultConnector.ID, + State: []byte(`{"foo3":"bar3"}`), + }, + } +) + +func upsertState(t *testing.T, ctx context.Context, storage Storage, state models.State) { + require.NoError(t, storage.StatesUpsert(ctx, state)) +} + +func TestStatesUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, state := range defaultStates { + upsertState(t, ctx, store, state) + } + + t.Run("upsert with unknown connector id", func(t *testing.T) { + c := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + s := models.State{ + ID: models.StateID{ + Reference: "test4", + ConnectorID: c, + }, + ConnectorID: c, + State: []byte(`{}`), + } + + require.Error(t, store.StatesUpsert(ctx, s)) + }) + + t.Run("upsert with same id", func(t *testing.T) { + s := models.State{ + ID: defaultStates[0].ID, + ConnectorID: defaultConnector.ID, + State: json.RawMessage(`{"foo":"bar"}`), + } + + upsertState(t, ctx, store, s) + + // Should update the state + state, err := store.StatesGet(ctx, s.ID) + require.NoError(t, err) + require.Equal(t, s, state) + }) +} + +func TestStatesGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, state := range defaultStates { + upsertState(t, ctx, store, state) + } + + t.Run("get state", func(t *testing.T) { + for _, state := range defaultStates { + s, err := store.StatesGet(ctx, state.ID) + require.NoError(t, err) + require.Equal(t, state, s) + } + }) + + t.Run("get state with unknown id", func(t *testing.T) { + _, err := store.StatesGet(ctx, models.StateID{ + Reference: "unknown", + ConnectorID: defaultConnector.ID, + }) + require.Error(t, err) + }) +} + +func TestDeleteStatesFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, state := range defaultStates { + upsertState(t, ctx, store, state) + } + + t.Run("delete states with unknown connector id", func(t *testing.T) { + require.NoError(t, store.StatesDeleteFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })) + + for _, state := range defaultStates { + s, err := store.StatesGet(ctx, state.ID) + require.NoError(t, err) + require.Equal(t, state, s) + } + }) + + t.Run("delete states", func(t *testing.T) { + require.NoError(t, store.StatesDeleteFromConnectorID(ctx, defaultConnector.ID)) + + for _, state := range defaultStates { + _, err := store.StatesGet(ctx, state.ID) + require.Error(t, err) + } + }) +} diff --git a/components/payments/internal/storage/storage.go b/components/payments/internal/storage/storage.go index b9cf09474a..75bcdf59f7 100644 --- a/components/payments/internal/storage/storage.go +++ b/components/payments/internal/storage/storage.go @@ -11,6 +11,9 @@ import ( ) type Storage interface { + // Close closes the storage. + Close() error + // Accounts AccountsUpsert(ctx context.Context, accounts []models.Account) error AccountsGet(ctx context.Context, id models.AccountID) (*models.Account, error) @@ -19,7 +22,7 @@ type Storage interface { // Balances BalancesUpsert(ctx context.Context, balances []models.Balance) error - BalancesDeleteForConnectorID(ctx context.Context, connectorID models.ConnectorID) error + BalancesDeleteFromConnectorID(ctx context.Context, connectorID models.ConnectorID) error BalancesList(ctx context.Context, q ListBalancesQuery) (*bunpaginate.Cursor[models.Balance], error) BalancesGetAt(ctx context.Context, accountID models.AccountID, at time.Time) ([]*models.Balance, error) @@ -70,15 +73,18 @@ type Storage interface { // Webhooks Configs WebhooksConfigsUpsert(ctx context.Context, webhooksConfigs []models.WebhookConfig) error + WebhooksConfigsGet(ctx context.Context, name string, connectorID models.ConnectorID) (*models.WebhookConfig, error) WebhooksConfigsDeleteFromConnectorID(ctx context.Context, connectorID models.ConnectorID) error // Webhooks WebhooksInsert(ctx context.Context, webhook models.Webhook) error + WebhooksGet(ctx context.Context, id string) (models.Webhook, error) WebhooksDeleteFromConnectorID(ctx context.Context, connectorID models.ConnectorID) error // Workflow Instances InstancesUpsert(ctx context.Context, instance models.Instance) error InstancesUpdate(ctx context.Context, instance models.Instance) error + InstancesGet(ctx context.Context, id string, scheduleID string, connectorID models.ConnectorID) (*models.Instance, error) InstancesList(ctx context.Context, q ListInstancesQuery) (*bunpaginate.Cursor[models.Instance], error) InstancesDeleteFromConnectorID(ctx context.Context, connectorID models.ConnectorID) error } @@ -93,3 +99,7 @@ type store struct { func newStorage(db *bun.DB, configEncryptionKey string) Storage { return &store{db: db, configEncryptionKey: configEncryptionKey} } + +func (s *store) Close() error { + return s.db.Close() +} diff --git a/components/payments/internal/storage/tasks_test.go b/components/payments/internal/storage/tasks_test.go new file mode 100644 index 0000000000..618b8119be --- /dev/null +++ b/components/payments/internal/storage/tasks_test.go @@ -0,0 +1,160 @@ +package storage + +import ( + "context" + "testing" + + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultTasksTree = models.Tasks{ + { + TaskType: models.TASK_FETCH_ACCOUNTS, + Name: "fetch_accounts", + Periodically: true, + NextTasks: []models.TaskTree{ + { + TaskType: models.TASK_FETCH_PAYMENTS, + Name: "fetch_payments", + Periodically: true, + NextTasks: []models.TaskTree{}, + }, + { + TaskType: models.TASK_FETCH_BALANCES, + Name: "fetch_balances", + Periodically: true, + NextTasks: []models.TaskTree{}, + }, + }, + }, + { + TaskType: models.TASK_FETCH_EXTERNAL_ACCOUNTS, + Name: "fetch_beneficiaries", + Periodically: true, + NextTasks: []models.TaskTree{}, + }, + } + + defaultTasksTree2 = models.Tasks{ + { + TaskType: models.TASK_FETCH_ACCOUNTS, + Name: "fetch_accounts", + Periodically: true, + NextTasks: []models.TaskTree{ + { + TaskType: models.TASK_FETCH_BALANCES, + Name: "fetch_balances", + Periodically: true, + NextTasks: []models.TaskTree{}, + }, + { + TaskType: models.TASK_FETCH_PAYMENTS, + Name: "fetch_payments", + Periodically: true, + NextTasks: []models.TaskTree{}, + }, + { + TaskType: models.TASK_FETCH_EXTERNAL_ACCOUNTS, + Name: "fetch_recipients", + Periodically: true, + NextTasks: []models.TaskTree{}, + }, + }, + }, + } +) + +func upsertTasksTree(t *testing.T, ctx context.Context, storage Storage, connectorID models.ConnectorID, tasksTree []models.TaskTree) { + require.NoError(t, storage.TasksUpsert(ctx, connectorID, tasksTree)) +} + +func TestTasksUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertTasksTree(t, ctx, store, defaultConnector.ID, defaultTasksTree) + + t.Run("upsert with unknown connector id", func(t *testing.T) { + require.Error(t, store.TasksUpsert(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }, defaultTasksTree2)) + }) + + t.Run("upsert with same connector id", func(t *testing.T) { + upsertTasksTree(t, ctx, store, defaultConnector.ID, defaultTasksTree2) + + tasks, err := store.TasksGet(ctx, defaultConnector.ID) + require.NoError(t, err) + require.Equal(t, defaultTasksTree2, *tasks) + }) +} + +func TestTasksGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertTasksTree(t, ctx, store, defaultConnector.ID, defaultTasksTree) + + t.Run("get tasks", func(t *testing.T) { + tasks, err := store.TasksGet(ctx, defaultConnector.ID) + require.NoError(t, err) + require.Equal(t, defaultTasksTree, *tasks) + }) + + t.Run("get tasks with unknown connector id", func(t *testing.T) { + _, err := store.TasksGet(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }) + require.Error(t, err) + }) +} + +func TestTasksDeleteFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertConnector(t, ctx, store, defaultConnector2) + upsertTasksTree(t, ctx, store, defaultConnector.ID, defaultTasksTree) + upsertTasksTree(t, ctx, store, defaultConnector2.ID, defaultTasksTree2) + + t.Run("delete tasks with unknown connector id", func(t *testing.T) { + require.NoError(t, store.TasksDeleteFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })) + + tasks, err := store.TasksGet(ctx, defaultConnector.ID) + require.NoError(t, err) + require.Equal(t, defaultTasksTree, *tasks) + + tasks, err = store.TasksGet(ctx, defaultConnector2.ID) + require.NoError(t, err) + require.Equal(t, defaultTasksTree2, *tasks) + }) + + t.Run("delete tasks", func(t *testing.T) { + require.NoError(t, store.TasksDeleteFromConnectorID(ctx, defaultConnector.ID)) + + _, err := store.TasksGet(ctx, defaultConnector.ID) + require.Error(t, err) + + tasks, err := store.TasksGet(ctx, defaultConnector2.ID) + require.NoError(t, err) + require.Equal(t, defaultTasksTree2, *tasks) + }) +} diff --git a/components/payments/internal/storage/utils_test.go b/components/payments/internal/storage/utils_test.go new file mode 100644 index 0000000000..e8d034a8fd --- /dev/null +++ b/components/payments/internal/storage/utils_test.go @@ -0,0 +1,27 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMetadataRegexp(t *testing.T) { + t.Parallel() + + t.Run("valid tests", func(t *testing.T) { + t.Parallel() + require.True(t, metadataRegex.MatchString("metadata[foo]")) + require.True(t, metadataRegex.MatchString("metadata[foo_bar]")) + require.True(t, metadataRegex.MatchString("metadata[foo/bar]")) + require.True(t, metadataRegex.MatchString("metadata[foo.bar]")) + }) + + t.Run("invalid tests", func(t *testing.T) { + t.Parallel() + + require.False(t, metadataRegex.MatchString("metadata[foo")) + require.False(t, metadataRegex.MatchString("metadata/foo")) + require.False(t, metadataRegex.MatchString("metadata.foo")) + }) +} diff --git a/components/payments/internal/storage/webhooks.go b/components/payments/internal/storage/webhooks.go index 1eac74f47f..c4d974d8bc 100644 --- a/components/payments/internal/storage/webhooks.go +++ b/components/payments/internal/storage/webhooks.go @@ -34,6 +34,19 @@ func (s *store) WebhooksInsert(ctx context.Context, webhook models.Webhook) erro return nil } +func (s *store) WebhooksGet(ctx context.Context, id string) (models.Webhook, error) { + var w webhook + err := s.db.NewSelect(). + Model(&w). + Where("id = ?", id). + Scan(ctx) + if err != nil { + return models.Webhook{}, e("get webhook", err) + } + + return toWebhookModels(w), nil +} + func (s *store) WebhooksDeleteFromConnectorID(ctx context.Context, connectorID models.ConnectorID) error { _, err := s.db.NewDelete(). Model((*webhook)(nil)). @@ -55,3 +68,13 @@ func fromWebhookModels(from models.Webhook) webhook { Body: from.Body, } } + +func toWebhookModels(from webhook) models.Webhook { + return models.Webhook{ + ID: from.ID, + ConnectorID: from.ConnectorID, + Headers: from.Headers, + QueryValues: from.QueryValues, + Body: from.Body, + } +} diff --git a/components/payments/internal/storage/webhooks_configs.go b/components/payments/internal/storage/webhooks_configs.go index e4a2f13614..7c5f583ea0 100644 --- a/components/payments/internal/storage/webhooks_configs.go +++ b/components/payments/internal/storage/webhooks_configs.go @@ -30,6 +30,19 @@ func (s *store) WebhooksConfigsUpsert(ctx context.Context, webhooksConfigs []mod return nil } +func (s *store) WebhooksConfigsGet(ctx context.Context, name string, connectorID models.ConnectorID) (*models.WebhookConfig, error) { + var webhookConfig webhookConfig + err := s.db.NewSelect(). + Model(&webhookConfig). + Where("name = ? AND connector_id = ?", name, connectorID). + Scan(ctx) + if err != nil { + return nil, e("get webhook config", err) + } + + return toWebhookConfigModel(webhookConfig), nil +} + func (s *store) WebhooksConfigsDeleteFromConnectorID(ctx context.Context, connectorID models.ConnectorID) error { _, err := s.db.NewDelete(). Model((*webhookConfig)(nil)). @@ -58,3 +71,11 @@ func fromWebhooksConfigsModels(from []models.WebhookConfig) []webhookConfig { return to } + +func toWebhookConfigModel(from webhookConfig) *models.WebhookConfig { + return &models.WebhookConfig{ + Name: from.Name, + ConnectorID: from.ConnectorID, + URLPath: from.URLPath, + } +} diff --git a/components/payments/internal/storage/webhooks_configs_test.go b/components/payments/internal/storage/webhooks_configs_test.go new file mode 100644 index 0000000000..95b0382a29 --- /dev/null +++ b/components/payments/internal/storage/webhooks_configs_test.go @@ -0,0 +1,129 @@ +package storage + +import ( + "context" + "testing" + + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultWebhooksConfigs = []models.WebhookConfig{ + { + Name: "test1", + ConnectorID: defaultConnector.ID, + URLPath: "/test1", + }, + { + Name: "test2", + ConnectorID: defaultConnector.ID, + URLPath: "/test2", + }, + { + Name: "test3", + ConnectorID: defaultConnector.ID, + URLPath: "/test3", + }, + } +) + +func upsertWebhookConfigs(t *testing.T, ctx context.Context, storage Storage, webhookConfigs []models.WebhookConfig) { + require.NoError(t, storage.WebhooksConfigsUpsert(ctx, webhookConfigs)) +} + +func TestWebhooksConfigsUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertWebhookConfigs(t, ctx, store, defaultWebhooksConfigs) + + t.Run("same name and connector id insert", func(t *testing.T) { + w := models.WebhookConfig{ + Name: "test1", + ConnectorID: defaultConnector.ID, + URLPath: "/test3", + } + + require.NoError(t, store.WebhooksConfigsUpsert(ctx, []models.WebhookConfig{w})) + + actual, err := store.WebhooksConfigsGet(ctx, w.Name, w.ConnectorID) + require.NoError(t, err) + require.Equal(t, defaultWebhooksConfigs[0], *actual) + }) + + t.Run("unknown connector id", func(t *testing.T) { + w := models.WebhookConfig{ + Name: "test1", + ConnectorID: models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + }, + URLPath: "/test3", + } + + require.Error(t, store.WebhooksConfigsUpsert(ctx, []models.WebhookConfig{w})) + }) +} + +func TestWebhooksConfigsGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertWebhookConfigs(t, ctx, store, defaultWebhooksConfigs) + + t.Run("get webhook config", func(t *testing.T) { + for _, w := range defaultWebhooksConfigs { + actual, err := store.WebhooksConfigsGet(ctx, w.Name, w.ConnectorID) + require.NoError(t, err) + require.Equal(t, w, *actual) + } + }) + + t.Run("unknown webhook config", func(t *testing.T) { + _, err := store.WebhooksConfigsGet(ctx, "unknown", defaultConnector.ID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + }) +} + +func TestWebhooksConfigsDeleteFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + upsertWebhookConfigs(t, ctx, store, defaultWebhooksConfigs) + + t.Run("delete webhooks configs from unknown connector id", func(t *testing.T) { + require.NoError(t, store.WebhooksConfigsDeleteFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })) + + for _, w := range defaultWebhooksConfigs { + actual, err := store.WebhooksConfigsGet(ctx, w.Name, w.ConnectorID) + require.NoError(t, err) + require.Equal(t, w, *actual) + } + }) + + t.Run("delete webhooks configs from connector id", func(t *testing.T) { + require.NoError(t, store.WebhooksConfigsDeleteFromConnectorID(ctx, defaultConnector.ID)) + + for _, w := range defaultWebhooksConfigs { + _, err := store.WebhooksConfigsGet(ctx, w.Name, w.ConnectorID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + } + }) +} diff --git a/components/payments/internal/storage/webhooks_test.go b/components/payments/internal/storage/webhooks_test.go new file mode 100644 index 0000000000..bfa511e291 --- /dev/null +++ b/components/payments/internal/storage/webhooks_test.go @@ -0,0 +1,138 @@ +package storage + +import ( + "context" + "testing" + + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultWebhooks = []models.Webhook{ + { + ID: "test1", + ConnectorID: defaultConnector.ID, + QueryValues: map[string][]string{ + "foo": {"bar"}, + }, + Headers: map[string][]string{ + "foo2": {"bar2"}, + }, + Body: []byte(`{}`), + }, + { + ID: "test2", + ConnectorID: defaultConnector.ID, + QueryValues: map[string][]string{ + "foo3": {"bar3"}, + }, + Headers: map[string][]string{ + "foo4": {"bar4"}, + }, + Body: []byte(`{}`), + }, + } +) + +func upsertWebhook(t *testing.T, ctx context.Context, storage Storage, webhook models.Webhook) { + require.NoError(t, storage.WebhooksInsert(ctx, webhook)) +} + +func TestWebhooksInsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, webhook := range defaultWebhooks { + upsertWebhook(t, ctx, store, webhook) + } + + t.Run("same id upsert", func(t *testing.T) { + webhook := defaultWebhooks[0] + webhook.QueryValues = map[string][]string{ + "changed": {"changed"}, + } + + require.NoError(t, store.WebhooksInsert(ctx, webhook)) + + // should not have been changed + actual, err := store.WebhooksGet(ctx, webhook.ID) + require.NoError(t, err) + require.Equal(t, defaultWebhooks[0], actual) + }) + + t.Run("unknown connector id", func(t *testing.T) { + webhook := defaultWebhooks[0] + webhook.ID = "unknown" + webhook.ConnectorID = models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + require.Error(t, store.WebhooksInsert(ctx, webhook)) + }) +} + +func TestWebhooksGet(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, webhook := range defaultWebhooks { + upsertWebhook(t, ctx, store, webhook) + } + + t.Run("get webhook", func(t *testing.T) { + for _, webhook := range defaultWebhooks { + actual, err := store.WebhooksGet(ctx, webhook.ID) + require.NoError(t, err) + require.Equal(t, webhook, actual) + } + }) + + t.Run("get unknown webhook", func(t *testing.T) { + _, err := store.WebhooksGet(ctx, "unknown") + require.Error(t, err) + }) +} + +func TestWebhooksDeleteFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, webhook := range defaultWebhooks { + upsertWebhook(t, ctx, store, webhook) + } + + t.Run("delete unknown connector id", func(t *testing.T) { + require.NoError(t, store.WebhooksDeleteFromConnectorID(ctx, models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })) + + for _, webhook := range defaultWebhooks { + actual, err := store.WebhooksGet(ctx, webhook.ID) + require.NoError(t, err) + require.Equal(t, webhook, actual) + } + }) + + t.Run("delete webhooks", func(t *testing.T) { + require.NoError(t, store.WebhooksDeleteFromConnectorID(ctx, defaultConnector.ID)) + + for _, webhook := range defaultWebhooks { + _, err := store.WebhooksGet(ctx, webhook.ID) + require.Error(t, err) + } + }) +} diff --git a/components/payments/internal/storage/workflow_instances.go b/components/payments/internal/storage/workflow_instances.go index e1ed817079..58a0a372ef 100644 --- a/components/payments/internal/storage/workflow_instances.go +++ b/components/payments/internal/storage/workflow_instances.go @@ -6,6 +6,7 @@ import ( "time" "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/pointer" "github.com/formancehq/go-libs/query" "github.com/formancehq/payments/internal/models" "github.com/pkg/errors" @@ -94,6 +95,21 @@ func (s *store) instancesQueryContext(qb query.Builder) (string, []any, error) { })) } +func (s *store) InstancesGet(ctx context.Context, id string, scheduleID string, connectorID models.ConnectorID) (*models.Instance, error) { + var i instance + err := s.db.NewSelect(). + Model(&i). + Where("id = ?", id). + Where("schedule_id = ?", scheduleID). + Where("connector_id = ?", connectorID). + Scan(ctx) + if err != nil { + return nil, e("failed to fetch instance", err) + } + + return pointer.For(toInstanceModel(i)), nil +} + func (s *store) InstancesList(ctx context.Context, q ListInstancesQuery) (*bunpaginate.Cursor[models.Instance], error) { var ( where string @@ -139,26 +155,37 @@ func (s *store) InstancesList(ctx context.Context, q ListInstancesQuery) (*bunpa func fromInstanceModel(from models.Instance) instance { return instance{ - ID: from.ID, - ScheduleID: from.ScheduleID, - ConnectorID: from.ConnectorID, - CreatedAt: from.CreatedAt, - UpdatedAt: from.UpdatedAt, - Terminated: from.Terminated, - TerminatedAt: from.TerminatedAt, - Error: from.Error, + ID: from.ID, + ScheduleID: from.ScheduleID, + ConnectorID: from.ConnectorID, + CreatedAt: from.CreatedAt.UTC(), + UpdatedAt: from.UpdatedAt.UTC(), + Terminated: from.Terminated, + TerminatedAt: func() *time.Time { + if from.TerminatedAt == nil { + return nil + } + return pointer.For(from.TerminatedAt.UTC()) + }(), + Error: from.Error, } } func toInstanceModel(from instance) models.Instance { return models.Instance{ - ID: from.ID, - ScheduleID: from.ScheduleID, - ConnectorID: from.ConnectorID, - CreatedAt: from.CreatedAt, - UpdatedAt: from.UpdatedAt, - Terminated: from.Terminated, - TerminatedAt: from.TerminatedAt, - Error: from.Error, + ID: from.ID, + ScheduleID: from.ScheduleID, + ConnectorID: from.ConnectorID, + CreatedAt: from.CreatedAt.UTC(), + UpdatedAt: from.UpdatedAt.UTC(), + Terminated: from.Terminated, + TerminatedAt: func() *time.Time { + if from.TerminatedAt == nil { + return nil + } + + return pointer.For(from.TerminatedAt.UTC()) + }(), + Error: from.Error, } } diff --git a/components/payments/internal/storage/workflow_instances_test.go b/components/payments/internal/storage/workflow_instances_test.go new file mode 100644 index 0000000000..fdfd9813ad --- /dev/null +++ b/components/payments/internal/storage/workflow_instances_test.go @@ -0,0 +1,318 @@ +package storage + +import ( + "context" + "testing" + "time" + + "github.com/formancehq/go-libs/bun/bunpaginate" + "github.com/formancehq/go-libs/logging" + "github.com/formancehq/go-libs/pointer" + "github.com/formancehq/go-libs/query" + "github.com/formancehq/payments/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + defaultWorkflowInstances = []models.Instance{ + { + ID: "test1", + ScheduleID: defaultSchedules[0].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + UpdatedAt: now.Add(-60 * time.Minute).UTC().Time, + Terminated: false, + }, + { + ID: "test2", + ScheduleID: defaultSchedules[0].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + UpdatedAt: now.Add(-30 * time.Minute).UTC().Time, + Terminated: false, + }, + { + ID: "test3", + ScheduleID: defaultSchedules[2].ID, + ConnectorID: defaultConnector.ID, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + UpdatedAt: now.Add(-55 * time.Minute).UTC().Time, + Terminated: true, + TerminatedAt: pointer.For(now.UTC().Time), + Error: pointer.For("test error"), + }, + } +) + +func upsertInstance(t *testing.T, ctx context.Context, storage Storage, instance models.Instance) { + require.NoError(t, storage.InstancesUpsert(ctx, instance)) +} + +func TestInstancesUpsert(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, schedule := range defaultSchedules { + upsertSchedule(t, ctx, store, schedule) + } + for _, instance := range defaultWorkflowInstances { + upsertInstance(t, ctx, store, instance) + } + + t.Run("same id upsert", func(t *testing.T) { + instance := defaultWorkflowInstances[0] + instance.Terminated = true + instance.TerminatedAt = pointer.For(now.UTC().Time) + instance.Error = pointer.For("test error") + + upsertInstance(t, ctx, store, instance) + + actual, err := store.InstancesGet(ctx, instance.ID, instance.ScheduleID, instance.ConnectorID) + require.NoError(t, err) + require.Equal(t, defaultWorkflowInstances[0], *actual) + }) + + t.Run("unknown connector id", func(t *testing.T) { + instance := defaultWorkflowInstances[0] + instance.ConnectorID = models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + err := store.InstancesUpsert(ctx, instance) + require.Error(t, err) + }) + + t.Run("unknown schedule id", func(t *testing.T) { + instance := defaultWorkflowInstances[0] + instance.ScheduleID = uuid.New().String() + + err := store.InstancesUpsert(ctx, instance) + require.Error(t, err) + }) +} + +func TestInstancesUpdate(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, schedule := range defaultSchedules { + upsertSchedule(t, ctx, store, schedule) + } + for _, instance := range defaultWorkflowInstances { + upsertInstance(t, ctx, store, instance) + } + + t.Run("update instance error", func(t *testing.T) { + instance := defaultWorkflowInstances[0] + instance.Error = pointer.For("test error") + instance.Terminated = true + instance.TerminatedAt = pointer.For(now.UTC().Time) + + err := store.InstancesUpdate(ctx, instance) + require.NoError(t, err) + + actual, err := store.InstancesGet(ctx, instance.ID, instance.ScheduleID, instance.ConnectorID) + require.NoError(t, err) + require.Equal(t, instance, *actual) + }) + + t.Run("update instance already on error", func(t *testing.T) { + instance := defaultWorkflowInstances[2] + instance.Error = pointer.For("test error2") + instance.Terminated = true + instance.TerminatedAt = pointer.For(now.UTC().Time) + + err := store.InstancesUpdate(ctx, instance) + require.NoError(t, err) + + actual, err := store.InstancesGet(ctx, instance.ID, instance.ScheduleID, instance.ConnectorID) + require.NoError(t, err) + require.Equal(t, instance, *actual) + }) +} + +func TestInstancesDeleteFromConnectorID(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, schedule := range defaultSchedules { + upsertSchedule(t, ctx, store, schedule) + } + for _, instance := range defaultWorkflowInstances { + upsertInstance(t, ctx, store, instance) + } + + t.Run("delete instances from unknown connector", func(t *testing.T) { + unknownConnectorID := models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + } + + require.NoError(t, store.InstancesDeleteFromConnectorID(ctx, unknownConnectorID)) + + for _, instance := range defaultWorkflowInstances { + actual, err := store.InstancesGet(ctx, instance.ID, instance.ScheduleID, instance.ConnectorID) + require.NoError(t, err) + require.Equal(t, instance, *actual) + } + }) + + t.Run("delete instances from default connector", func(t *testing.T) { + require.NoError(t, store.InstancesDeleteFromConnectorID(ctx, defaultConnector.ID)) + + for _, instance := range defaultWorkflowInstances { + _, err := store.InstancesGet(ctx, instance.ID, instance.ScheduleID, instance.ConnectorID) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + } + }) +} + +func TestInstancesList(t *testing.T) { + t.Parallel() + + ctx := logging.TestingContext() + store := newStore(t) + + upsertConnector(t, ctx, store, defaultConnector) + for _, schedule := range defaultSchedules { + upsertSchedule(t, ctx, store, schedule) + } + for _, instance := range defaultWorkflowInstances { + upsertInstance(t, ctx, store, instance) + } + + t.Run("list instances by schedule_id", func(t *testing.T) { + q := NewListInstancesQuery( + bunpaginate.NewPaginatedQueryOptions(InstanceQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("schedule_id", defaultSchedules[0].ID)), + ) + + cursor, err := store.InstancesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 2, len(cursor.Data)) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + require.Equal(t, defaultWorkflowInstances[1], cursor.Data[0]) + require.Equal(t, defaultWorkflowInstances[0], cursor.Data[1]) + }) + + t.Run("list instances by unknown schedule_id", func(t *testing.T) { + q := NewListInstancesQuery( + bunpaginate.NewPaginatedQueryOptions(InstanceQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("schedule_id", uuid.New().String())), + ) + + cursor, err := store.InstancesList(ctx, q) + require.NoError(t, err) + require.Empty(t, cursor.Data) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + }) + + t.Run("list instances by connector_id", func(t *testing.T) { + q := NewListInstancesQuery( + bunpaginate.NewPaginatedQueryOptions(InstanceQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", defaultConnector.ID)), + ) + + cursor, err := store.InstancesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 3, len(cursor.Data)) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + require.Equal(t, defaultWorkflowInstances[1], cursor.Data[0]) + require.Equal(t, defaultWorkflowInstances[2], cursor.Data[1]) + require.Equal(t, defaultWorkflowInstances[0], cursor.Data[2]) + }) + + t.Run("list instances by unknown connector_id", func(t *testing.T) { + q := NewListInstancesQuery( + bunpaginate.NewPaginatedQueryOptions(InstanceQuery{}). + WithPageSize(15). + WithQueryBuilder(query.Match("connector_id", models.ConnectorID{ + Reference: uuid.New(), + Provider: "unknown", + })), + ) + + cursor, err := store.InstancesList(ctx, q) + require.NoError(t, err) + require.Empty(t, cursor.Data) + require.False(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.Empty(t, cursor.Next) + }) + + t.Run("list instances test cursor", func(t *testing.T) { + q := NewListInstancesQuery( + bunpaginate.NewPaginatedQueryOptions(InstanceQuery{}). + WithPageSize(1), + ) + + cursor, err := store.InstancesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, defaultWorkflowInstances[1], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.InstancesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, defaultWorkflowInstances[2], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Next, &q) + require.NoError(t, err) + cursor, err = store.InstancesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.False(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.Empty(t, cursor.Next) + require.Equal(t, defaultWorkflowInstances[0], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.InstancesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.NotEmpty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, defaultWorkflowInstances[2], cursor.Data[0]) + + err = bunpaginate.UnmarshalCursor(cursor.Previous, &q) + require.NoError(t, err) + cursor, err = store.InstancesList(ctx, q) + require.NoError(t, err) + require.Equal(t, 1, len(cursor.Data)) + require.True(t, cursor.HasMore) + require.Empty(t, cursor.Previous) + require.NotEmpty(t, cursor.Next) + require.Equal(t, defaultWorkflowInstances[1], cursor.Data[0]) + }) +}