• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15205630088

23 May 2025 08:14AM UTC coverage: 57.45% (-11.5%) from 68.996%
15205630088

Pull #9784

github

web-flow
Merge f8b9f36a3 into c52a6ddeb
Pull Request #9784: [wip] lnwallet+walletrpc: add SubmitPackage and related RPC call

47 of 96 new or added lines in 5 files covered. (48.96%)

30087 existing lines in 459 files now uncovered.

95586 of 166380 relevant lines covered (57.45%)

0.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27

28
        // invoiceProgressLogInterval is the interval we use limiting the
29
        // logging output of invoice processing.
30
        invoiceProgressLogInterval = 30 * time.Second
31
)
32

33
// SQLInvoiceQueries is an interface that defines the set of operations that can
34
// be executed against the invoice SQL database.
35
type SQLInvoiceQueries interface { //nolint:interfacebloat
36
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
37
                error)
38

39
        // TODO(bhandras): remove this once migrations have been separated out.
40
        InsertMigratedInvoice(ctx context.Context,
41
                arg sqlc.InsertMigratedInvoiceParams) (int64, error)
42

43
        InsertInvoiceFeature(ctx context.Context,
44
                arg sqlc.InsertInvoiceFeatureParams) error
45

46
        InsertInvoiceHTLC(ctx context.Context,
47
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
48

49
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
50
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
51

52
        FilterInvoices(ctx context.Context,
53
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
54

55
        GetInvoice(ctx context.Context,
56
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
57

58
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
59
                error)
60

61
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
62
                error)
63

64
        GetInvoiceFeatures(ctx context.Context,
65
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
66

67
        GetInvoiceHTLCCustomRecords(ctx context.Context,
68
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
69

70
        GetInvoiceHTLCs(ctx context.Context,
71
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
72

73
        UpdateInvoiceState(ctx context.Context,
74
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
75

76
        UpdateInvoiceAmountPaid(ctx context.Context,
77
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
78

79
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
80

81
        UpdateInvoiceHTLC(ctx context.Context,
82
                arg sqlc.UpdateInvoiceHTLCParams) error
83

84
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
85
                sql.Result, error)
86

87
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
88

89
        // AMP sub invoice specific methods.
90
        UpsertAMPSubInvoice(ctx context.Context,
91
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
92

93
        // TODO(bhandras): remove this once migrations have been separated out.
94
        InsertAMPSubInvoice(ctx context.Context,
95
                arg sqlc.InsertAMPSubInvoiceParams) error
96

97
        UpdateAMPSubInvoiceState(ctx context.Context,
98
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
99

100
        InsertAMPSubInvoiceHTLC(ctx context.Context,
101
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
102

103
        FetchAMPSubInvoices(ctx context.Context,
104
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
105
                error)
106

107
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
108
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
109
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
110

111
        FetchSettledAMPSubInvoices(ctx context.Context,
112
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
113
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
114

115
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
116
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
117
                error)
118

119
        // Invoice events specific methods.
120
        OnInvoiceCreated(ctx context.Context,
121
                arg sqlc.OnInvoiceCreatedParams) error
122

123
        OnInvoiceCanceled(ctx context.Context,
124
                arg sqlc.OnInvoiceCanceledParams) error
125

126
        OnInvoiceSettled(ctx context.Context,
127
                arg sqlc.OnInvoiceSettledParams) error
128

129
        OnAMPSubInvoiceCreated(ctx context.Context,
130
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
131

132
        OnAMPSubInvoiceCanceled(ctx context.Context,
133
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
134

135
        OnAMPSubInvoiceSettled(ctx context.Context,
136
                arg sqlc.OnAMPSubInvoiceSettledParams) error
137

138
        // Migration specific methods.
139
        // TODO(bhandras): remove this once migrations have been separated out.
140
        InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
141
                arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
142

143
        SetKVInvoicePaymentHash(ctx context.Context,
144
                arg sqlc.SetKVInvoicePaymentHashParams) error
145

146
        GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
147
                []byte, error)
148

149
        ClearKVInvoiceHashIndex(ctx context.Context) error
150
}
151

152
var _ InvoiceDB = (*SQLStore)(nil)
153

154
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
155
// SQLInvoiceQueries understands.
156
type SQLInvoiceQueriesTxOptions struct {
157
        // readOnly governs if a read only transaction is needed or not.
158
        readOnly bool
159
}
160

161
// ReadOnly returns true if the transaction should be read only.
162
//
163
// NOTE: This implements the TxOptions.
UNCOV
164
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
×
UNCOV
165
        return a.readOnly
×
UNCOV
166
}
×
167

168
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
UNCOV
169
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
×
UNCOV
170
        return SQLInvoiceQueriesTxOptions{
×
UNCOV
171
                readOnly: true,
×
UNCOV
172
        }
×
UNCOV
173
}
×
174

175
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
176
// of batched database operations.
177
type BatchedSQLInvoiceQueries interface {
178
        SQLInvoiceQueries
179

180
        sqldb.BatchedTx[SQLInvoiceQueries]
181
}
182

183
// SQLStore represents a storage backend.
184
type SQLStore struct {
185
        db    BatchedSQLInvoiceQueries
186
        clock clock.Clock
187
        opts  SQLStoreOptions
188
}
189

190
// SQLStoreOptions holds the options for the SQL store.
191
type SQLStoreOptions struct {
192
        paginationLimit int
193
}
194

195
// defaultSQLStoreOptions returns the default options for the SQL store.
UNCOV
196
func defaultSQLStoreOptions() SQLStoreOptions {
×
UNCOV
197
        return SQLStoreOptions{
×
UNCOV
198
                paginationLimit: defaultQueryPaginationLimit,
×
UNCOV
199
        }
×
UNCOV
200
}
×
201

202
// SQLStoreOption is a functional option that can be used to optionally modify
203
// the behavior of the SQL store.
204
type SQLStoreOption func(*SQLStoreOptions)
205

206
// WithPaginationLimit sets the pagination limit for the SQL store queries that
207
// paginate results.
UNCOV
208
func WithPaginationLimit(limit int) SQLStoreOption {
×
UNCOV
209
        return func(o *SQLStoreOptions) {
×
UNCOV
210
                o.paginationLimit = limit
×
UNCOV
211
        }
×
212
}
213

214
// NewSQLStore creates a new SQLStore instance given a open
215
// BatchedSQLInvoiceQueries storage backend.
216
func NewSQLStore(db BatchedSQLInvoiceQueries,
UNCOV
217
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
×
UNCOV
218

×
UNCOV
219
        opts := defaultSQLStoreOptions()
×
UNCOV
220
        for _, applyOption := range options {
×
UNCOV
221
                applyOption(&opts)
×
UNCOV
222
        }
×
223

UNCOV
224
        return &SQLStore{
×
UNCOV
225
                db:    db,
×
UNCOV
226
                clock: clock,
×
UNCOV
227
                opts:  opts,
×
UNCOV
228
        }
×
229
}
230

231
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
UNCOV
232
        sqlc.InsertInvoiceParams, error) {
×
UNCOV
233

×
UNCOV
234
        // Precompute the payment request hash so we can use it in the query.
×
UNCOV
235
        var paymentRequestHash []byte
×
UNCOV
236
        if len(invoice.PaymentRequest) > 0 {
×
UNCOV
237
                h := sha256.New()
×
UNCOV
238
                h.Write(invoice.PaymentRequest)
×
UNCOV
239
                paymentRequestHash = h.Sum(nil)
×
UNCOV
240
        }
×
241

UNCOV
242
        params := sqlc.InsertInvoiceParams{
×
UNCOV
243
                Hash:       paymentHash[:],
×
UNCOV
244
                AmountMsat: int64(invoice.Terms.Value),
×
UNCOV
245
                CltvDelta: sqldb.SQLInt32(
×
UNCOV
246
                        invoice.Terms.FinalCltvDelta,
×
UNCOV
247
                ),
×
UNCOV
248
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
×
UNCOV
249
                // Note: keysend invoices don't have a payment request.
×
UNCOV
250
                PaymentRequest: sqldb.SQLStr(string(
×
UNCOV
251
                        invoice.PaymentRequest),
×
UNCOV
252
                ),
×
UNCOV
253
                PaymentRequestHash: paymentRequestHash,
×
UNCOV
254
                State:              int16(invoice.State),
×
UNCOV
255
                AmountPaidMsat:     int64(invoice.AmtPaid),
×
UNCOV
256
                IsAmp:              invoice.IsAMP(),
×
UNCOV
257
                IsHodl:             invoice.HodlInvoice,
×
UNCOV
258
                IsKeysend:          invoice.IsKeysend(),
×
UNCOV
259
                CreatedAt:          invoice.CreationDate.UTC(),
×
UNCOV
260
        }
×
UNCOV
261

×
UNCOV
262
        if invoice.Memo != nil {
×
UNCOV
263
                // Store the memo as a nullable string in the database. Note
×
UNCOV
264
                // that for compatibility reasons, we store the value as a valid
×
UNCOV
265
                // string even if it's empty.
×
UNCOV
266
                params.Memo = sql.NullString{
×
UNCOV
267
                        String: string(invoice.Memo),
×
UNCOV
268
                        Valid:  true,
×
UNCOV
269
                }
×
UNCOV
270
        }
×
271

272
        // Some invoices may not have a preimage, like in the case of HODL
273
        // invoices.
UNCOV
274
        if invoice.Terms.PaymentPreimage != nil {
×
UNCOV
275
                preimage := *invoice.Terms.PaymentPreimage
×
UNCOV
276
                if preimage == UnknownPreimage {
×
277
                        return sqlc.InsertInvoiceParams{},
×
278
                                errors.New("cannot use all-zeroes preimage")
×
279
                }
×
UNCOV
280
                params.Preimage = preimage[:]
×
281
        }
282

283
        // Some non MPP payments may have the default (invalid) value.
UNCOV
284
        if invoice.Terms.PaymentAddr != BlankPayAddr {
×
UNCOV
285
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
×
UNCOV
286
        }
×
287

UNCOV
288
        return params, nil
×
289
}
290

291
// AddInvoice inserts the targeted invoice into the database. If the invoice has
292
// *any* payment hashes which already exists within the database, then the
293
// insertion will be aborted and rejected due to the strict policy banning any
294
// duplicate payment hashes.
295
//
296
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
297
func (i *SQLStore) AddInvoice(ctx context.Context,
UNCOV
298
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
×
UNCOV
299

×
UNCOV
300
        // Make sure this is a valid invoice before trying to store it in our
×
UNCOV
301
        // DB.
×
UNCOV
302
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
UNCOV
303
                return 0, err
×
UNCOV
304
        }
×
305

UNCOV
306
        var (
×
UNCOV
307
                writeTxOpts SQLInvoiceQueriesTxOptions
×
UNCOV
308
                invoiceID   int64
×
UNCOV
309
        )
×
UNCOV
310

×
UNCOV
311
        insertInvoiceParams, err := makeInsertInvoiceParams(
×
UNCOV
312
                newInvoice, paymentHash,
×
UNCOV
313
        )
×
UNCOV
314
        if err != nil {
×
315
                return 0, err
×
316
        }
×
317

UNCOV
318
        err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
×
UNCOV
319
                var err error
×
UNCOV
320
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
×
UNCOV
321
                if err != nil {
×
UNCOV
322
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
UNCOV
323
                }
×
324

325
                // TODO(positiveblue): if invocies do not have custom features
326
                // maybe just store the "invoice type" and populate the features
327
                // based on that.
UNCOV
328
                for feature := range newInvoice.Terms.Features.Features() {
×
UNCOV
329
                        params := sqlc.InsertInvoiceFeatureParams{
×
UNCOV
330
                                InvoiceID: invoiceID,
×
UNCOV
331
                                Feature:   int32(feature),
×
UNCOV
332
                        }
×
UNCOV
333

×
UNCOV
334
                        err := db.InsertInvoiceFeature(ctx, params)
×
UNCOV
335
                        if err != nil {
×
336
                                return fmt.Errorf("unable to insert invoice "+
×
337
                                        "feature(%v): %w", feature, err)
×
338
                        }
×
339
                }
340

341
                // Finally add a new event for this invoice.
UNCOV
342
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
UNCOV
343
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
UNCOV
344
                        InvoiceID: invoiceID,
×
UNCOV
345
                })
×
UNCOV
346
        }, func() {})
×
UNCOV
347
        if err != nil {
×
UNCOV
348
                mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
349
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
UNCOV
350
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
351
                        // Add context to unique constraint errors.
×
UNCOV
352
                        return 0, ErrDuplicateInvoice
×
UNCOV
353
                }
×
354

355
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
356
                        paymentHash, err)
