Skip to content

Commit

Permalink
sqldb+invoices: synchronize SQL invoice updater behavior with KV version
Browse files Browse the repository at this point in the history
Previously SQL invoice updater ignored the set ID hint when updating an
AMP invoice resulting in update subscriptions returning all of the AMP
state as well as all AMP HTLCs. This commit synchornizes behavior with
the KV implementation such that we now only return relevant AMP state
and HTLCs when updating an AMP invoice.
  • Loading branch information
bhandras authored and Roasbeef committed Sep 6, 2024
1 parent b3dc3ed commit 9298133
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 18 deletions.
38 changes: 34 additions & 4 deletions invoices/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"time"

"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
Expand Down Expand Up @@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
GetInvoice(ctx context.Context,
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)

GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
error)

GetInvoiceFeatures(ctx context.Context,
invoiceID int64) ([]sqlc.InvoiceFeature, error)

Expand Down Expand Up @@ -343,16 +347,31 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
params.SetID = ref.SetID()[:]
}

rows, err := db.GetInvoice(ctx, params)
var (
rows []sqlc.Invoice
err error
)

// We need to split the query based on how we intend to look up the
// invoice. If only the set ID is given then we want to have an exact
// match on the set ID. If other fields are given, we want to match on
// those fields and the set ID but with a less strict join condition.
if params.Hash == nil && params.PaymentAddr == nil &&
params.SetID != nil {

rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
} else {
rows, err = db.GetInvoice(ctx, params)
}
switch {
case len(rows) == 0:
return nil, ErrInvoiceNotFound

case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more
// than one invoice, we'll return an error.
return nil, fmt.Errorf("ambiguous invoice ref: %s",
ref.String())
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
ref.String(), spew.Sdump(rows))

case err != nil:
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
Expand Down Expand Up @@ -1308,13 +1327,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback.
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
_ *SetID, callback InvoiceUpdateCallback) (
setID *SetID, callback InvoiceUpdateCallback) (
*Invoice, error) {

var updatedInvoice *Invoice

txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
if setID != nil {
// Make sure to use the set ID if this is an AMP update.
var setIDBytes [32]byte
copy(setIDBytes[:], setID[:])
ref.setID = &setIDBytes

// If we're updating an AMP invoice, we'll also only
// need to fetch the HTLCs for the given set ID.
ref.refModifier = HtlcSetOnlyModifier
}

invoice, err := i.fetchInvoice(ctx, db, ref)
if err != nil {
return err
Expand Down
70 changes: 60 additions & 10 deletions sqldb/sqlc/invoices.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sqldb/sqlc/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions sqldb/sqlc/queries/invoices.sql
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ WHERE invoice_id = $1;
-- name: GetInvoice :many
SELECT i.*
FROM invoices i
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id
LEFT JOIN amp_sub_invoices a
ON i.id = a.invoice_id
AND (
a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL
)
WHERE (
i.id = sqlc.narg('add_index') OR
sqlc.narg('add_index') IS NULL
Expand All @@ -39,13 +43,16 @@ WHERE (
) AND (
i.payment_addr = sqlc.narg('payment_addr') OR
sqlc.narg('payment_addr') IS NULL
) AND (
a.set_id = sqlc.narg('set_id') OR
sqlc.narg('set_id') IS NULL
)
GROUP BY i.id
LIMIT 2;

-- name: GetInvoiceBySetID :many
SELECT i.*
FROM invoices i
INNER JOIN amp_sub_invoices a
ON i.id = a.invoice_id AND a.set_id = $1;

-- name: FilterInvoices :many
SELECT
invoices.*
Expand Down

0 comments on commit 9298133

Please sign in to comment.