diff --git a/integration_tests/hodl_invoice_test.go b/integration_tests/hodl_invoice_test.go index 3e58e28c..7edcf884 100644 --- a/integration_tests/hodl_invoice_test.go +++ b/integration_tests/hodl_invoice_test.go @@ -152,7 +152,7 @@ func (suite *HodlInvoiceSuite) TestHodlInvoice() { } assert.Equal(suite.T(), int64(userFundingSats), userBalance) - invoices, err := suite.service.InvoicesFor(context.Background(), userId, common.InvoiceTypeOutgoing) + invoices, err := invoicesFor(suite.service, userId, common.InvoiceTypeOutgoing) if err != nil { fmt.Printf("Error when getting invoices %v\n", err.Error()) } diff --git a/integration_tests/internal_payment_test.go b/integration_tests/internal_payment_test.go index dcc9c7f0..6590097f 100644 --- a/integration_tests/internal_payment_test.go +++ b/integration_tests/internal_payment_test.go @@ -394,7 +394,7 @@ func (suite *PaymentTestSuite) TestInternalPaymentFail() { _ = suite.createPayInvoiceReqError(bobInvoice.PayReq, suite.aliceToken) userId := getUserIdFromToken(suite.aliceToken) - invoices, err := suite.service.InvoicesFor(context.Background(), userId, common.InvoiceTypeOutgoing) + invoices, err := invoicesFor(suite.service, userId, common.InvoiceTypeOutgoing) if err != nil { fmt.Printf("Error when getting invoices %v\n", err.Error()) } diff --git a/integration_tests/keysend_failure_test.go b/integration_tests/keysend_failure_test.go index 25659c7a..dd5bfca5 100644 --- a/integration_tests/keysend_failure_test.go +++ b/integration_tests/keysend_failure_test.go @@ -102,7 +102,7 @@ func (suite *KeySendFailureTestSuite) TestKeysendPayment() { } assert.Equal(suite.T(), int64(aliceFundingSats), aliceBalance) - invoices, err := suite.service.InvoicesFor(context.Background(), userId, common.InvoiceTypeOutgoing) + invoices, err := invoicesFor(suite.service, userId, common.InvoiceTypeOutgoing) if err != nil { fmt.Printf("Error when getting invoices %v\n", err.Error()) } diff --git a/integration_tests/payment_failure_async_test.go b/integration_tests/payment_failure_async_test.go index edc8a1f5..60fed507 100644 --- a/integration_tests/payment_failure_async_test.go +++ b/integration_tests/payment_failure_async_test.go @@ -119,7 +119,7 @@ func (suite *PaymentTestAsyncErrorsSuite) TestExternalAsyncFailingInvoice() { } assert.Equal(suite.T(), int64(userFundingSats), userBalance) - invoices, err := suite.service.InvoicesFor(context.Background(), userId, common.InvoiceTypeOutgoing) + invoices, err := invoicesFor(suite.service, userId, common.InvoiceTypeOutgoing) if err != nil { fmt.Printf("Error when getting invoices %v\n", err.Error()) } diff --git a/integration_tests/payment_failure_test.go b/integration_tests/payment_failure_test.go index eff222a9..b73ec94c 100644 --- a/integration_tests/payment_failure_test.go +++ b/integration_tests/payment_failure_test.go @@ -130,7 +130,7 @@ func (suite *PaymentTestErrorsSuite) TestExternalFailingInvoice() { userId := getUserIdFromToken(suite.userToken) - invoices, err := suite.service.InvoicesFor(context.Background(), userId, common.InvoiceTypeOutgoing) + invoices, err := invoicesFor(suite.service, userId, common.InvoiceTypeOutgoing) if err != nil { fmt.Printf("Error when getting invoices %v\n", err.Error()) } diff --git a/integration_tests/util.go b/integration_tests/util.go index 2fe3e7d7..c353abe1 100644 --- a/integration_tests/util.go +++ b/integration_tests/util.go @@ -10,8 +10,10 @@ import ( "os" "time" + "github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/db" "github.com/getAlby/lndhub.go/db/migrations" + "github.com/getAlby/lndhub.go/db/models" "github.com/getAlby/lndhub.go/lib" "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" @@ -126,6 +128,22 @@ func getUserIdFromToken(token string) int64 { return int64(claims["id"].(float64)) } +// since svc.invoicesFor excludes erroneous invoices, this is used for testing +func invoicesFor(svc *service.LndhubService, userId int64, invoiceType string) ([]models.Invoice, error) { + var invoices []models.Invoice + + query := svc.DB.NewSelect().Model(&invoices).Where("user_id = ?", userId) + if invoiceType != "" { + query.Where("type = ? AND state <> ?", invoiceType, common.InvoiceStateInitialized) + } + query.OrderExpr("id DESC").Limit(100) + err := query.Scan(context.Background()) + if err != nil { + return nil, err + } + return invoices, nil +} + func createUsers(svc *service.LndhubService, usersToCreate int) (logins []ExpectedCreateUserResponseBody, tokens []string, err error) { logins = []ExpectedCreateUserResponseBody{} tokens = []string{} diff --git a/lib/service/user.go b/lib/service/user.go index 62a5606a..8ef3c1cc 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -268,7 +268,7 @@ func (svc *LndhubService) InvoicesFor(ctx context.Context, userId int64, invoice query := svc.DB.NewSelect().Model(&invoices).Where("user_id = ?", userId) if invoiceType != "" { - query.Where("type = ? AND state <> ?", invoiceType, common.InvoiceStateInitialized) + query.Where("type = ? AND state NOT IN(?, ?)", invoiceType, common.InvoiceStateInitialized, common.InvoiceStateError) } query.OrderExpr("id DESC").Limit(100) err := query.Scan(ctx)