×
357
        }
358

UNCOV
359
        newInvoice.AddIndex = uint64(invoiceID)
×
UNCOV
360

×
UNCOV
361
        return newInvoice.AddIndex, nil
×
362
}
363

364
// getInvoiceByRef fetches the invoice with the given reference. The reference
365
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
366
func getInvoiceByRef(ctx context.Context,
UNCOV
367
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
×
UNCOV
368

×
UNCOV
369
        // If the reference is empty, we can't look up the invoice.
×
UNCOV
370
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
×
UNCOV
371
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
UNCOV
372
        }
×
373

374
        // If the reference is a hash only, we can look up the invoice directly
375
        // by the payment hash which is faster.
UNCOV
376
        if ref.IsHashOnly() {
×
UNCOV
377
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
×
UNCOV
378
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
379
                        return sqlc.Invoice{}, ErrInvoiceNotFound
×
UNCOV
380
                }
×
381

UNCOV
382
                return invoice, err
×
383
        }
384

385
        // Otherwise the reference may include more fields, so we'll need to
386
        // assemble the query parameters based on the fields that are set.
UNCOV
387
        var params sqlc.GetInvoiceParams
×
UNCOV
388

×
UNCOV
389
        if ref.PayHash() != nil {
×
UNCOV
390
                params.Hash = ref.PayHash()[:]
×
UNCOV
391
        }
×
392

393
        // Newer invoices (0.11 and up) are indexed by payment address in
394
        // addition to payment hash, but pre 0.8 invoices do not have one at
395
        // all. Only allow lookups for payment address if it is not a blank
396
        // payment address, which is a special-cased value for legacy keysend
397
        // invoices.
UNCOV
398
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
UNCOV
399
                params.PaymentAddr = ref.PayAddr()[:]
×
UNCOV
400
        }
×
401

402
        // If the reference has a set ID we'll fetch the invoice which has the
403
        // corresponding AMP sub invoice.
UNCOV
404
        if ref.SetID() != nil {
×
UNCOV
405
                params.SetID = ref.SetID()[:]
×
UNCOV
406
        }
×
407

UNCOV
408
        var (
×
UNCOV
409
                rows []sqlc.Invoice
×
UNCOV
410
                err  error
×
UNCOV
411
        )
×
UNCOV
412

×
UNCOV
413
        // We need to split the query based on how we intend to look up the
×
UNCOV
414
        // invoice. If only the set ID is given then we want to have an exact
×
UNCOV
415
        // match on the set ID. If other fields are given, we want to match on
×
UNCOV
416
        // those fields and the set ID but with a less strict join condition.
×
UNCOV
417
        if params.Hash == nil && params.PaymentAddr == nil &&
×
UNCOV
418
                params.SetID != nil {
×
UNCOV
419

×
UNCOV
420
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
×
UNCOV
421
        } else {
×
UNCOV
422
                rows, err = db.GetInvoice(ctx, params)
×
UNCOV
423
        }
×
424

UNCOV
425
        switch {
×
UNCOV
426
        case len(rows) == 0:
×
UNCOV
427
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
428

429
        case len(rows) > 1:
×
430
                // In case the reference is ambiguous, meaning it matches more
×
431
                // than        one invoice, we'll return an error.
×
432
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
×
433
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
434

435
        case err != nil:
×
436
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
437
                        err)
×
438
        }
439

UNCOV
440
        return rows[0], nil
×
441
}
442

443
// fetchInvoice fetches the common invoice data and the AMP state for the
444
// invoice with the given reference.
445
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
UNCOV
446
        *Invoice, error) {
×
UNCOV
447

×
UNCOV
448
        // Fetch the invoice from the database.
×
UNCOV
449
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
×
UNCOV
450
        if err != nil {
×
UNCOV
451
                return nil, err
×
UNCOV
452
        }
×
453

UNCOV
454
        var (
×
UNCOV
455
                setID         *[32]byte
×
UNCOV
456
                fetchAmpHtlcs bool
×
UNCOV
457
        )
×
UNCOV
458

×
UNCOV
459
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
UNCOV
460
        // the modifier.
×
UNCOV
461
        switch ref.Modifier() {
×
UNCOV
462
        case DefaultModifier:
×
UNCOV
463
                // By default we'll fetch all AMP HTLCs.
×
UNCOV
464
                setID = nil
×
UNCOV
465
                fetchAmpHtlcs = true
×
466

UNCOV
467
        case HtlcSetOnlyModifier:
×
UNCOV
468
                // In this case we'll fetch all AMP HTLCs for the specified set
×
UNCOV
469
                // id.
×
UNCOV
470
                if ref.SetID() == nil {
×
471
                        return nil, fmt.Errorf("set ID is required to use " +
×
472
                                "the HTLC set only modifier")
×
473
                }
×
474

UNCOV
475
                setID = ref.SetID()
×
UNCOV
476
                fetchAmpHtlcs = true
×
477

UNCOV
478
        case HtlcSetBlankModifier:
×
UNCOV
479
                // No need to fetch any HTLCs.
×
UNCOV
480
                setID = nil
×
UNCOV
481
                fetchAmpHtlcs = false
×
482

483
        default:
×
484
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
485
                        ref.Modifier())
×
486
        }
487

488
        // Fetch the rest of the invoice data and fill the invoice struct.
UNCOV
489
        _, invoice, err := fetchInvoiceData(
×
UNCOV
490
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
×
UNCOV
491
        )
×
UNCOV
492
        if err != nil {
×
493
                return nil, err
×
494
        }
×
495

UNCOV
496
        return invoice, nil
×
497
}
498

499
// fetchAmpState fetches the AMP state for the invoice with the given ID.
500
// Optional setID can be provided to fetch the state for a specific AMP HTLC
501
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
502
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
503
// well.
504
//
505
//nolint:funlen
506
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
507
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
UNCOV
508
        HTLCSet, error) {
×
UNCOV
509

×
UNCOV
510
        var paramSetID []byte
×
UNCOV
511
        if setID != nil {
×
UNCOV
512
                paramSetID = setID[:]
×
UNCOV
513
        }
×
514

515
        // First fetch all the AMP sub invoices for this invoice or the one
516
        // matching the provided set ID.
UNCOV
517
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
UNCOV
518
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
UNCOV
519
                        InvoiceID: invoiceID,
×
UNCOV
520
                        SetID:     paramSetID,
×
UNCOV
521
                },
×
UNCOV
522
        )
×
UNCOV
523
        if err != nil {
×
524
                return nil, nil, err
×
525
        }
×
526

UNCOV
527
        ampState := make(map[SetID]InvoiceStateAMP)
×
UNCOV
528
        for _, row := range ampInvoiceRows {
×
UNCOV
529
                var rowSetID [32]byte
×
UNCOV
530

×
UNCOV
531
                if len(row.SetID) != 32 {
×
532
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
533
                                len(row.SetID))
×
534
                }
×
535

UNCOV
536
                var settleDate time.Time
×
UNCOV
537
                if row.SettledAt.Valid {
×
UNCOV
538
                        settleDate = row.SettledAt.Time.Local()
×
UNCOV
539
                }
×
540

UNCOV
541
                copy(rowSetID[:], row.SetID)
×
UNCOV
542
                ampState[rowSetID] = InvoiceStateAMP{
×
UNCOV
543
                        State:       HtlcState(row.State),
×
UNCOV
544
                        SettleIndex: uint64(row.SettleIndex.Int64),
×
UNCOV
545
                        SettleDate:  settleDate,
×
UNCOV
546
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
UNCOV
547
                }
×
548
        }
549

UNCOV
550
        if !fetchHtlcs {
×
UNCOV
551
                return ampState, nil, nil
×
UNCOV
552
        }
×
553

UNCOV
554
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
555
        if err != nil {
×
556
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
557
                        "invoice HTLCs: %w", err)
×
558
        }
×
559

UNCOV
560
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
UNCOV
561
        for _, row := range customRecordRows {
×
UNCOV
562
                if _, ok := customRecords[row.HtlcID]; !ok {
×
UNCOV
563
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
UNCOV
564
                }
×
565

UNCOV
566
                value := row.Value
×
UNCOV
567
                if value == nil {
×
568
                        value = []byte{}
×
569
                }
×
570

UNCOV
571
                customRecords[row.HtlcID][uint64(row.Key)] = value
×
572
        }
573

574
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
575
        // provided set ID.
UNCOV
576
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
×
UNCOV
577
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
×
UNCOV
578
                        InvoiceID: invoiceID,
×
UNCOV
579
                        SetID:     paramSetID,
×
UNCOV
580
                },
×
UNCOV
581
        )
×
UNCOV
582
        if err != nil {
×
583
                return nil, nil, err
×
584
        }
×
585

UNCOV
586
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
×
UNCOV
587
        for _, row := range ampHtlcRows {
×
UNCOV
588
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
589
                if err != nil {
×
590
                        return nil, nil, err
×
591
                }
×
592

UNCOV
593
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
594

×
UNCOV
595
                if row.HtlcID < 0 {
×
596
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
597
                                "value: %v", row.HtlcID)
×
598
                }
×
599

UNCOV
600
                htlcID := uint64(row.HtlcID)
×
UNCOV
601

×
UNCOV
602
                circuitKey := CircuitKey{
×
UNCOV
603
                        ChanID: chanID,
×
UNCOV
604
                        HtlcID: htlcID,
×
UNCOV
605
                }
×
UNCOV
606

×
UNCOV
607
                htlc := &InvoiceHTLC{
×
UNCOV
608
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
609
                        AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
610
                        AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
611
                        Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
612
                        State:        HtlcState(row.State),
×
UNCOV
613
                }
×
UNCOV
614

×
UNCOV
615
                if row.TotalMppMsat.Valid {
×
UNCOV
616
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
UNCOV
617
                                row.TotalMppMsat.Int64,
×
UNCOV
618
                        )
×
UNCOV
619
                }
×
620

UNCOV
621
                if row.ResolveTime.Valid {
×
UNCOV
622
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
623
                }
×
624

UNCOV
625
                var (
×
UNCOV
626
                        rootShare [32]byte
×
UNCOV
627
                        setID     [32]byte
×
UNCOV
628
                )
×
UNCOV
629

×
UNCOV
630
                if len(row.RootShare) != 32 {
×
631
                        return nil, nil, fmt.Errorf("invalid root share "+
×
632
                                "length: %d", len(row.RootShare))
×
633
                }
×
UNCOV
634
                copy(rootShare[:], row.RootShare)
×
UNCOV
635

×
UNCOV
636
                if len(row.SetID) != 32 {
×
637
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
638
                                len(row.SetID))
×
639
                }
×
UNCOV
640
                copy(setID[:], row.SetID)
×
UNCOV
641

×
UNCOV
642
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
643
                        return nil, nil, fmt.Errorf("invalid child index "+
×
644
                                "value: %v", row.ChildIndex)
×
645
                }
×
646

UNCOV
647
                ampRecord := record.NewAMP(
×
UNCOV
648
                        rootShare, setID, uint32(row.ChildIndex),
×
UNCOV
649
                )
×
UNCOV
650

×
UNCOV
651
                htlc.AMP = &InvoiceHtlcAMPData{
×
UNCOV
652
                        Record: *ampRecord,
×
UNCOV
653
                }
×
UNCOV
654

×
UNCOV
655
                if len(row.Hash) != 32 {
×
656
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
657
                                len(row.Hash))
×
658
                }
×
UNCOV
659
                copy(htlc.AMP.Hash[:], row.Hash)
×
UNCOV
660

×
UNCOV
661
                if row.Preimage != nil {
×
UNCOV
662
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
UNCOV
663
                        if err != nil {
×
664
                                return nil, nil, err
×
665
                        }
×
666

UNCOV
667
                        htlc.AMP.Preimage = &preimage
×
668
                }
669

UNCOV
670
                if _, ok := customRecords[row.ID]; ok {
×
UNCOV
671
                        htlc.CustomRecords = customRecords[row.ID]
×
UNCOV
672
                } else {
×
UNCOV
673
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
674
                }
×
675

UNCOV
676
                ampHtlcs[circuitKey] = htlc
×
677
        }
678

UNCOV
679
        if len(ampHtlcs) > 0 {
×
UNCOV
680
                for setID := range ampState {
×
UNCOV
681
                        var amtPaid lnwire.MilliSatoshi
×
UNCOV
682
                        invoiceKeys := make(
×
UNCOV
683
                                map[models.CircuitKey]struct{},
×
UNCOV
684
                        )
×
UNCOV
685

×
UNCOV
686
                        for key, htlc := range ampHtlcs {
×
UNCOV
687
                                if htlc.AMP.Record.SetID() != setID {
×
UNCOV
688
                                        continue
×
689
                                }
690

UNCOV
691
                                invoiceKeys[key] = struct{}{}
×
UNCOV
692

×
UNCOV
693
                                if htlc.State != HtlcStateCanceled {
×
UNCOV
694
                                        amtPaid += htlc.Amt
×
UNCOV
695
                                }
×
696
                        }
697

UNCOV
698
                        setState := ampState[setID]
×
UNCOV
699
                        setState.InvoiceKeys = invoiceKeys
×
UNCOV
700
                        setState.AmtPaid = amtPaid
×
UNCOV
701
                        ampState[setID] = setState
×
702
                }
703
        }
704

UNCOV
705
        return ampState, ampHtlcs, nil
×
706
}
707

708
// LookupInvoice attempts to look up an invoice corresponding the passed in
709
// reference. The reference may be a payment hash, a payment address, or a set
710
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
711
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
712
// error.
713
func (i *SQLStore) LookupInvoice(ctx context.Context,
UNCOV
714
        ref InvoiceRef) (Invoice, error) {
×
UNCOV
715

×
UNCOV
716
        var (
×
UNCOV
717
                invoice *Invoice
×
UNCOV
718
                err     error
×
UNCOV
719
        )
×
UNCOV
720

×
UNCOV
721
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
722
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
723
                invoice, err = fetchInvoice(ctx, db, ref)
×
UNCOV
724

×
UNCOV
725
                return err
×
UNCOV
726
        }, func() {})
×
UNCOV
727
        if txErr != nil {
×
UNCOV
728
                return Invoice{}, txErr
×
UNCOV
729
        }
×
730

UNCOV
731
        return *invoice, nil
×
732
}
733

734
// FetchPendingInvoices returns all the invoices that are currently in a
735
// "pending" state. An invoice is pending if it has been created but not yet
736
// settled or canceled.
737
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
UNCOV
738
        map[lntypes.Hash]Invoice, error) {
×
UNCOV
739

×
UNCOV
740
        var invoices map[lntypes.Hash]Invoice
×
UNCOV
741

×
UNCOV
742
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
743
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
744
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
745
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
746
                                PendingOnly: true,
×
UNCOV
747
                                NumOffset:   int32(offset),
×
UNCOV
748
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
749
                                Reverse:     false,
×
UNCOV
750
                        }
×
UNCOV
751

×
UNCOV
752
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
753
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
754
                                return 0, fmt.Errorf("unable to get invoices "+
×
755
                                        "from db: %w", err)
×
756
                        }
×
757

758
                        // Load all the information for the invoices.
UNCOV
759
                        for _, row := range rows {
×
UNCOV
760
                                hash, invoice, err := fetchInvoiceData(
×
UNCOV
761
                                        ctx, db, row, nil, true,
×
UNCOV
762
                                )
×
UNCOV
763
                                if err != nil {
×
764
                                        return 0, err
×
765
                                }
×
766

UNCOV
767
                                invoices[*hash] = *invoice
×
768
                        }
769

UNCOV
770
                        return len(rows), nil
×
771
                }, i.opts.paginationLimit)
UNCOV
772
        }, func() {
×
UNCOV
773
                invoices = make(map[lntypes.Hash]Invoice)
×
UNCOV
774
        })
×
UNCOV
775
        if err != nil {
×
776
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
777
                        err)
×
778
        }
×
779

UNCOV
780
        return invoices, nil
×
781
}
782

783
// InvoicesSettledSince can be used by callers to catch up any settled invoices
784
// they missed within the settled invoice time series. We'll return all known
785
// settled invoice that have a settle index higher than the passed idx.
786
//
787
// NOTE: The index starts from 1. As a result we enforce that specifying a value
788
// below the starting index value is a noop.
789
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
UNCOV
790
        []Invoice, error) {
×
UNCOV
791

×
UNCOV
792
        var (
×
UNCOV
793
                invoices       []Invoice
×
UNCOV
794
                start          = time.Now()
×
UNCOV
795
                lastLogTime    = time.Now()
×
UNCOV
796
                processedCount int
×
UNCOV
797
        )
×
UNCOV
798

×
UNCOV
799
        if idx == 0 {
×
UNCOV
800
                return invoices, nil
×
UNCOV
801
        }
×
802

UNCOV
803
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
804
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
805
                err := queryWithLimit(func(offset int) (int, error) {
×
UNCOV
806
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
807
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
808
                                NumOffset:      int32(offset),
×
UNCOV
809
                                NumLimit:       int32(i.opts.paginationLimit),
×
UNCOV
810
                                Reverse:        false,
×
UNCOV
811
                        }
×
UNCOV
812

×
UNCOV
813
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
814
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
815
                                return 0, fmt.Errorf("unable to get invoices "+
×
816
                                        "from db: %w", err)
×
817
                        }
×
818

819
                        // Load all the information for the invoices.
UNCOV
820
                        for _, row := range rows {
×
UNCOV
821
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
822
                                        ctx, db, row, nil, true,
×
UNCOV
823
                                )
×
UNCOV
824
                                if err != nil {
×
825
                                        return 0, fmt.Errorf("unable to fetch "+
×
826
                                                "invoice(id=%d) from db: %w",
×
827
                                                row.ID, err)
×
828
                                }
×
829

UNCOV
830
                                invoices = append(invoices, *invoice)
×
UNCOV
831

×
UNCOV
832
                                processedCount++
×
UNCOV
833
                                if time.Since(lastLogTime) >=
×
UNCOV
834
                                        invoiceProgressLogInterval {
×
835

×
836
                                        log.Debugf("Processed %d settled "+
×
837
                                                "invoices which have a settle "+
×
838
                                                "index greater than %v",
×
839
                                                processedCount, idx)
×
840

×
841
                                        lastLogTime = time.Now()
×
842
                                }
×
843
                        }
844

UNCOV
845
                        return len(rows), nil
×
846
                }, i.opts.paginationLimit)
UNCOV
847
                if err != nil {
×
848
                        return err
×
849
                }
×
850

851
                // Now fetch all the AMP sub invoices that were settled since
852
                // the provided index.
UNCOV
853
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
×
UNCOV
854
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
UNCOV
855
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
856
                        },
×
UNCOV
857
                )
×
UNCOV
858
                if err != nil {
×
859
                        return err
×
860
                }
×
861

UNCOV
862
                for _, ampInvoice := range ampInvoices {
×
UNCOV
863
                        // Convert the row to a sqlc.Invoice so we can use the
×
UNCOV
864
                        // existing fetchInvoiceData function.
×
UNCOV
865
                        sqlInvoice := sqlc.Invoice{
×
UNCOV
866
                                ID:             ampInvoice.ID,
×
UNCOV
867
                                Hash:           ampInvoice.Hash,
×
UNCOV
868
                                Preimage:       ampInvoice.Preimage,
×
UNCOV
869
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
UNCOV
870
                                SettledAt:      ampInvoice.AmpSettledAt,
×
UNCOV
871
                                Memo:           ampInvoice.Memo,
×
UNCOV
872
                                AmountMsat:     ampInvoice.AmountMsat,
×
UNCOV
873
                                CltvDelta:      ampInvoice.CltvDelta,
×
UNCOV
874
                                Expiry:         ampInvoice.Expiry,
×
UNCOV
875
                                PaymentAddr:    ampInvoice.PaymentAddr,
×
UNCOV
876
                                PaymentRequest: ampInvoice.PaymentRequest,
×
UNCOV
877
                                State:          ampInvoice.State,
×
UNCOV
878
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
UNCOV
879
                                IsAmp:          ampInvoice.IsAmp,
×
UNCOV
880
                                IsHodl:         ampInvoice.IsHodl,
×
UNCOV
881
                                IsKeysend:      ampInvoice.IsKeysend,
×
UNCOV
882
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
UNCOV
883
                        }
×
UNCOV
884

×
UNCOV
885
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
UNCOV
886
                        _, invoice, err := fetchInvoiceData(
×
UNCOV
887
                                ctx, db, sqlInvoice,
×
UNCOV
888
                                (*[32]byte)(ampInvoice.SetID), true,
×
UNCOV
889
                        )
×
UNCOV
890
                        if err != nil {
×
891
                                return fmt.Errorf("unable to fetch "+
×
892
                                        "AMP invoice(id=%d) from db: %w",
×
893
                                        ampInvoice.ID, err)
×
894
                        }
×
895

UNCOV
896
                        invoices = append(invoices, *invoice)
×
UNCOV
897

×
UNCOV
898
                        processedCount++
×
UNCOV
899
                        if time.Since(lastLogTime) >=
×
UNCOV
900
                                invoiceProgressLogInterval {
×
901

×
902
                                log.Debugf("Processed %d settled invoices "+
×
903
                                        "including AMP sub invoices which "+
×
904
                                        "have a settle index greater than %v",
×
905
                                        processedCount, idx)
×
906

×
907
                                lastLogTime = time.Now()
×
908
                        }
×
909
                }
910

UNCOV
911
                return nil
×
UNCOV
912
        }, func() {
×
UNCOV
913
                invoices = nil
×
UNCOV
914
        })
×
UNCOV
915
        if err != nil {
×
916
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
917
                        "index (excluding) %d: %w", idx, err)
×
918
        }
×
919

UNCOV
920
        elapsed := time.Since(start)
×
UNCOV
921
        log.Debugf("Completed scanning for settled invoices starting at "+
×
UNCOV
922
                "index %v: total_processed=%d, found_invoices=%d, elapsed=%v",
×
UNCOV
923
                idx, processedCount, len(invoices),
×
UNCOV
924
                elapsed.Round(time.Millisecond))
×
UNCOV
925

×
UNCOV
926
        return invoices, nil
×
927
}
928

929
// InvoicesAddedSince can be used by callers to seek into the event time series
930
// of all the invoices added in the database. This method will return all
931
// invoices with an add index greater than the specified idx.
932
//
933
// NOTE: The index starts from 1. As a result we enforce that specifying a value
934
// below the starting index value is a noop.
935
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
UNCOV
936
        []Invoice, error) {
×
UNCOV
937

×
UNCOV
938
        var (
×
UNCOV
939
                result         []Invoice
×
UNCOV
940
                start          = time.Now()
×
UNCOV
941
                lastLogTime    = time.Now()
×
UNCOV
942
                processedCount int
×
UNCOV
943
        )
×
UNCOV
944

×
UNCOV
945
        if idx == 0 {
×
UNCOV
946
                return result, nil
×
UNCOV
947
        }
×
948

UNCOV
949
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
950
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
951
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
952
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
953
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
954
                                NumOffset:   int32(offset),
×
UNCOV
955
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
956
                                Reverse:     false,
×
UNCOV
957
                        }
×
UNCOV
958

×
UNCOV
959
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
960
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
961
                                return 0, fmt.Errorf("unable to get invoices "+
×
962
                                        "from db: %w", err)
×
963
                        }
×
964

965
                        // Load all the information for the invoices.
UNCOV
966
                        for _, row := range rows {
×
UNCOV
967
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
968
                                        ctx, db, row, nil, true,
×
UNCOV
969
                                )
×
UNCOV
970
                                if err != nil {
×
971
                                        return 0, err
×
972
                                }
×
973

UNCOV
974
                                result = append(result, *invoice)
×
UNCOV
975

×
UNCOV
976
                                processedCount++
×
UNCOV
977
                                if time.Since(lastLogTime) >=
×
UNCOV
978
                                        invoiceProgressLogInterval {
×
979

×
980
                                        log.Debugf("Processed %d invoices "+
×
981
                                                "which were added since add "+
×
982
                                                "index %v", processedCount, idx)
×
983

×
984
                                        lastLogTime = time.Now()
×
985
                                }
×
986
                        }
987

UNCOV
988
                        return len(rows), nil
×
989
                }, i.opts.paginationLimit)
UNCOV
990
        }, func() {
×
UNCOV
991
                result = nil
×
UNCOV
992
        })
×
993

UNCOV
994
        if err != nil {
×
995
                return nil, fmt.Errorf("unable to get invoices added since "+
×
996
                        "index %d: %w", idx, err)
×
997
        }
×
998

UNCOV
999
        elapsed := time.Since(start)
×
UNCOV
1000
        log.Debugf("Completed scanning for invoices added since index %v: "+
×
UNCOV
1001
                "total_processed=%d, found_invoices=%d, elapsed=%v",
×
UNCOV
1002
                idx, processedCount, len(result),
×
UNCOV
1003
                elapsed.Round(time.Millisecond))
×
UNCOV
1004

×
UNCOV
1005
        return result, nil
×
1006
}
1007

1008
// QueryInvoices allows a caller to query the invoice database for invoices
1009
// within the specified add index range.
1010
func (i *SQLStore) QueryInvoices(ctx context.Context,
UNCOV
1011
        q InvoiceQuery) (InvoiceSlice, error) {
×
UNCOV
1012

×
UNCOV
1013
        var invoices []Invoice
×
UNCOV
1014

×
UNCOV
1015
        if q.NumMaxInvoices == 0 {
×
1016
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
1017
                        "be non-zero")
×
1018
        }
×
1019

UNCOV
1020
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
1021
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1022
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
1023
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
1024
                                NumOffset:   int32(offset),
×
UNCOV
1025
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
1026
                                PendingOnly: q.PendingOnly,
×
UNCOV
1027
                                Reverse:     q.Reversed,
×
UNCOV
1028
                        }
×
UNCOV
1029

×
UNCOV
1030
                        if q.Reversed {
×
UNCOV
1031
                                // If the index offset was not set, we want to
×
UNCOV
1032
                                // fetch from the lastest invoice.
×
UNCOV
1033
                                if q.IndexOffset == 0 {
×
UNCOV
1034
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
1035
                                                int64(math.MaxInt64),
×
UNCOV
1036
                                        )
×
UNCOV
1037
                                } else {
×
UNCOV
1038
                                        // The invoice with index offset id must
×
UNCOV
1039
                                        // not be included in the results.
×
UNCOV
1040
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
1041
                                                q.IndexOffset - 1,
×
UNCOV
1042
                                        )
×
UNCOV
1043
                                }
×
UNCOV
1044
                        } else {
×
UNCOV
1045
                                // The invoice with index offset id must not be
×
UNCOV
1046
                                // included in the results.
×
UNCOV
1047
                                params.AddIndexGet = sqldb.SQLInt64(
×
UNCOV
1048
                                        q.IndexOffset + 1,
×
UNCOV
1049
                                )
×
UNCOV
1050
                        }
×
1051

UNCOV
1052
                        if q.CreationDateStart != 0 {
×
UNCOV
1053
                                params.CreatedAfter = sqldb.SQLTime(
×
UNCOV
1054
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
UNCOV
1055
                                )
×
UNCOV
1056
                        }
×
1057

UNCOV
1058
                        if q.CreationDateEnd != 0 {
×
UNCOV
1059
                                // We need to add 1 to the end date as we're
×
UNCOV
1060
                                // checking less than the end date in SQL.
×
UNCOV
1061
                                params.CreatedBefore = sqldb.SQLTime(
×
UNCOV
1062
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
×
UNCOV
1063
                                )
×
UNCOV
1064
                        }
×
1065

UNCOV
1066
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
1067
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1068
                                return 0, fmt.Errorf("unable to get invoices "+
×
1069
                                        "from db: %w", err)
×
1070
                        }
×
1071

1072
                        // Load all the information for the invoices.
UNCOV
1073
                        for _, row := range rows {
×
UNCOV
1074
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
1075
                                        ctx, db, row, nil, true,
×
UNCOV
1076
                                )
×
UNCOV
1077
                                if err != nil {
×
1078
                                        return 0, err
×
1079
                                }
×
1080

UNCOV
1081
                                invoices = append(invoices, *invoice)
×
UNCOV
1082

×
UNCOV
1083
                                if len(invoices) == int(q.NumMaxInvoices) {
×
UNCOV
1084
                                        return 0, nil
×
UNCOV
1085
                                }
×
1086
                        }
1087

UNCOV
1088
                        return len(rows), nil
×
1089
                }, i.opts.paginationLimit)
UNCOV
1090
        }, func() {
×
UNCOV
1091
                invoices = nil
×
UNCOV
1092
        })
×
UNCOV
1093
        if err != nil {
×
1094
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
1095
                        "invoices: %w", err)
×
1096
        }
×
1097

UNCOV
1098
        if len(invoices) == 0 {
×
UNCOV
1099
                return InvoiceSlice{
×
UNCOV
1100
                        InvoiceQuery: q,
×
UNCOV
1101
                }, nil
×
UNCOV
1102
        }
×
1103

1104
        // If we iterated through the add index in reverse order, then
1105
        // we'll need to reverse the slice of invoices to return them in
1106
        // forward order.
UNCOV
1107
        if q.Reversed {
×
UNCOV
1108
                numInvoices := len(invoices)
×
UNCOV
1109
                for i := 0; i < numInvoices/2; i++ {
×
UNCOV
1110
                        reverse := numInvoices - i - 1
×
UNCOV
1111
                        invoices[i], invoices[reverse] =
×
UNCOV
1112
                                invoices[reverse], invoices[i]
×
UNCOV
1113
                }
×
1114
        }
1115

UNCOV
1116
        res := InvoiceSlice{
×
UNCOV
1117
                InvoiceQuery:     q,
×
UNCOV
1118
                Invoices:         invoices,
×
UNCOV
1119
                FirstIndexOffset: invoices[0].AddIndex,
×
UNCOV
1120
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
UNCOV
1121
        }
×
UNCOV
1122

×
UNCOV
1123
        return res, nil
×
1124
}
1125

1126
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1127
// a SQL database as the backend.
1128
type sqlInvoiceUpdater struct {
1129
        db         SQLInvoiceQueries
1130
        ctx        context.Context //nolint:containedctx
1131
        invoice    *Invoice
1132
        updateTime time.Time
1133
}
1134

1135
// AddHtlc adds a new htlc to the invoice.
1136
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
UNCOV
1137
        newHtlc *InvoiceHTLC) error {
×
UNCOV
1138

×
UNCOV
1139
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
UNCOV
1140
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
UNCOV
1141
                        HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1142
                        ChanID: strconv.FormatUint(
×
UNCOV
1143
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1144
                        ),
×
UNCOV
1145
                        AmountMsat: int64(newHtlc.Amt),
×
UNCOV
1146
                        TotalMppMsat: sql.NullInt64{
×
UNCOV
1147
                                Int64: int64(newHtlc.MppTotalAmt),
×
UNCOV
1148
                                Valid: newHtlc.MppTotalAmt != 0,
×
UNCOV
1149
                        },
×
UNCOV
1150
                        AcceptHeight: int32(newHtlc.AcceptHeight),
×
UNCOV
1151
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
UNCOV
1152
                        ExpiryHeight: int32(newHtlc.Expiry),
×
UNCOV
1153
                        State:        int16(newHtlc.State),
×
UNCOV
1154
                        InvoiceID:    int64(s.invoice.AddIndex),
×
UNCOV
1155
                },
×
UNCOV
1156
        )
×
UNCOV
1157
        if err != nil {
×
1158
                return err
×
1159
        }
×
1160

UNCOV
1161
        for key, value := range newHtlc.CustomRecords {
×
UNCOV
1162
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
UNCOV
1163
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
UNCOV
1164
                                // TODO(bhandras): schema might be wrong here
×
UNCOV
1165
                                // as the custom record key is an uint64.
×
UNCOV
1166
                                Key:    int64(key),
×
UNCOV
1167
                                Value:  value,
×
UNCOV
1168
                                HtlcID: htlcPrimaryKeyID,
×
UNCOV
1169
                        },
×
UNCOV
1170
                )
×
UNCOV
1171
                if err != nil {
×
1172
                        return err
×
1173
                }
×
1174
        }
1175

UNCOV
1176
        if newHtlc.AMP != nil {
×
UNCOV
1177
                setID := newHtlc.AMP.Record.SetID()
×
UNCOV
1178

×
UNCOV
1179
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
UNCOV
1180
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
UNCOV
1181
                                SetID:     setID[:],
×
UNCOV
1182
                                CreatedAt: s.updateTime.UTC(),
×
UNCOV
1183
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1184
                        },
×
UNCOV
1185
                )
×
UNCOV
1186
                if err != nil {
×
UNCOV
1187
                        mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
1188
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
×
UNCOV
1189
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
1190
                                return ErrDuplicateSetID{
×
UNCOV
1191
                                        SetID: setID,
×
UNCOV
1192
                                }
×
UNCOV
1193
                        }
×
1194

1195
                        return err
×
1196
                }
1197

1198
                // If we're just inserting the AMP invoice, we'll get a non
1199
                // zero rows affected count.
UNCOV
1200
                rowsAffected, err := upsertResult.RowsAffected()
×
UNCOV
1201
                if err != nil {
×
1202
                        return err
×
1203
                }
×
UNCOV
1204
                if rowsAffected != 0 {
×
UNCOV
1205
                        // If we're inserting a new AMP invoice, we'll also
×
UNCOV
1206
                        // insert a new invoice event.
×
UNCOV
1207
                        err = s.db.OnAMPSubInvoiceCreated(
×
UNCOV
1208
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
UNCOV
1209
                                        AddedAt:   s.updateTime.UTC(),
×
UNCOV
1210
                                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1211
                                        SetID:     setID[:],
×
UNCOV
1212
                                },
×
UNCOV
1213
                        )
×
UNCOV
1214
                        if err != nil {
×
1215
                                return err
×
1216
                        }
×
1217
                }
1218

UNCOV
1219
                rootShare := newHtlc.AMP.Record.RootShare()
×
UNCOV
1220

×
UNCOV
1221
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
UNCOV
1222
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1223
                        SetID:     setID[:],
×
UNCOV
1224
                        HtlcID:    htlcPrimaryKeyID,
×
UNCOV
1225
                        RootShare: rootShare[:],
×
UNCOV
1226
                        ChildIndex: int64(
×
UNCOV
1227
                                newHtlc.AMP.Record.ChildIndex(),
×
UNCOV
1228
                        ),
×
UNCOV
1229
                        Hash: newHtlc.AMP.Hash[:],
×
UNCOV
1230
                }
×
UNCOV
1231

×
UNCOV
1232
                if newHtlc.AMP.Preimage != nil {
×
UNCOV
1233
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
UNCOV
1234
                }
×
1235

UNCOV
1236
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
UNCOV
1237
                if err != nil {
×
1238
                        return err
×
1239
                }
×
1240
        }
1241

UNCOV
1242
        return nil
×
1243
}
1244

1245
// ResolveHtlc marks an htlc as resolved with the given state.
1246
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
UNCOV
1247
        state HtlcState, resolveTime time.Time) error {
×
UNCOV
1248

×
UNCOV
1249
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
UNCOV
1250
                HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1251
                ChanID: strconv.FormatUint(
×
UNCOV
1252
                        circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1253
                ),
×
UNCOV
1254
                InvoiceID:   int64(s.invoice.AddIndex),
×
UNCOV
1255
                State:       int16(state),
×
UNCOV
1256
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
UNCOV
1257
        })
×
UNCOV
1258
}
×
1259

1260
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1261
// identified by the setID.
1262
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
UNCOV
1263
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
UNCOV
1264

×
UNCOV
1265
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
UNCOV
1266
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
×
UNCOV
1267
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1268
                        SetID:     setID[:],
×
UNCOV
1269
                        HtlcID:    int64(circuitKey.HtlcID),
×
UNCOV
1270
                        Preimage:  preimage[:],
×
UNCOV
1271
                        ChanID: strconv.FormatUint(
×
UNCOV
1272
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1273
                        ),
×
UNCOV
1274
                },
×
UNCOV
1275
        )
×
UNCOV
1276
        if err != nil {
×
1277
                return err
×
1278
        }
×
1279

UNCOV
1280
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1281
        if err != nil {
×
1282
                return err
×
1283
        }
×
UNCOV
1284
        if rowsAffected == 0 {
×
1285
                return ErrInvoiceNotFound
×
1286
        }
×
1287

UNCOV
1288
        return nil
×
1289
}
1290

1291
// UpdateInvoiceState updates the invoice state to the new state.
1292
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
UNCOV
1293
        newState ContractState, preimage *lntypes.Preimage) error {
×
UNCOV
1294

×
UNCOV
1295
        var (
×
UNCOV
1296
                settleIndex sql.NullInt64
×
UNCOV
1297
                settledAt   sql.NullTime
×
UNCOV
1298
        )
×
UNCOV
1299

×
UNCOV
1300
        switch newState {
×
UNCOV
1301
        case ContractSettled:
×
UNCOV
1302
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1303
                if err != nil {
×
1304
                        return err
×
1305
                }
×
1306

UNCOV
1307
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1308

×
UNCOV
1309
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1310
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1311

×
UNCOV
1312
                err = s.db.OnInvoiceSettled(
×
UNCOV
1313
                        s.ctx, sqlc.OnInvoiceSettledParams{
×
UNCOV
1314
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1315
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1316
                        },
×
UNCOV
1317
                )
×
UNCOV
1318
                if err != nil {
×
1319
                        return err
×
1320
                }
×
1321

UNCOV
1322
        case ContractCanceled:
×
UNCOV
1323
                err := s.db.OnInvoiceCanceled(
×
UNCOV
1324
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
UNCOV
1325
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1326
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1327
                        },
×
UNCOV
1328
                )
×
UNCOV
1329
                if err != nil {
×
1330
                        return err
×
1331
                }
×
1332
        }
1333

UNCOV
1334
        params := sqlc.UpdateInvoiceStateParams{
×
UNCOV
1335
                ID:          int64(s.invoice.AddIndex),
×
UNCOV
1336
                State:       int16(newState),
×
UNCOV
1337
                SettleIndex: settleIndex,
×
UNCOV
1338
                SettledAt:   settledAt,
×
UNCOV
1339
        }
×
UNCOV
1340

×
UNCOV
1341
        if preimage != nil {
×
UNCOV
1342
                params.Preimage = preimage[:]
×
UNCOV
1343
        }
×
1344

UNCOV
1345
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
UNCOV
1346
        if err != nil {
×
1347
                return err
×
1348
        }
×
UNCOV
1349
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1350
        if err != nil {
×
1351
                return err
×
1352
        }
×
1353

UNCOV
1354
        if rowsAffected == 0 {
×
1355
                return ErrInvoiceNotFound
×
1356
        }
×
1357

UNCOV
1358
        if settleIndex.Valid {
×
UNCOV
1359
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1360
                s.invoice.SettleDate = s.updateTime
×
UNCOV
1361
        }
×
1362

UNCOV
1363
        return nil
×
1364
}
1365

1366
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1367
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
UNCOV
1368
        amtPaid lnwire.MilliSatoshi) error {
×
UNCOV
1369

×
UNCOV
1370
        _, err := s.db.UpdateInvoiceAmountPaid(
×
UNCOV
1371
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
UNCOV
1372
                        ID:             int64(s.invoice.AddIndex),
×
UNCOV
1373
                        AmountPaidMsat: int64(amtPaid),
×
UNCOV
1374
                },
×
UNCOV
1375
        )
×
UNCOV
1376

×
UNCOV
1377
        return err
×
UNCOV
1378
}
×
1379

1380
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1381
// setID.
1382
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
UNCOV
1383
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
UNCOV
1384

×
UNCOV
1385
        var (
×
UNCOV
1386
                settleIndex sql.NullInt64
×
UNCOV
1387
                settledAt   sql.NullTime
×
UNCOV
1388
        )
×
UNCOV
1389

×
UNCOV
1390
        switch newState.State {
×
UNCOV
1391
        case HtlcStateSettled:
×
UNCOV
1392
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1393
                if err != nil {
×
1394
                        return err
×
1395
                }
×
1396

UNCOV
1397
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1398

×
UNCOV
1399
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1400
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1401

×
UNCOV
1402
                err = s.db.OnAMPSubInvoiceSettled(
×
UNCOV
1403
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
×
UNCOV
1404
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1405
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1406
                                SetID:     setID[:],
×
UNCOV
1407
                        },
×
UNCOV
1408
                )
×
UNCOV
1409
                if err != nil {
×
1410
                        return err
×
1411
                }
×
1412

UNCOV
1413
        case HtlcStateCanceled:
×
UNCOV
1414
                err := s.db.OnAMPSubInvoiceCanceled(
×
UNCOV
1415
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
UNCOV
1416
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1417
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1418
                                SetID:     setID[:],
×
UNCOV
1419
                        },
×
UNCOV
1420
                )
×
UNCOV
1421
                if err != nil {
×
1422
                        return err
×
1423
                }
×
1424
        }
1425

UNCOV
1426
        err := s.db.UpdateAMPSubInvoiceState(
×
UNCOV
1427
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
×
UNCOV
1428
                        SetID:       setID[:],
×
UNCOV
1429
                        State:       int16(newState.State),
×
UNCOV
1430
                        SettleIndex: settleIndex,
×
UNCOV
1431
                        SettledAt:   settledAt,
×
UNCOV
1432
                },
×
UNCOV
1433
        )
×
UNCOV
1434
        if err != nil {
×
1435
                return err
×
1436
        }
×
1437

UNCOV
1438
        if settleIndex.Valid {
×
UNCOV
1439
                updatedState := s.invoice.AMPState[setID]
×
UNCOV
1440
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1441
                updatedState.SettleDate = s.updateTime.UTC()
×
UNCOV
1442
                s.invoice.AMPState[setID] = updatedState
×
UNCOV
1443
        }
×
1444

UNCOV
1445
        return nil
×
1446
}
1447

1448
// Finalize finalizes the update before it is written to the database. Note that
1449
// we don't use this directly in the SQL implementation, so the function is just
1450
// a stub.
UNCOV
1451
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
×
UNCOV
1452
        return nil
×
UNCOV
1453
}
×
1454

1455
// UpdateInvoice attempts to update an invoice corresponding to the passed
1456
// reference. If an invoice matching the passed reference doesn't exist within
1457
// the database, then the action will fail with  ErrInvoiceNotFound error.
1458
//
1459
// The update is performed inside the same database transaction that fetches the
1460
// invoice and is therefore atomic. The fields to update are controlled by the
1461
// supplied callback.
1462
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1463
        setID *SetID, callback InvoiceUpdateCallback) (
UNCOV
1464
        *Invoice, error) {
×
UNCOV
1465

×
UNCOV
1466
        var updatedInvoice *Invoice
×
UNCOV
1467

×
UNCOV
1468
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
×
UNCOV
1469
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1470
                switch {
×
1471
                // For the default case we fetch all HTLCs.
UNCOV
1472
                case setID == nil:
×
UNCOV
1473
                        ref.refModifier = DefaultModifier
×
1474

1475
                // If the setID is the blank but NOT nil, we set the
1476
                // refModifier to HtlcSetBlankModifier to fetch no HTLC for the
1477
                // AMP invoice.
1478
                case *setID == BlankPayAddr:
×
1479
                        ref.refModifier = HtlcSetBlankModifier
×
1480

1481
                // A setID is provided, we use the refModifier to fetch only
1482
                // the HTLCs for the given setID and also make sure we add the
1483
                // setID to the ref.
UNCOV
1484
                default:
×
UNCOV
1485
                        var setIDBytes [32]byte
×
UNCOV
1486
                        copy(setIDBytes[:], setID[:])
×
UNCOV
1487
                        ref.setID = &setIDBytes
×
UNCOV
1488

×
UNCOV
1489
                        // We only fetch the HTLCs for the given setID.
×
UNCOV
1490
                        ref.refModifier = HtlcSetOnlyModifier
×
1491
                }
1492

UNCOV
1493
                invoice, err := fetchInvoice(ctx, db, ref)
×
UNCOV
1494
                if err != nil {
×
UNCOV
1495
                        return err
×
UNCOV
1496
                }
×
1497

UNCOV
1498
                updateTime := i.clock.Now()
×
UNCOV
1499
                updater := &sqlInvoiceUpdater{
×
UNCOV
1500
                        db:         db,
×
UNCOV
1501
                        ctx:        ctx,
×
UNCOV
1502
                        invoice:    invoice,
×
UNCOV
1503
                        updateTime: updateTime,
×
UNCOV
1504
                }
×
UNCOV
1505

×
UNCOV
1506
                payHash := ref.PayHash()
×
UNCOV
1507
                updatedInvoice, err = UpdateInvoice(
×
UNCOV
1508
                        payHash, invoice, updateTime, callback, updater,
×
UNCOV
1509
                )
×
UNCOV
1510

×
UNCOV
1511
                return err
×
UNCOV
1512
        }, func() {})
×
UNCOV
1513
        if txErr != nil {
×
UNCOV
1514
                // If the invoice is already settled, we'll return the
×
UNCOV
1515
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
UNCOV
1516
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
UNCOV
1517
                        return updatedInvoice, txErr
×
UNCOV
1518
                }
×
1519

UNCOV
1520
                return nil, txErr
×
1521
        }
1522

UNCOV
1523
        return updatedInvoice, nil
×
1524
}
1525

1526
// DeleteInvoice attempts to delete the passed invoices and all their related
1527
// data from the database in one transaction.
1528
func (i *SQLStore) DeleteInvoice(ctx context.Context,
UNCOV
1529
        invoicesToDelete []InvoiceDeleteRef) error {
×
UNCOV
1530

×
UNCOV
1531
        // All the InvoiceDeleteRef instances include the add index of the
×
UNCOV
1532
        // invoice. The rest was added to ensure that the invoices were deleted
×
UNCOV
1533
        // properly in the kv database. When we have fully migrated we can
×
UNCOV
1534
        // remove the rest of the fields.
×
UNCOV
1535
        for _, ref := range invoicesToDelete {
×
UNCOV
1536
                if ref.AddIndex == 0 {
×
1537
                        return fmt.Errorf("unable to delete invoice using a "+
×
1538
                                "ref without AddIndex set: %v", ref)
×
1539
                }
×
1540
        }
1541

UNCOV
1542
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
UNCOV
1543
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1544
                for _, ref := range invoicesToDelete {
×
UNCOV
1545
                        params := sqlc.DeleteInvoiceParams{
×
UNCOV
1546
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
UNCOV
1547
                        }
×
UNCOV
1548

×
UNCOV
1549
                        if ref.SettleIndex != 0 {
×
UNCOV
1550
                                params.SettleIndex = sqldb.SQLInt64(
×
UNCOV
1551
                                        ref.SettleIndex,
×
UNCOV
1552
                                )
×
UNCOV
1553
                        }
×
1554

UNCOV
1555
                        if ref.PayHash != lntypes.ZeroHash {
×
UNCOV
1556
                                params.Hash = ref.PayHash[:]
×
UNCOV
1557
                        }
×
1558

UNCOV
1559
                        result, err := db.DeleteInvoice(ctx, params)
×
UNCOV
1560
                        if err != nil {
×
1561
                                return fmt.Errorf("unable to delete "+
×
1562
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1563
                        }
×
UNCOV
1564
                        rowsAffected, err := result.RowsAffected()
×
UNCOV
1565
                        if err != nil {
×
1566
                                return fmt.Errorf("unable to get rows "+
×
1567
                                        "affected: %w", err)
×
1568
                        }
×
UNCOV
1569
                        if rowsAffected == 0 {
×
UNCOV
1570
                                return fmt.Errorf("%w: %v",
×
UNCOV
1571
                                        ErrInvoiceNotFound, ref.AddIndex)
×
UNCOV
1572
                        }
×
1573
                }
1574

UNCOV
1575
                return nil
×
UNCOV
1576
        }, func() {})
×
1577

UNCOV
1578
        if err != nil {
×
UNCOV
1579
                return fmt.Errorf("unable to delete invoices: %w", err)
×
UNCOV
1580
        }
×
1581

UNCOV
1582
        return nil
×
1583
}
1584

1585
// DeleteCanceledInvoices removes all canceled invoices from the database.
UNCOV
1586
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
UNCOV
1587
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
UNCOV
1588
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1589
                _, err := db.DeleteCanceledInvoices(ctx)
×
UNCOV
1590
                if err != nil {
×
1591
                        return fmt.Errorf("unable to delete canceled "+
×
1592
                                "invoices: %w", err)
×
1593
                }
×
1594

UNCOV
1595
                return nil
×
UNCOV
1596
        }, func() {})
×
UNCOV
1597
        if err != nil {
×
1598
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1599
        }
×
1600

UNCOV
1601
        return nil
×
1602
}
1603

1604
// fetchInvoiceData fetches additional data for the given invoice. If the
1605
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1606
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1607
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1608
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1609
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
UNCOV
1610
        *Invoice, error) {
×
UNCOV
1611

×
UNCOV
1612
        // Unmarshal the common data.
×
UNCOV
1613
        hash, invoice, err := unmarshalInvoice(row)
×
UNCOV
1614
        if err != nil {
×
1615
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1616
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1617
        }
×
1618

1619
        // Fetch the invoice features.
UNCOV
1620
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
UNCOV
1621
        if err != nil {
×
1622
                return nil, nil, err
×
1623
        }
×
1624

UNCOV
1625
        invoice.Terms.Features = features
×
UNCOV
1626

×
UNCOV
1627
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
UNCOV
1628
        // with the HTLCs (if requested).
×
UNCOV
1629
        if invoice.IsAMP() {
×
UNCOV
1630
                invoiceID := int64(invoice.AddIndex)
×
UNCOV
1631
                ampState, ampHtlcs, err := fetchAmpState(
×
UNCOV
1632
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
UNCOV
1633
                )
×
UNCOV
1634
                if err != nil {
×
1635
                        return nil, nil, err
×
1636
                }
×
1637

UNCOV
1638
                invoice.AMPState = ampState
×
UNCOV
1639
                invoice.Htlcs = ampHtlcs
×
UNCOV
1640

×
UNCOV
1641
                return hash, invoice, nil
×
1642
        }
1643

1644
        // Otherwise simply fetch the invoice HTLCs.
UNCOV
1645
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
UNCOV
1646
        if err != nil {
×
1647
                return nil, nil, err
×
1648
        }
×
1649

UNCOV
1650
        if len(htlcs) > 0 {
×
UNCOV
1651
                invoice.Htlcs = htlcs
×
UNCOV
1652
        }
×
1653

UNCOV
1654
        return hash, invoice, nil
×
1655
}
1656

1657
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1658
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1659
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
1660

×
UNCOV
1661
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
UNCOV
1662
        if err != nil {
×
1663
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1664
                        err)
×
1665
        }
×
1666

UNCOV
1667
        features := lnwire.EmptyFeatureVector()
×
UNCOV
1668
        for _, feature := range rows {
×
UNCOV
1669
                features.Set(lnwire.FeatureBit(feature.Feature))
×
UNCOV
1670
        }
×
1671

UNCOV
1672
        return features, nil
×
1673
}
1674

1675
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1676
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1677
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
UNCOV
1678

×
UNCOV
1679
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
UNCOV
1680
        if err != nil {
×
1681
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1682
        }
×
1683

1684
        // We have no htlcs to unmarshal.
UNCOV
1685
        if len(htlcRows) == 0 {
×
UNCOV
1686
                return nil, nil
×
UNCOV
1687
        }
×
1688

UNCOV
1689
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
1690
        if err != nil {
×
1691
                return nil, fmt.Errorf("unable to get custom records for "+
×
1692
                        "invoice htlcs: %w", err)
×
1693
        }
×
1694

UNCOV
1695
        cr := make(map[int64]record.CustomSet, len(crRows))
×
UNCOV
1696
        for _, row := range crRows {
×
UNCOV
1697
                if _, ok := cr[row.HtlcID]; !ok {
×
UNCOV
1698
                        cr[row.HtlcID] = make(record.CustomSet)
×
UNCOV
1699
                }
×
1700

UNCOV
1701
                value := row.Value
×
UNCOV
1702
                if value == nil {
×
UNCOV
1703
                        value = []byte{}
×
UNCOV
1704
                }
×
UNCOV
1705
                cr[row.HtlcID][uint64(row.Key)] = value
×
1706
        }
1707

UNCOV
1708
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
UNCOV
1709

×
UNCOV
1710
        for _, row := range htlcRows {
×
UNCOV
1711
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
UNCOV
1712
                if err != nil {
×
1713
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1714
                                "htlc(%d): %w", row.ID, err)
×
1715
                }
×
1716

UNCOV
1717
                if customRecords, ok := cr[row.ID]; ok {
×
UNCOV
1718
                        htlc.CustomRecords = customRecords
×
UNCOV
1719
                } else {
×
UNCOV
1720
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
1721
                }
×
1722

UNCOV
1723
                htlcs[circuiteKey] = htlc
×
1724
        }
1725

UNCOV
1726
        return htlcs, nil
×
1727
}
1728

1729
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1730
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
UNCOV
1731
        error) {
×
UNCOV
1732

×
UNCOV
1733
        var (
×
UNCOV
1734
                settleIndex    int64
×
UNCOV
1735
                settledAt      time.Time
×
UNCOV
1736
                memo           []byte
×
UNCOV
1737
                paymentRequest []byte
×
UNCOV
1738
                preimage       *lntypes.Preimage
×
UNCOV
1739
                paymentAddr    [32]byte
×
UNCOV
1740
        )
×
UNCOV
1741

×
UNCOV
1742
        hash, err := lntypes.MakeHash(row.Hash)
×
UNCOV
1743
        if err != nil {
×
1744
                return nil, nil, err
×
1745
        }
×
1746

UNCOV
1747
        if row.SettleIndex.Valid {
×
UNCOV
1748
                settleIndex = row.SettleIndex.Int64
×
UNCOV
1749
        }
×
1750

UNCOV
1751
        if row.SettledAt.Valid {
×
UNCOV
1752
                settledAt = row.SettledAt.Time.Local()
×
UNCOV
1753
        }
×
1754

UNCOV
1755
        if row.Memo.Valid {
×
UNCOV
1756
                memo = []byte(row.Memo.String)
×
UNCOV
1757
        }
×
1758

1759
        // Keysend payments will have this field empty.
UNCOV
1760
        if row.PaymentRequest.Valid {
×
UNCOV
1761
                paymentRequest = []byte(row.PaymentRequest.String)
×
UNCOV
1762
        } else {
×
UNCOV
1763
                paymentRequest = []byte{}
×
UNCOV
1764
        }
×
1765

1766
        // We may not have the preimage if this a hodl invoice.
UNCOV
1767
        if row.Preimage != nil {
×
UNCOV
1768
                preimage = &lntypes.Preimage{}
×
UNCOV
1769
                copy(preimage[:], row.Preimage)
×
UNCOV
1770
        }
×
1771

UNCOV
1772
        copy(paymentAddr[:], row.PaymentAddr)
×
UNCOV
1773

×
UNCOV
1774
        var cltvDelta int32
×
UNCOV
1775
        if row.CltvDelta.Valid {
×
UNCOV
1776
                cltvDelta = row.CltvDelta.Int32
×
UNCOV
1777
        }
×
1778

UNCOV
1779
        expiry := time.Duration(row.Expiry) * time.Second
×
UNCOV
1780

×
UNCOV
1781
        invoice := &Invoice{
×
UNCOV
1782
                SettleIndex:    uint64(settleIndex),
×
UNCOV
1783
                SettleDate:     settledAt,
×
UNCOV
1784
                Memo:           memo,
×
UNCOV
1785
                PaymentRequest: paymentRequest,
×
UNCOV
1786
                CreationDate:   row.CreatedAt.Local(),
×
UNCOV
1787
                Terms: ContractTerm{
×
UNCOV
1788
                        FinalCltvDelta:  cltvDelta,
×
UNCOV
1789
                        Expiry:          expiry,
×
UNCOV
1790
                        PaymentPreimage: preimage,
×
UNCOV
1791
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1792
                        PaymentAddr:     paymentAddr,
×
UNCOV
1793
                },
×
UNCOV
1794
                AddIndex:    uint64(row.ID),
×
UNCOV
1795
                State:       ContractState(row.State),
×
UNCOV
1796
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
UNCOV
1797
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
UNCOV
1798
                AMPState:    AMPInvoiceState{},
×
UNCOV
1799
                HodlInvoice: row.IsHodl,
×
UNCOV
1800
        }
×
UNCOV
1801

×
UNCOV
1802
        return &hash, invoice, nil
×
1803
}
1804

1805
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1806
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
UNCOV
1807
        *InvoiceHTLC, error) {
×
UNCOV
1808

×
UNCOV
1809
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
1810
        if err != nil {
×
1811
                return CircuitKey{}, nil, err
×
1812
        }
×
1813

UNCOV
1814
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
1815

×
UNCOV
1816
        if row.HtlcID < 0 {
×
1817
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1818
                        "value: %v", row.HtlcID)
×
1819
        }
×
1820

UNCOV
1821
        htlcID := uint64(row.HtlcID)
×
UNCOV
1822

×
UNCOV
1823
        circuitKey := CircuitKey{
×
UNCOV
1824
                ChanID: chanID,
×
UNCOV
1825
                HtlcID: htlcID,
×
UNCOV
1826
        }
×
UNCOV
1827

×
UNCOV
1828
        htlc := &InvoiceHTLC{
×
UNCOV
1829
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1830
                AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
1831
                AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
1832
                Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
1833
                State:        HtlcState(row.State),
×
UNCOV
1834
        }
×
UNCOV
1835

×
UNCOV
1836
        if row.TotalMppMsat.Valid {
×
UNCOV
1837
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
UNCOV
1838
        }
×
1839

UNCOV
1840
        if row.ResolveTime.Valid {
×
UNCOV
1841
                htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
1842
        }
×
1843

UNCOV
1844
        return circuitKey, htlc, nil
×
1845
}
1846

1847
// queryWithLimit is a helper method that can be used to query the database
1848
// using a limit and offset. The passed query function should return the number
1849
// of rows returned and an error if any.
UNCOV
1850
func queryWithLimit(query func(int) (int, error), limit int) error {
×
UNCOV
1851
        offset := 0
×
UNCOV
1852
        for {
×
UNCOV
1853
                rows, err := query(offset)
×
UNCOV
1854
                if err != nil {
×
1855
                        return err
×
1856
                }
×
1857

UNCOV
1858
                if rows < limit {
×
UNCOV
1859
                        return nil
×
UNCOV
1860
                }
×
1861

UNCOV
1862
                offset += limit
×
1863
        }
1864
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc