• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15561477203

10 Jun 2025 01:54PM UTC coverage: 58.351% (-10.1%) from 68.487%
15561477203

Pull #9356

github

web-flow
Merge 6440b25db into c6d6d4c0b
Pull Request #9356: lnrpc: add incoming/outgoing channel ids filter to forwarding history request

33 of 36 new or added lines in 2 files covered. (91.67%)

28366 existing lines in 455 files now uncovered.

97715 of 167461 relevant lines covered (58.35%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27

28
        // invoiceProgressLogInterval is the interval we use limiting the
29
        // logging output of invoice processing.
30
        invoiceProgressLogInterval = 30 * time.Second
31
)
32

33
// SQLInvoiceQueries is an interface that defines the set of operations that can
34
// be executed against the invoice SQL database.
35
type SQLInvoiceQueries interface { //nolint:interfacebloat
36
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
37
                error)
38

39
        // TODO(bhandras): remove this once migrations have been separated out.
40
        InsertMigratedInvoice(ctx context.Context,
41
                arg sqlc.InsertMigratedInvoiceParams) (int64, error)
42

43
        InsertInvoiceFeature(ctx context.Context,
44
                arg sqlc.InsertInvoiceFeatureParams) error
45

46
        InsertInvoiceHTLC(ctx context.Context,
47
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
48

49
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
50
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
51

52
        FilterInvoices(ctx context.Context,
53
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
54

55
        GetInvoice(ctx context.Context,
56
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
57

58
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
59
                error)
60

61
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
62
                error)
63

64
        GetInvoiceFeatures(ctx context.Context,
65
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
66

67
        GetInvoiceHTLCCustomRecords(ctx context.Context,
68
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
69

70
        GetInvoiceHTLCs(ctx context.Context,
71
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
72

73
        UpdateInvoiceState(ctx context.Context,
74
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
75

76
        UpdateInvoiceAmountPaid(ctx context.Context,
77
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
78

79
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
80

81
        UpdateInvoiceHTLC(ctx context.Context,
82
                arg sqlc.UpdateInvoiceHTLCParams) error
83

84
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
85
                sql.Result, error)
86

87
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
88

89
        // AMP sub invoice specific methods.
90
        UpsertAMPSubInvoice(ctx context.Context,
91
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
92

93
        // TODO(bhandras): remove this once migrations have been separated out.
94
        InsertAMPSubInvoice(ctx context.Context,
95
                arg sqlc.InsertAMPSubInvoiceParams) error
96

97
        UpdateAMPSubInvoiceState(ctx context.Context,
98
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
99

100
        InsertAMPSubInvoiceHTLC(ctx context.Context,
101
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
102

103
        FetchAMPSubInvoices(ctx context.Context,
104
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
105
                error)
106

107
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
108
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
109
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
110

111
        FetchSettledAMPSubInvoices(ctx context.Context,
112
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
113
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
114

115
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
116
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
117
                error)
118

119
        // Invoice events specific methods.
120
        OnInvoiceCreated(ctx context.Context,
121
                arg sqlc.OnInvoiceCreatedParams) error
122

123
        OnInvoiceCanceled(ctx context.Context,
124
                arg sqlc.OnInvoiceCanceledParams) error
125

126
        OnInvoiceSettled(ctx context.Context,
127
                arg sqlc.OnInvoiceSettledParams) error
128

129
        OnAMPSubInvoiceCreated(ctx context.Context,
130
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
131

132
        OnAMPSubInvoiceCanceled(ctx context.Context,
133
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
134

135
        OnAMPSubInvoiceSettled(ctx context.Context,
136
                arg sqlc.OnAMPSubInvoiceSettledParams) error
137

138
        // Migration specific methods.
139
        // TODO(bhandras): remove this once migrations have been separated out.
140
        InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
141
                arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
142

143
        SetKVInvoicePaymentHash(ctx context.Context,
144
                arg sqlc.SetKVInvoicePaymentHashParams) error
145

146
        GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
147
                []byte, error)
148

149
        ClearKVInvoiceHashIndex(ctx context.Context) error
150
}
151

152
var _ InvoiceDB = (*SQLStore)(nil)
153

154
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
155
// of batched database operations.
156
type BatchedSQLInvoiceQueries interface {
157
        SQLInvoiceQueries
158

159
        sqldb.BatchedTx[SQLInvoiceQueries]
160
}
161

162
// SQLStore represents a storage backend.
163
type SQLStore struct {
164
        db    BatchedSQLInvoiceQueries
165
        clock clock.Clock
166
        opts  SQLStoreOptions
167
}
168

169
// SQLStoreOptions holds the options for the SQL store.
170
type SQLStoreOptions struct {
171
        paginationLimit int
172
}
173

174
// defaultSQLStoreOptions returns the default options for the SQL store.
UNCOV
175
func defaultSQLStoreOptions() SQLStoreOptions {
×
UNCOV
176
        return SQLStoreOptions{
×
UNCOV
177
                paginationLimit: defaultQueryPaginationLimit,
×
UNCOV
178
        }
×
UNCOV
179
}
×
180

181
// SQLStoreOption is a functional option that can be used to optionally modify
182
// the behavior of the SQL store.
183
type SQLStoreOption func(*SQLStoreOptions)
184

185
// WithPaginationLimit sets the pagination limit for the SQL store queries that
186
// paginate results.
UNCOV
187
func WithPaginationLimit(limit int) SQLStoreOption {
×
UNCOV
188
        return func(o *SQLStoreOptions) {
×
UNCOV
189
                o.paginationLimit = limit
×
UNCOV
190
        }
×
191
}
192

193
// NewSQLStore creates a new SQLStore instance given a open
194
// BatchedSQLInvoiceQueries storage backend.
195
func NewSQLStore(db BatchedSQLInvoiceQueries,
UNCOV
196
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
×
UNCOV
197

×
UNCOV
198
        opts := defaultSQLStoreOptions()
×
UNCOV
199
        for _, applyOption := range options {
×
UNCOV
200
                applyOption(&opts)
×
UNCOV
201
        }
×
202

UNCOV
203
        return &SQLStore{
×
UNCOV
204
                db:    db,
×
UNCOV
205
                clock: clock,
×
UNCOV
206
                opts:  opts,
×
UNCOV
207
        }
×
208
}
209

210
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
UNCOV
211
        sqlc.InsertInvoiceParams, error) {
×
UNCOV
212

×
UNCOV
213
        // Precompute the payment request hash so we can use it in the query.
×
UNCOV
214
        var paymentRequestHash []byte
×
UNCOV
215
        if len(invoice.PaymentRequest) > 0 {
×
UNCOV
216
                h := sha256.New()
×
UNCOV
217
                h.Write(invoice.PaymentRequest)
×
UNCOV
218
                paymentRequestHash = h.Sum(nil)
×
UNCOV
219
        }
×
220

UNCOV
221
        params := sqlc.InsertInvoiceParams{
×
UNCOV
222
                Hash:       paymentHash[:],
×
UNCOV
223
                AmountMsat: int64(invoice.Terms.Value),
×
UNCOV
224
                CltvDelta: sqldb.SQLInt32(
×
UNCOV
225
                        invoice.Terms.FinalCltvDelta,
×
UNCOV
226
                ),
×
UNCOV
227
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
×
UNCOV
228
                // Note: keysend invoices don't have a payment request.
×
UNCOV
229
                PaymentRequest: sqldb.SQLStr(string(
×
UNCOV
230
                        invoice.PaymentRequest),
×
UNCOV
231
                ),
×
UNCOV
232
                PaymentRequestHash: paymentRequestHash,
×
UNCOV
233
                State:              int16(invoice.State),
×
UNCOV
234
                AmountPaidMsat:     int64(invoice.AmtPaid),
×
UNCOV
235
                IsAmp:              invoice.IsAMP(),
×
UNCOV
236
                IsHodl:             invoice.HodlInvoice,
×
UNCOV
237
                IsKeysend:          invoice.IsKeysend(),
×
UNCOV
238
                CreatedAt:          invoice.CreationDate.UTC(),
×
UNCOV
239
        }
×
UNCOV
240

×
UNCOV
241
        if invoice.Memo != nil {
×
UNCOV
242
                // Store the memo as a nullable string in the database. Note
×
UNCOV
243
                // that for compatibility reasons, we store the value as a valid
×
UNCOV
244
                // string even if it's empty.
×
UNCOV
245
                params.Memo = sql.NullString{
×
UNCOV
246
                        String: string(invoice.Memo),
×
UNCOV
247
                        Valid:  true,
×
UNCOV
248
                }
×
UNCOV
249
        }
×
250

251
        // Some invoices may not have a preimage, like in the case of HODL
252
        // invoices.
UNCOV
253
        if invoice.Terms.PaymentPreimage != nil {
×
UNCOV
254
                preimage := *invoice.Terms.PaymentPreimage
×
UNCOV
255
                if preimage == UnknownPreimage {
×
256
                        return sqlc.InsertInvoiceParams{},
×
257
                                errors.New("cannot use all-zeroes preimage")
×
258
                }
×
UNCOV
259
                params.Preimage = preimage[:]
×
260
        }
261

262
        // Some non MPP payments may have the default (invalid) value.
UNCOV
263
        if invoice.Terms.PaymentAddr != BlankPayAddr {
×
UNCOV
264
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
×
UNCOV
265
        }
×
266

UNCOV
267
        return params, nil
×
268
}
269

270
// AddInvoice inserts the targeted invoice into the database. If the invoice has
271
// *any* payment hashes which already exists within the database, then the
272
// insertion will be aborted and rejected due to the strict policy banning any
273
// duplicate payment hashes.
274
//
275
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
276
func (i *SQLStore) AddInvoice(ctx context.Context,
UNCOV
277
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
×
UNCOV
278

×
UNCOV
279
        // Make sure this is a valid invoice before trying to store it in our
×
UNCOV
280
        // DB.
×
UNCOV
281
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
UNCOV
282
                return 0, err
×
UNCOV
283
        }
×
284

UNCOV
285
        var (
×
UNCOV
286
                writeTxOpts = sqldb.WriteTxOpt()
×
UNCOV
287
                invoiceID   int64
×
UNCOV
288
        )
×
UNCOV
289

×
UNCOV
290
        insertInvoiceParams, err := makeInsertInvoiceParams(
×
UNCOV
291
                newInvoice, paymentHash,
×
UNCOV
292
        )
×
UNCOV
293
        if err != nil {
×
294
                return 0, err
×
295
        }
×
296

UNCOV
297
        err = i.db.ExecTx(ctx, writeTxOpts, func(db SQLInvoiceQueries) error {
×
UNCOV
298
                var err error
×
UNCOV
299
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
×
UNCOV
300
                if err != nil {
×
UNCOV
301
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
UNCOV
302
                }
×
303

304
                // TODO(positiveblue): if invocies do not have custom features
305
                // maybe just store the "invoice type" and populate the features
306
                // based on that.
UNCOV
307
                for feature := range newInvoice.Terms.Features.Features() {
×
UNCOV
308
                        params := sqlc.InsertInvoiceFeatureParams{
×
UNCOV
309
                                InvoiceID: invoiceID,
×
UNCOV
310
                                Feature:   int32(feature),
×
UNCOV
311
                        }
×
UNCOV
312

×
UNCOV
313
                        err := db.InsertInvoiceFeature(ctx, params)
×
UNCOV
314
                        if err != nil {
×
315
                                return fmt.Errorf("unable to insert invoice "+
×
316
                                        "feature(%v): %w", feature, err)
×
317
                        }
×
318
                }
319

320
                // Finally add a new event for this invoice.
UNCOV
321
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
UNCOV
322
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
UNCOV
323
                        InvoiceID: invoiceID,
×
UNCOV
324
                })
×
325
        }, sqldb.NoOpReset)
UNCOV
326
        if err != nil {
×
UNCOV
327
                mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
328
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
UNCOV
329
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
330
                        // Add context to unique constraint errors.
×
UNCOV
331
                        return 0, ErrDuplicateInvoice
×
UNCOV
332
                }
×
333

334
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
335
                        paymentHash, err)
×
336
        }
337

UNCOV
338
        newInvoice.AddIndex = uint64(invoiceID)
×
UNCOV
339

×
UNCOV
340
        return newInvoice.AddIndex, nil
×
341
}
342

343
// getInvoiceByRef fetches the invoice with the given reference. The reference
344
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
345
func getInvoiceByRef(ctx context.Context,
UNCOV
346
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
×
UNCOV
347

×
UNCOV
348
        // If the reference is empty, we can't look up the invoice.
×
UNCOV
349
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
×
UNCOV
350
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
UNCOV
351
        }
×
352

353
        // If the reference is a hash only, we can look up the invoice directly
354
        // by the payment hash which is faster.
UNCOV
355
        if ref.IsHashOnly() {
×
UNCOV
356
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
×
UNCOV
357
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
358
                        return sqlc.Invoice{}, ErrInvoiceNotFound
×
UNCOV
359
                }
×
360

UNCOV
361
                return invoice, err
×
362
        }
363

364
        // Otherwise the reference may include more fields, so we'll need to
365
        // assemble the query parameters based on the fields that are set.
UNCOV
366
        var params sqlc.GetInvoiceParams
×
UNCOV
367

×
UNCOV
368
        if ref.PayHash() != nil {
×
UNCOV
369
                params.Hash = ref.PayHash()[:]
×
UNCOV
370
        }
×
371

372
        // Newer invoices (0.11 and up) are indexed by payment address in
373
        // addition to payment hash, but pre 0.8 invoices do not have one at
374
        // all. Only allow lookups for payment address if it is not a blank
375
        // payment address, which is a special-cased value for legacy keysend
376
        // invoices.
UNCOV
377
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
UNCOV
378
                params.PaymentAddr = ref.PayAddr()[:]
×
UNCOV
379
        }
×
380

381
        // If the reference has a set ID we'll fetch the invoice which has the
382
        // corresponding AMP sub invoice.
UNCOV
383
        if ref.SetID() != nil {
×
UNCOV
384
                params.SetID = ref.SetID()[:]
×
UNCOV
385
        }
×
386

UNCOV
387
        var (
×
UNCOV
388
                rows []sqlc.Invoice
×
UNCOV
389
                err  error
×
UNCOV
390
        )
×
UNCOV
391

×
UNCOV
392
        // We need to split the query based on how we intend to look up the
×
UNCOV
393
        // invoice. If only the set ID is given then we want to have an exact
×
UNCOV
394
        // match on the set ID. If other fields are given, we want to match on
×
UNCOV
395
        // those fields and the set ID but with a less strict join condition.
×
UNCOV
396
        if params.Hash == nil && params.PaymentAddr == nil &&
×
UNCOV
397
                params.SetID != nil {
×
UNCOV
398

×
UNCOV
399
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
×
UNCOV
400
        } else {
×
UNCOV
401
                rows, err = db.GetInvoice(ctx, params)
×
UNCOV
402
        }
×
403

UNCOV
404
        switch {
×
UNCOV
405
        case len(rows) == 0:
×
UNCOV
406
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
407

408
        case len(rows) > 1:
×
409
                // In case the reference is ambiguous, meaning it matches more
×
410
                // than        one invoice, we'll return an error.
×
411
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
×
412
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
413

414
        case err != nil:
×
415
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
416
                        err)
×
417
        }
418

UNCOV
419
        return rows[0], nil
×
420
}
421

422
// fetchInvoice fetches the common invoice data and the AMP state for the
423
// invoice with the given reference.
424
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
UNCOV
425
        *Invoice, error) {
×
UNCOV
426

×
UNCOV
427
        // Fetch the invoice from the database.
×
UNCOV
428
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
×
UNCOV
429
        if err != nil {
×
UNCOV
430
                return nil, err
×
UNCOV
431
        }
×
432

UNCOV
433
        var (
×
UNCOV
434
                setID         *[32]byte
×
UNCOV
435
                fetchAmpHtlcs bool
×
UNCOV
436
        )
×
UNCOV
437

×
UNCOV
438
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
UNCOV
439
        // the modifier.
×
UNCOV
440
        switch ref.Modifier() {
×
UNCOV
441
        case DefaultModifier:
×
UNCOV
442
                // By default we'll fetch all AMP HTLCs.
×
UNCOV
443
                setID = nil
×
UNCOV
444
                fetchAmpHtlcs = true
×
445

UNCOV
446
        case HtlcSetOnlyModifier:
×
UNCOV
447
                // In this case we'll fetch all AMP HTLCs for the specified set
×
UNCOV
448
                // id.
×
UNCOV
449
                if ref.SetID() == nil {
×
450
                        return nil, fmt.Errorf("set ID is required to use " +
×
451
                                "the HTLC set only modifier")
×
452
                }
×
453

UNCOV
454
                setID = ref.SetID()
×
UNCOV
455
                fetchAmpHtlcs = true
×
456

UNCOV
457
        case HtlcSetBlankModifier:
×
UNCOV
458
                // No need to fetch any HTLCs.
×
UNCOV
459
                setID = nil
×
UNCOV
460
                fetchAmpHtlcs = false
×
461

462
        default:
×
463
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
464
                        ref.Modifier())
×
465
        }
466

467
        // Fetch the rest of the invoice data and fill the invoice struct.
UNCOV
468
        _, invoice, err := fetchInvoiceData(
×
UNCOV
469
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
×
UNCOV
470
        )
×
UNCOV
471
        if err != nil {
×
472
                return nil, err
×
473
        }
×
474

UNCOV
475
        return invoice, nil
×
476
}
477

478
// fetchAmpState fetches the AMP state for the invoice with the given ID.
479
// Optional setID can be provided to fetch the state for a specific AMP HTLC
480
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
481
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
482
// well.
483
//
484
//nolint:funlen
485
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
486
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
UNCOV
487
        HTLCSet, error) {
×
UNCOV
488

×
UNCOV
489
        var paramSetID []byte
×
UNCOV
490
        if setID != nil {
×
UNCOV
491
                paramSetID = setID[:]
×
UNCOV
492
        }
×
493

494
        // First fetch all the AMP sub invoices for this invoice or the one
495
        // matching the provided set ID.
UNCOV
496
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
UNCOV
497
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
UNCOV
498
                        InvoiceID: invoiceID,
×
UNCOV
499
                        SetID:     paramSetID,
×
UNCOV
500
                },
×
UNCOV
501
        )
×
UNCOV
502
        if err != nil {
×
503
                return nil, nil, err
×
504
        }
×
505

UNCOV
506
        ampState := make(map[SetID]InvoiceStateAMP)
×
UNCOV
507
        for _, row := range ampInvoiceRows {
×
UNCOV
508
                var rowSetID [32]byte
×
UNCOV
509

×
UNCOV
510
                if len(row.SetID) != 32 {
×
511
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
512
                                len(row.SetID))
×
513
                }
×
514

UNCOV
515
                var settleDate time.Time
×
UNCOV
516
                if row.SettledAt.Valid {
×
UNCOV
517
                        settleDate = row.SettledAt.Time.Local()
×
UNCOV
518
                }
×
519

UNCOV
520
                copy(rowSetID[:], row.SetID)
×
UNCOV
521
                ampState[rowSetID] = InvoiceStateAMP{
×
UNCOV
522
                        State:       HtlcState(row.State),
×
UNCOV
523
                        SettleIndex: uint64(row.SettleIndex.Int64),
×
UNCOV
524
                        SettleDate:  settleDate,
×
UNCOV
525
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
UNCOV
526
                }
×
527
        }
528

UNCOV
529
        if !fetchHtlcs {
×
UNCOV
530
                return ampState, nil, nil
×
UNCOV
531
        }
×
532

UNCOV
533
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
534
        if err != nil {
×
535
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
536
                        "invoice HTLCs: %w", err)
×
537
        }
×
538

UNCOV
539
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
UNCOV
540
        for _, row := range customRecordRows {
×
UNCOV
541
                if _, ok := customRecords[row.HtlcID]; !ok {
×
UNCOV
542
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
UNCOV
543
                }
×
544

UNCOV
545
                value := row.Value
×
UNCOV
546
                if value == nil {
×
547
                        value = []byte{}
×
548
                }
×
549

UNCOV
550
                customRecords[row.HtlcID][uint64(row.Key)] = value
×
551
        }
552

553
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
554
        // provided set ID.
UNCOV
555
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
×
UNCOV
556
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
×
UNCOV
557
                        InvoiceID: invoiceID,
×
UNCOV
558
                        SetID:     paramSetID,
×
UNCOV
559
                },
×
UNCOV
560
        )
×
UNCOV
561
        if err != nil {
×
562
                return nil, nil, err
×
563
        }
×
564

UNCOV
565
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
×
UNCOV
566
        for _, row := range ampHtlcRows {
×
UNCOV
567
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
568
                if err != nil {
×
569
                        return nil, nil, err
×
570
                }
×
571

UNCOV
572
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
573

×
UNCOV
574
                if row.HtlcID < 0 {
×
575
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
576
                                "value: %v", row.HtlcID)
×
577
                }
×
578

UNCOV
579
                htlcID := uint64(row.HtlcID)
×
UNCOV
580

×
UNCOV
581
                circuitKey := CircuitKey{
×
UNCOV
582
                        ChanID: chanID,
×
UNCOV
583
                        HtlcID: htlcID,
×
UNCOV
584
                }
×
UNCOV
585

×
UNCOV
586
                htlc := &InvoiceHTLC{
×
UNCOV
587
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
588
                        AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
589
                        AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
590
                        Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
591
                        State:        HtlcState(row.State),
×
UNCOV
592
                }
×
UNCOV
593

×
UNCOV
594
                if row.TotalMppMsat.Valid {
×
UNCOV
595
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
UNCOV
596
                                row.TotalMppMsat.Int64,
×
UNCOV
597
                        )
×
UNCOV
598
                }
×
599

UNCOV
600
                if row.ResolveTime.Valid {
×
UNCOV
601
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
602
                }
×
603

UNCOV
604
                var (
×
UNCOV
605
                        rootShare [32]byte
×
UNCOV
606
                        setID     [32]byte
×
UNCOV
607
                )
×
UNCOV
608

×
UNCOV
609
                if len(row.RootShare) != 32 {
×
610
                        return nil, nil, fmt.Errorf("invalid root share "+
×
611
                                "length: %d", len(row.RootShare))
×
612
                }
×
UNCOV
613
                copy(rootShare[:], row.RootShare)
×
UNCOV
614

×
UNCOV
615
                if len(row.SetID) != 32 {
×
616
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
617
                                len(row.SetID))
×
618
                }
×
UNCOV
619
                copy(setID[:], row.SetID)
×
UNCOV
620

×
UNCOV
621
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
622
                        return nil, nil, fmt.Errorf("invalid child index "+
×
623
                                "value: %v", row.ChildIndex)
×
624
                }
×
625

UNCOV
626
                ampRecord := record.NewAMP(
×
UNCOV
627
                        rootShare, setID, uint32(row.ChildIndex),
×
UNCOV
628
                )
×
UNCOV
629

×
UNCOV
630
                htlc.AMP = &InvoiceHtlcAMPData{
×
UNCOV
631
                        Record: *ampRecord,
×
UNCOV
632
                }
×
UNCOV
633

×
UNCOV
634
                if len(row.Hash) != 32 {
×
635
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
636
                                len(row.Hash))
×
637
                }
×
UNCOV
638
                copy(htlc.AMP.Hash[:], row.Hash)
×
UNCOV
639

×
UNCOV
640
                if row.Preimage != nil {
×
UNCOV
641
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
UNCOV
642
                        if err != nil {
×
643
                                return nil, nil, err
×
644
                        }
×
645

UNCOV
646
                        htlc.AMP.Preimage = &preimage
×
647
                }
648

UNCOV
649
                if _, ok := customRecords[row.ID]; ok {
×
UNCOV
650
                        htlc.CustomRecords = customRecords[row.ID]
×
UNCOV
651
                } else {
×
UNCOV
652
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
653
                }
×
654

UNCOV
655
                ampHtlcs[circuitKey] = htlc
×
656
        }
657

UNCOV
658
        if len(ampHtlcs) > 0 {
×
UNCOV
659
                for setID := range ampState {
×
UNCOV
660
                        var amtPaid lnwire.MilliSatoshi
×
UNCOV
661
                        invoiceKeys := make(
×
UNCOV
662
                                map[models.CircuitKey]struct{},
×
UNCOV
663
                        )
×
UNCOV
664

×
UNCOV
665
                        for key, htlc := range ampHtlcs {
×
UNCOV
666
                                if htlc.AMP.Record.SetID() != setID {
×
UNCOV
667
                                        continue
×
668
                                }
669

UNCOV
670
                                invoiceKeys[key] = struct{}{}
×
UNCOV
671

×
UNCOV
672
                                if htlc.State != HtlcStateCanceled {
×
UNCOV
673
                                        amtPaid += htlc.Amt
×
UNCOV
674
                                }
×
675
                        }
676

UNCOV
677
                        setState := ampState[setID]
×
UNCOV
678
                        setState.InvoiceKeys = invoiceKeys
×
UNCOV
679
                        setState.AmtPaid = amtPaid
×
UNCOV
680
                        ampState[setID] = setState
×
681
                }
682
        }
683

UNCOV
684
        return ampState, ampHtlcs, nil
×
685
}
686

687
// LookupInvoice attempts to look up an invoice corresponding the passed in
688
// reference. The reference may be a payment hash, a payment address, or a set
689
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
690
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
691
// error.
692
func (i *SQLStore) LookupInvoice(ctx context.Context,
UNCOV
693
        ref InvoiceRef) (Invoice, error) {
×
UNCOV
694

×
UNCOV
695
        var (
×
UNCOV
696
                invoice *Invoice
×
UNCOV
697
                err     error
×
UNCOV
698
        )
×
UNCOV
699

×
UNCOV
700
        readTxOpt := sqldb.ReadTxOpt()
×
UNCOV
701
        txErr := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
702
                invoice, err = fetchInvoice(ctx, db, ref)
×
UNCOV
703

×
UNCOV
704
                return err
×
UNCOV
705
        }, sqldb.NoOpReset)
×
UNCOV
706
        if txErr != nil {
×
UNCOV
707
                return Invoice{}, txErr
×
UNCOV
708
        }
×
709

UNCOV
710
        return *invoice, nil
×
711
}
712

713
// FetchPendingInvoices returns all the invoices that are currently in a
714
// "pending" state. An invoice is pending if it has been created but not yet
715
// settled or canceled.
716
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
UNCOV
717
        map[lntypes.Hash]Invoice, error) {
×
UNCOV
718

×
UNCOV
719
        var invoices map[lntypes.Hash]Invoice
×
UNCOV
720

×
UNCOV
721
        readTxOpt := sqldb.ReadTxOpt()
×
UNCOV
722
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
723
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
724
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
725
                                PendingOnly: true,
×
UNCOV
726
                                NumOffset:   int32(offset),
×
UNCOV
727
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
728
                                Reverse:     false,
×
UNCOV
729
                        }
×
UNCOV
730

×
UNCOV
731
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
732
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
733
                                return 0, fmt.Errorf("unable to get invoices "+
×
734
                                        "from db: %w", err)
×
735
                        }
×
736

737
                        // Load all the information for the invoices.
UNCOV
738
                        for _, row := range rows {
×
UNCOV
739
                                hash, invoice, err := fetchInvoiceData(
×
UNCOV
740
                                        ctx, db, row, nil, true,
×
UNCOV
741
                                )
×
UNCOV
742
                                if err != nil {
×
743
                                        return 0, err
×
744
                                }
×
745

UNCOV
746
                                invoices[*hash] = *invoice
×
747
                        }
748

UNCOV
749
                        return len(rows), nil
×
750
                }, i.opts.paginationLimit)
UNCOV
751
        }, func() {
×
UNCOV
752
                invoices = make(map[lntypes.Hash]Invoice)
×
UNCOV
753
        })
×
UNCOV
754
        if err != nil {
×
755
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
756
                        err)
×
757
        }
×
758

UNCOV
759
        return invoices, nil
×
760
}
761

762
// InvoicesSettledSince can be used by callers to catch up any settled invoices
763
// they missed within the settled invoice time series. We'll return all known
764
// settled invoice that have a settle index higher than the passed idx.
765
//
766
// NOTE: The index starts from 1. As a result we enforce that specifying a value
767
// below the starting index value is a noop.
768
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
UNCOV
769
        []Invoice, error) {
×
UNCOV
770

×
UNCOV
771
        var (
×
UNCOV
772
                invoices       []Invoice
×
UNCOV
773
                start          = time.Now()
×
UNCOV
774
                lastLogTime    = time.Now()
×
UNCOV
775
                processedCount int
×
UNCOV
776
        )
×
UNCOV
777

×
UNCOV
778
        if idx == 0 {
×
UNCOV
779
                return invoices, nil
×
UNCOV
780
        }
×
781

UNCOV
782
        readTxOpt := sqldb.ReadTxOpt()
×
UNCOV
783
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
784
                err := queryWithLimit(func(offset int) (int, error) {
×
UNCOV
785
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
786
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
787
                                NumOffset:      int32(offset),
×
UNCOV
788
                                NumLimit:       int32(i.opts.paginationLimit),
×
UNCOV
789
                                Reverse:        false,
×
UNCOV
790
                        }
×
UNCOV
791

×
UNCOV
792
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
793
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
794
                                return 0, fmt.Errorf("unable to get invoices "+
×
795
                                        "from db: %w", err)
×
796
                        }
×
797

798
                        // Load all the information for the invoices.
UNCOV
799
                        for _, row := range rows {
×
UNCOV
800
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
801
                                        ctx, db, row, nil, true,
×
UNCOV
802
                                )
×
UNCOV
803
                                if err != nil {
×
804
                                        return 0, fmt.Errorf("unable to fetch "+
×
805
                                                "invoice(id=%d) from db: %w",
×
806
                                                row.ID, err)
×
807
                                }
×
808

UNCOV
809
                                invoices = append(invoices, *invoice)
×
UNCOV
810

×
UNCOV
811
                                processedCount++
×
UNCOV
812
                                if time.Since(lastLogTime) >=
×
UNCOV
813
                                        invoiceProgressLogInterval {
×
814

×
815
                                        log.Debugf("Processed %d settled "+
×
816
                                                "invoices which have a settle "+
×
817
                                                "index greater than %v",
×
818
                                                processedCount, idx)
×
819

×
820
                                        lastLogTime = time.Now()
×
821
                                }
×
822
                        }
823

UNCOV
824
                        return len(rows), nil
×
825
                }, i.opts.paginationLimit)
UNCOV
826
                if err != nil {
×
827
                        return err
×
828
                }
×
829

830
                // Now fetch all the AMP sub invoices that were settled since
831
                // the provided index.
UNCOV
832
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
×
UNCOV
833
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
UNCOV
834
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
835
                        },
×
UNCOV
836
                )
×
UNCOV
837
                if err != nil {
×
838
                        return err
×
839
                }
×
840

UNCOV
841
                for _, ampInvoice := range ampInvoices {
×
UNCOV
842
                        // Convert the row to a sqlc.Invoice so we can use the
×
UNCOV
843
                        // existing fetchInvoiceData function.
×
UNCOV
844
                        sqlInvoice := sqlc.Invoice{
×
UNCOV
845
                                ID:             ampInvoice.ID,
×
UNCOV
846
                                Hash:           ampInvoice.Hash,
×
UNCOV
847
                                Preimage:       ampInvoice.Preimage,
×
UNCOV
848
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
UNCOV
849
                                SettledAt:      ampInvoice.AmpSettledAt,
×
UNCOV
850
                                Memo:           ampInvoice.Memo,
×
UNCOV
851
                                AmountMsat:     ampInvoice.AmountMsat,
×
UNCOV
852
                                CltvDelta:      ampInvoice.CltvDelta,
×
UNCOV
853
                                Expiry:         ampInvoice.Expiry,
×
UNCOV
854
                                PaymentAddr:    ampInvoice.PaymentAddr,
×
UNCOV
855
                                PaymentRequest: ampInvoice.PaymentRequest,
×
UNCOV
856
                                State:          ampInvoice.State,
×
UNCOV
857
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
UNCOV
858
                                IsAmp:          ampInvoice.IsAmp,
×
UNCOV
859
                                IsHodl:         ampInvoice.IsHodl,
×
UNCOV
860
                                IsKeysend:      ampInvoice.IsKeysend,
×
UNCOV
861
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
UNCOV
862
                        }
×
UNCOV
863

×
UNCOV
864
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
UNCOV
865
                        _, invoice, err := fetchInvoiceData(
×
UNCOV
866
                                ctx, db, sqlInvoice,
×
UNCOV
867
                                (*[32]byte)(ampInvoice.SetID), true,
×
UNCOV
868
                        )
×
UNCOV
869
                        if err != nil {
×
870
                                return fmt.Errorf("unable to fetch "+
×
871
                                        "AMP invoice(id=%d) from db: %w",
×
872
                                        ampInvoice.ID, err)
×
873
                        }
×
874

UNCOV
875
                        invoices = append(invoices, *invoice)
×
UNCOV
876

×
UNCOV
877
                        processedCount++
×
UNCOV
878
                        if time.Since(lastLogTime) >=
×
UNCOV
879
                                invoiceProgressLogInterval {
×
880

×
881
                                log.Debugf("Processed %d settled invoices "+
×
882
                                        "including AMP sub invoices which "+
×
883
                                        "have a settle index greater than %v",
×
884
                                        processedCount, idx)
×
885

×
886
                                lastLogTime = time.Now()
×
887
                        }
×
888
                }
889

UNCOV
890
                return nil
×
UNCOV
891
        }, func() {
×
UNCOV
892
                invoices = nil
×
UNCOV
893
        })
×
UNCOV
894
        if err != nil {
×
895
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
896
                        "index (excluding) %d: %w", idx, err)
×
897
        }
×
898

UNCOV
899
        elapsed := time.Since(start)
×
UNCOV
900
        log.Debugf("Completed scanning for settled invoices starting at "+
×
UNCOV
901
                "index %v: total_processed=%d, found_invoices=%d, elapsed=%v",
×
UNCOV
902
                idx, processedCount, len(invoices),
×
UNCOV
903
                elapsed.Round(time.Millisecond))
×
UNCOV
904

×
UNCOV
905
        return invoices, nil
×
906
}
907

908
// InvoicesAddedSince can be used by callers to seek into the event time series
909
// of all the invoices added in the database. This method will return all
910
// invoices with an add index greater than the specified idx.
911
//
912
// NOTE: The index starts from 1. As a result we enforce that specifying a value
913
// below the starting index value is a noop.
914
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
UNCOV
915
        []Invoice, error) {
×
UNCOV
916

×
UNCOV
917
        var (
×
UNCOV
918
                result         []Invoice
×
UNCOV
919
                start          = time.Now()
×
UNCOV
920
                lastLogTime    = time.Now()
×
UNCOV
921
                processedCount int
×
UNCOV
922
        )
×
UNCOV
923

×
UNCOV
924
        if idx == 0 {
×
UNCOV
925
                return result, nil
×
UNCOV
926
        }
×
927

UNCOV
928
        readTxOpt := sqldb.ReadTxOpt()
×
UNCOV
929
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
930
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
931
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
932
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
933
                                NumOffset:   int32(offset),
×
UNCOV
934
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
935
                                Reverse:     false,
×
UNCOV
936
                        }
×
UNCOV
937

×
UNCOV
938
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
939
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
940
                                return 0, fmt.Errorf("unable to get invoices "+
×
941
                                        "from db: %w", err)
×
942
                        }
×
943

944
                        // Load all the information for the invoices.
UNCOV
945
                        for _, row := range rows {
×
UNCOV
946
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
947
                                        ctx, db, row, nil, true,
×
UNCOV
948
                                )
×
UNCOV
949
                                if err != nil {
×
950
                                        return 0, err
×
951
                                }
×
952

UNCOV
953
                                result = append(result, *invoice)
×
UNCOV
954

×
UNCOV
955
                                processedCount++
×
UNCOV
956
                                if time.Since(lastLogTime) >=
×
UNCOV
957
                                        invoiceProgressLogInterval {
×
958

×
959
                                        log.Debugf("Processed %d invoices "+
×
960
                                                "which were added since add "+
×
961
                                                "index %v", processedCount, idx)
×
962

×
963
                                        lastLogTime = time.Now()
×
964
                                }
×
965
                        }
966

UNCOV
967
                        return len(rows), nil
×
968
                }, i.opts.paginationLimit)
UNCOV
969
        }, func() {
×
UNCOV
970
                result = nil
×
UNCOV
971
        })
×
972

UNCOV
973
        if err != nil {
×
974
                return nil, fmt.Errorf("unable to get invoices added since "+
×
975
                        "index %d: %w", idx, err)
×
976
        }
×
977

UNCOV
978
        elapsed := time.Since(start)
×
UNCOV
979
        log.Debugf("Completed scanning for invoices added since index %v: "+
×
UNCOV
980
                "total_processed=%d, found_invoices=%d, elapsed=%v",
×
UNCOV
981
                idx, processedCount, len(result),
×
UNCOV
982
                elapsed.Round(time.Millisecond))
×
UNCOV
983

×
UNCOV
984
        return result, nil
×
985
}
986

987
// QueryInvoices allows a caller to query the invoice database for invoices
988
// within the specified add index range.
989
func (i *SQLStore) QueryInvoices(ctx context.Context,
UNCOV
990
        q InvoiceQuery) (InvoiceSlice, error) {
×
UNCOV
991

×
UNCOV
992
        var invoices []Invoice
×
UNCOV
993

×
UNCOV
994
        if q.NumMaxInvoices == 0 {
×
995
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
996
                        "be non-zero")
×
997
        }
×
998

UNCOV
999
        readTxOpt := sqldb.ReadTxOpt()
×
UNCOV
1000
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1001
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
1002
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
1003
                                NumOffset:   int32(offset),
×
UNCOV
1004
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
1005
                                PendingOnly: q.PendingOnly,
×
UNCOV
1006
                                Reverse:     q.Reversed,
×
UNCOV
1007
                        }
×
UNCOV
1008

×
UNCOV
1009
                        if q.Reversed {
×
UNCOV
1010
                                // If the index offset was not set, we want to
×
UNCOV
1011
                                // fetch from the lastest invoice.
×
UNCOV
1012
                                if q.IndexOffset == 0 {
×
UNCOV
1013
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
1014
                                                int64(math.MaxInt64),
×
UNCOV
1015
                                        )
×
UNCOV
1016
                                } else {
×
UNCOV
1017
                                        // The invoice with index offset id must
×
UNCOV
1018
                                        // not be included in the results.
×
UNCOV
1019
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
1020
                                                q.IndexOffset - 1,
×
UNCOV
1021
                                        )
×
UNCOV
1022
                                }
×
UNCOV
1023
                        } else {
×
UNCOV
1024
                                // The invoice with index offset id must not be
×
UNCOV
1025
                                // included in the results.
×
UNCOV
1026
                                params.AddIndexGet = sqldb.SQLInt64(
×
UNCOV
1027
                                        q.IndexOffset + 1,
×
UNCOV
1028
                                )
×
UNCOV
1029
                        }
×
1030

UNCOV
1031
                        if q.CreationDateStart != 0 {
×
UNCOV
1032
                                params.CreatedAfter = sqldb.SQLTime(
×
UNCOV
1033
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
UNCOV
1034
                                )
×
UNCOV
1035
                        }
×
1036

UNCOV
1037
                        if q.CreationDateEnd != 0 {
×
UNCOV
1038
                                // We need to add 1 to the end date as we're
×
UNCOV
1039
                                // checking less than the end date in SQL.
×
UNCOV
1040
                                params.CreatedBefore = sqldb.SQLTime(
×
UNCOV
1041
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
×
UNCOV
1042
                                )
×
UNCOV
1043
                        }
×
1044

UNCOV
1045
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
1046
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1047
                                return 0, fmt.Errorf("unable to get invoices "+
×
1048
                                        "from db: %w", err)
×
1049
                        }
×
1050

1051
                        // Load all the information for the invoices.
UNCOV
1052
                        for _, row := range rows {
×
UNCOV
1053
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
1054
                                        ctx, db, row, nil, true,
×
UNCOV
1055
                                )
×
UNCOV
1056
                                if err != nil {
×
1057
                                        return 0, err
×
1058
                                }
×
1059

UNCOV
1060
                                invoices = append(invoices, *invoice)
×
UNCOV
1061

×
UNCOV
1062
                                if len(invoices) == int(q.NumMaxInvoices) {
×
UNCOV
1063
                                        return 0, nil
×
UNCOV
1064
                                }
×
1065
                        }
1066

UNCOV
1067
                        return len(rows), nil
×
1068
                }, i.opts.paginationLimit)
UNCOV
1069
        }, func() {
×
UNCOV
1070
                invoices = nil
×
UNCOV
1071
        })
×
UNCOV
1072
        if err != nil {
×
1073
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
1074
                        "invoices: %w", err)
×
1075
        }
×
1076

UNCOV
1077
        if len(invoices) == 0 {
×
UNCOV
1078
                return InvoiceSlice{
×
UNCOV
1079
                        InvoiceQuery: q,
×
UNCOV
1080
                }, nil
×
UNCOV
1081
        }
×
1082

1083
        // If we iterated through the add index in reverse order, then
1084
        // we'll need to reverse the slice of invoices to return them in
1085
        // forward order.
UNCOV
1086
        if q.Reversed {
×
UNCOV
1087
                numInvoices := len(invoices)
×
UNCOV
1088
                for i := 0; i < numInvoices/2; i++ {
×
UNCOV
1089
                        reverse := numInvoices - i - 1
×
UNCOV
1090
                        invoices[i], invoices[reverse] =
×
UNCOV
1091
                                invoices[reverse], invoices[i]
×
UNCOV
1092
                }
×
1093
        }
1094

UNCOV
1095
        res := InvoiceSlice{
×
UNCOV
1096
                InvoiceQuery:     q,
×
UNCOV
1097
                Invoices:         invoices,
×
UNCOV
1098
                FirstIndexOffset: invoices[0].AddIndex,
×
UNCOV
1099
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
UNCOV
1100
        }
×
UNCOV
1101

×
UNCOV
1102
        return res, nil
×
1103
}
1104

1105
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1106
// a SQL database as the backend.
1107
type sqlInvoiceUpdater struct {
1108
        db         SQLInvoiceQueries
1109
        ctx        context.Context //nolint:containedctx
1110
        invoice    *Invoice
1111
        updateTime time.Time
1112
}
1113

1114
// AddHtlc adds a new htlc to the invoice.
1115
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
UNCOV
1116
        newHtlc *InvoiceHTLC) error {
×
UNCOV
1117

×
UNCOV
1118
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
UNCOV
1119
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
UNCOV
1120
                        HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1121
                        ChanID: strconv.FormatUint(
×
UNCOV
1122
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1123
                        ),
×
UNCOV
1124
                        AmountMsat: int64(newHtlc.Amt),
×
UNCOV
1125
                        TotalMppMsat: sql.NullInt64{
×
UNCOV
1126
                                Int64: int64(newHtlc.MppTotalAmt),
×
UNCOV
1127
                                Valid: newHtlc.MppTotalAmt != 0,
×
UNCOV
1128
                        },
×
UNCOV
1129
                        AcceptHeight: int32(newHtlc.AcceptHeight),
×
UNCOV
1130
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
UNCOV
1131
                        ExpiryHeight: int32(newHtlc.Expiry),
×
UNCOV
1132
                        State:        int16(newHtlc.State),
×
UNCOV
1133
                        InvoiceID:    int64(s.invoice.AddIndex),
×
UNCOV
1134
                },
×
UNCOV
1135
        )
×
UNCOV
1136
        if err != nil {
×
1137
                return err
×
1138
        }
×
1139

UNCOV
1140
        for key, value := range newHtlc.CustomRecords {
×
UNCOV
1141
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
UNCOV
1142
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
UNCOV
1143
                                // TODO(bhandras): schema might be wrong here
×
UNCOV
1144
                                // as the custom record key is an uint64.
×
UNCOV
1145
                                Key:    int64(key),
×
UNCOV
1146
                                Value:  value,
×
UNCOV
1147
                                HtlcID: htlcPrimaryKeyID,
×
UNCOV
1148
                        },
×
UNCOV
1149
                )
×
UNCOV
1150
                if err != nil {
×
1151
                        return err
×
1152
                }
×
1153
        }
1154

UNCOV
1155
        if newHtlc.AMP != nil {
×
UNCOV
1156
                setID := newHtlc.AMP.Record.SetID()
×
UNCOV
1157

×
UNCOV
1158
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
UNCOV
1159
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
UNCOV
1160
                                SetID:     setID[:],
×
UNCOV
1161
                                CreatedAt: s.updateTime.UTC(),
×
UNCOV
1162
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1163
                        },
×
UNCOV
1164
                )
×
UNCOV
1165
                if err != nil {
×
UNCOV
1166
                        mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
1167
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
×
UNCOV
1168
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
1169
                                return ErrDuplicateSetID{
×
UNCOV
1170
                                        SetID: setID,
×
UNCOV
1171
                                }
×
UNCOV
1172
                        }
×
1173

1174
                        return err
×
1175
                }
1176

1177
                // If we're just inserting the AMP invoice, we'll get a non
1178
                // zero rows affected count.
UNCOV
1179
                rowsAffected, err := upsertResult.RowsAffected()
×
UNCOV
1180
                if err != nil {
×
1181
                        return err
×
1182
                }
×
UNCOV
1183
                if rowsAffected != 0 {
×
UNCOV
1184
                        // If we're inserting a new AMP invoice, we'll also
×
UNCOV
1185
                        // insert a new invoice event.
×
UNCOV
1186
                        err = s.db.OnAMPSubInvoiceCreated(
×
UNCOV
1187
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
UNCOV
1188
                                        AddedAt:   s.updateTime.UTC(),
×
UNCOV
1189
                                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1190
                                        SetID:     setID[:],
×
UNCOV
1191
                                },
×
UNCOV
1192
                        )
×
UNCOV
1193
                        if err != nil {
×
1194
                                return err
×
1195
                        }
×
1196
                }
1197

UNCOV
1198
                rootShare := newHtlc.AMP.Record.RootShare()
×
UNCOV
1199

×
UNCOV
1200
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
UNCOV
1201
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1202
                        SetID:     setID[:],
×
UNCOV
1203
                        HtlcID:    htlcPrimaryKeyID,
×
UNCOV
1204
                        RootShare: rootShare[:],
×
UNCOV
1205
                        ChildIndex: int64(
×
UNCOV
1206
                                newHtlc.AMP.Record.ChildIndex(),
×
UNCOV
1207
                        ),
×
UNCOV
1208
                        Hash: newHtlc.AMP.Hash[:],
×
UNCOV
1209
                }
×
UNCOV
1210

×
UNCOV
1211
                if newHtlc.AMP.Preimage != nil {
×
UNCOV
1212
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
UNCOV
1213
                }
×
1214

UNCOV
1215
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
UNCOV
1216
                if err != nil {
×
1217
                        return err
×
1218
                }
×
1219
        }
1220

UNCOV
1221
        return nil
×
1222
}
1223

1224
// ResolveHtlc marks an htlc as resolved with the given state.
1225
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
UNCOV
1226
        state HtlcState, resolveTime time.Time) error {
×
UNCOV
1227

×
UNCOV
1228
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
UNCOV
1229
                HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1230
                ChanID: strconv.FormatUint(
×
UNCOV
1231
                        circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1232
                ),
×
UNCOV
1233
                InvoiceID:   int64(s.invoice.AddIndex),
×
UNCOV
1234
                State:       int16(state),
×
UNCOV
1235
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
UNCOV
1236
        })
×
UNCOV
1237
}
×
1238

1239
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1240
// identified by the setID.
1241
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
UNCOV
1242
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
UNCOV
1243

×
UNCOV
1244
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
UNCOV
1245
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
×
UNCOV
1246
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1247
                        SetID:     setID[:],
×
UNCOV
1248
                        HtlcID:    int64(circuitKey.HtlcID),
×
UNCOV
1249
                        Preimage:  preimage[:],
×
UNCOV
1250
                        ChanID: strconv.FormatUint(
×
UNCOV
1251
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1252
                        ),
×
UNCOV
1253
                },
×
UNCOV
1254
        )
×
UNCOV
1255
        if err != nil {
×
1256
                return err
×
1257
        }
×
1258

UNCOV
1259
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1260
        if err != nil {
×
1261
                return err
×
1262
        }
×
UNCOV
1263
        if rowsAffected == 0 {
×
1264
                return ErrInvoiceNotFound
×
1265
        }
×
1266

UNCOV
1267
        return nil
×
1268
}
1269

1270
// UpdateInvoiceState updates the invoice state to the new state.
1271
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
UNCOV
1272
        newState ContractState, preimage *lntypes.Preimage) error {
×
UNCOV
1273

×
UNCOV
1274
        var (
×
UNCOV
1275
                settleIndex sql.NullInt64
×
UNCOV
1276
                settledAt   sql.NullTime
×
UNCOV
1277
        )
×
UNCOV
1278

×
UNCOV
1279
        switch newState {
×
UNCOV
1280
        case ContractSettled:
×
UNCOV
1281
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1282
                if err != nil {
×
1283
                        return err
×
1284
                }
×
1285

UNCOV
1286
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1287

×
UNCOV
1288
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1289
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1290

×
UNCOV
1291
                err = s.db.OnInvoiceSettled(
×
UNCOV
1292
                        s.ctx, sqlc.OnInvoiceSettledParams{
×
UNCOV
1293
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1294
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1295
                        },
×
UNCOV
1296
                )
×
UNCOV
1297
                if err != nil {
×
1298
                        return err
×
1299
                }
×
1300

UNCOV
1301
        case ContractCanceled:
×
UNCOV
1302
                err := s.db.OnInvoiceCanceled(
×
UNCOV
1303
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
UNCOV
1304
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1305
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1306
                        },
×
UNCOV
1307
                )
×
UNCOV
1308
                if err != nil {
×
1309
                        return err
×
1310
                }
×
1311
        }
1312

UNCOV
1313
        params := sqlc.UpdateInvoiceStateParams{
×
UNCOV
1314
                ID:          int64(s.invoice.AddIndex),
×
UNCOV
1315
                State:       int16(newState),
×
UNCOV
1316
                SettleIndex: settleIndex,
×
UNCOV
1317
                SettledAt:   settledAt,
×
UNCOV
1318
        }
×
UNCOV
1319

×
UNCOV
1320
        if preimage != nil {
×
UNCOV
1321
                params.Preimage = preimage[:]
×
UNCOV
1322
        }
×
1323

UNCOV
1324
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
UNCOV
1325
        if err != nil {
×
UNCOV
1326
                return err
×
UNCOV
1327
        }
×
UNCOV
1328
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1329
        if err != nil {
×
1330
                return err
×
1331
        }
×
1332

UNCOV
1333
        if rowsAffected == 0 {
×
1334
                return ErrInvoiceNotFound
×
1335
        }
×
1336

UNCOV
1337
        if settleIndex.Valid {
×
UNCOV
1338
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1339
                s.invoice.SettleDate = s.updateTime
×
UNCOV
1340
        }
×
1341

UNCOV
1342
        return nil
×
1343
}
1344

1345
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1346
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
UNCOV
1347
        amtPaid lnwire.MilliSatoshi) error {
×
UNCOV
1348

×
UNCOV
1349
        _, err := s.db.UpdateInvoiceAmountPaid(
×
UNCOV
1350
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
UNCOV
1351
                        ID:             int64(s.invoice.AddIndex),
×
UNCOV
1352
                        AmountPaidMsat: int64(amtPaid),
×
UNCOV
1353
                },
×
UNCOV
1354
        )
×
UNCOV
1355

×
UNCOV
1356
        return err
×
UNCOV
1357
}
×
1358

1359
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1360
// setID.
1361
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
UNCOV
1362
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
UNCOV
1363

×
UNCOV
1364
        var (
×
UNCOV
1365
                settleIndex sql.NullInt64
×
UNCOV
1366
                settledAt   sql.NullTime
×
UNCOV
1367
        )
×
UNCOV
1368

×
UNCOV
1369
        switch newState.State {
×
UNCOV
1370
        case HtlcStateSettled:
×
UNCOV
1371
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1372
                if err != nil {
×
1373
                        return err
×
1374
                }
×
1375

UNCOV
1376
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1377

×
UNCOV
1378
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1379
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1380

×
UNCOV
1381
                err = s.db.OnAMPSubInvoiceSettled(
×
UNCOV
1382
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
×
UNCOV
1383
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1384
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1385
                                SetID:     setID[:],
×
UNCOV
1386
                        },
×
UNCOV
1387
                )
×
UNCOV
1388
                if err != nil {
×
1389
                        return err
×
1390
                }
×
1391

UNCOV
1392
        case HtlcStateCanceled:
×
UNCOV
1393
                err := s.db.OnAMPSubInvoiceCanceled(
×
UNCOV
1394
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
UNCOV
1395
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1396
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1397
                                SetID:     setID[:],
×
UNCOV
1398
                        },
×
UNCOV
1399
                )
×
UNCOV
1400
                if err != nil {
×
1401
                        return err
×
1402
                }
×
1403
        }
1404

UNCOV
1405
        err := s.db.UpdateAMPSubInvoiceState(
×
UNCOV
1406
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
×
UNCOV
1407
                        SetID:       setID[:],
×
UNCOV
1408
                        State:       int16(newState.State),
×
UNCOV
1409
                        SettleIndex: settleIndex,
×
UNCOV
1410
                        SettledAt:   settledAt,
×
UNCOV
1411
                },
×
UNCOV
1412
        )
×
UNCOV
1413
        if err != nil {
×
1414
                return err
×
1415
        }
×
1416

UNCOV
1417
        if settleIndex.Valid {
×
UNCOV
1418
                updatedState := s.invoice.AMPState[setID]
×
UNCOV
1419
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1420
                updatedState.SettleDate = s.updateTime.UTC()
×
UNCOV
1421
                s.invoice.AMPState[setID] = updatedState
×
UNCOV
1422
        }
×
1423

UNCOV
1424
        return nil
×
1425
}
1426

1427
// Finalize finalizes the update before it is written to the database. Note that
1428
// we don't use this directly in the SQL implementation, so the function is just
1429
// a stub.
UNCOV
1430
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
×
UNCOV
1431
        return nil
×
UNCOV
1432
}
×
1433

1434
// UpdateInvoice attempts to update an invoice corresponding to the passed
1435
// reference. If an invoice matching the passed reference doesn't exist within
1436
// the database, then the action will fail with  ErrInvoiceNotFound error.
1437
//
1438
// The update is performed inside the same database transaction that fetches the
1439
// invoice and is therefore atomic. The fields to update are controlled by the
1440
// supplied callback.
1441
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1442
        setID *SetID, callback InvoiceUpdateCallback) (
UNCOV
1443
        *Invoice, error) {
×
UNCOV
1444

×
UNCOV
1445
        var updatedInvoice *Invoice
×
UNCOV
1446

×
UNCOV
1447
        txOpt := sqldb.WriteTxOpt()
×
UNCOV
1448
        txErr := i.db.ExecTx(ctx, txOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1449
                switch {
×
1450
                // For the default case we fetch all HTLCs.
UNCOV
1451
                case setID == nil:
×
UNCOV
1452
                        ref.refModifier = DefaultModifier
×
1453

1454
                // If the setID is the blank but NOT nil, we set the
1455
                // refModifier to HtlcSetBlankModifier to fetch no HTLC for the
1456
                // AMP invoice.
1457
                case *setID == BlankPayAddr:
×
1458
                        ref.refModifier = HtlcSetBlankModifier
×
1459

1460
                // A setID is provided, we use the refModifier to fetch only
1461
                // the HTLCs for the given setID and also make sure we add the
1462
                // setID to the ref.
UNCOV
1463
                default:
×
UNCOV
1464
                        var setIDBytes [32]byte
×
UNCOV
1465
                        copy(setIDBytes[:], setID[:])
×
UNCOV
1466
                        ref.setID = &setIDBytes
×
UNCOV
1467

×
UNCOV
1468
                        // We only fetch the HTLCs for the given setID.
×
UNCOV
1469
                        ref.refModifier = HtlcSetOnlyModifier
×
1470
                }
1471

UNCOV
1472
                invoice, err := fetchInvoice(ctx, db, ref)
×
UNCOV
1473
                if err != nil {
×
UNCOV
1474
                        return err
×
UNCOV
1475
                }
×
1476

UNCOV
1477
                updateTime := i.clock.Now()
×
UNCOV
1478
                updater := &sqlInvoiceUpdater{
×
UNCOV
1479
                        db:         db,
×
UNCOV
1480
                        ctx:        ctx,
×
UNCOV
1481
                        invoice:    invoice,
×
UNCOV
1482
                        updateTime: updateTime,
×
UNCOV
1483
                }
×
UNCOV
1484

×
UNCOV
1485
                payHash := ref.PayHash()
×
UNCOV
1486
                updatedInvoice, err = UpdateInvoice(
×
UNCOV
1487
                        payHash, invoice, updateTime, callback, updater,
×
UNCOV
1488
                )
×
UNCOV
1489

×
UNCOV
1490
                return err
×
1491
        }, sqldb.NoOpReset)
UNCOV
1492
        if txErr != nil {
×
UNCOV
1493
                // If the invoice is already settled, we'll return the
×
UNCOV
1494
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
UNCOV
1495
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
UNCOV
1496
                        return updatedInvoice, txErr
×
UNCOV
1497
                }
×
1498

UNCOV
1499
                return nil, txErr
×
1500
        }
1501

UNCOV
1502
        return updatedInvoice, nil
×
1503
}
1504

1505
// DeleteInvoice attempts to delete the passed invoices and all their related
1506
// data from the database in one transaction.
1507
func (i *SQLStore) DeleteInvoice(ctx context.Context,
UNCOV
1508
        invoicesToDelete []InvoiceDeleteRef) error {
×
UNCOV
1509

×
UNCOV
1510
        // All the InvoiceDeleteRef instances include the add index of the
×
UNCOV
1511
        // invoice. The rest was added to ensure that the invoices were deleted
×
UNCOV
1512
        // properly in the kv database. When we have fully migrated we can
×
UNCOV
1513
        // remove the rest of the fields.
×
UNCOV
1514
        for _, ref := range invoicesToDelete {
×
UNCOV
1515
                if ref.AddIndex == 0 {
×
1516
                        return fmt.Errorf("unable to delete invoice using a "+
×
1517
                                "ref without AddIndex set: %v", ref)
×
1518
                }
×
1519
        }
1520

UNCOV
1521
        writeTxOpt := sqldb.WriteTxOpt()
×
UNCOV
1522
        err := i.db.ExecTx(ctx, writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1523
                for _, ref := range invoicesToDelete {
×
UNCOV
1524
                        params := sqlc.DeleteInvoiceParams{
×
UNCOV
1525
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
UNCOV
1526
                        }
×
UNCOV
1527

×
UNCOV
1528
                        if ref.SettleIndex != 0 {
×
UNCOV
1529
                                params.SettleIndex = sqldb.SQLInt64(
×
UNCOV
1530
                                        ref.SettleIndex,
×
UNCOV
1531
                                )
×
UNCOV
1532
                        }
×
1533

UNCOV
1534
                        if ref.PayHash != lntypes.ZeroHash {
×
UNCOV
1535
                                params.Hash = ref.PayHash[:]
×
UNCOV
1536
                        }
×
1537

UNCOV
1538
                        result, err := db.DeleteInvoice(ctx, params)
×
UNCOV
1539
                        if err != nil {
×
1540
                                return fmt.Errorf("unable to delete "+
×
1541
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1542
                        }
×
UNCOV
1543
                        rowsAffected, err := result.RowsAffected()
×
UNCOV
1544
                        if err != nil {
×
1545
                                return fmt.Errorf("unable to get rows "+
×
1546
                                        "affected: %w", err)
×
1547
                        }
×
UNCOV
1548
                        if rowsAffected == 0 {
×
UNCOV
1549
                                return fmt.Errorf("%w: %v",
×
UNCOV
1550
                                        ErrInvoiceNotFound, ref.AddIndex)
×
UNCOV
1551
                        }
×
1552
                }
1553

UNCOV
1554
                return nil
×
1555
        }, sqldb.NoOpReset)
1556

UNCOV
1557
        if err != nil {
×
UNCOV
1558
                return fmt.Errorf("unable to delete invoices: %w", err)
×
UNCOV
1559
        }
×
1560

UNCOV
1561
        return nil
×
1562
}
1563

1564
// DeleteCanceledInvoices removes all canceled invoices from the database.
UNCOV
1565
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
UNCOV
1566
        writeTxOpt := sqldb.WriteTxOpt()
×
UNCOV
1567
        err := i.db.ExecTx(ctx, writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1568
                _, err := db.DeleteCanceledInvoices(ctx)
×
UNCOV
1569
                if err != nil {
×
1570
                        return fmt.Errorf("unable to delete canceled "+
×
1571
                                "invoices: %w", err)
×
1572
                }
×
1573

UNCOV
1574
                return nil
×
1575
        }, sqldb.NoOpReset)
UNCOV
1576
        if err != nil {
×
1577
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1578
        }
×
1579

UNCOV
1580
        return nil
×
1581
}
1582

1583
// fetchInvoiceData fetches additional data for the given invoice. If the
1584
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1585
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1586
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1587
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1588
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
UNCOV
1589
        *Invoice, error) {
×
UNCOV
1590

×
UNCOV
1591
        // Unmarshal the common data.
×
UNCOV
1592
        hash, invoice, err := unmarshalInvoice(row)
×
UNCOV
1593
        if err != nil {
×
1594
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1595
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1596
        }
×
1597

1598
        // Fetch the invoice features.
UNCOV
1599
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
UNCOV
1600
        if err != nil {
×
1601
                return nil, nil, err
×
1602
        }
×
1603

UNCOV
1604
        invoice.Terms.Features = features
×
UNCOV
1605

×
UNCOV
1606
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
UNCOV
1607
        // with the HTLCs (if requested).
×
UNCOV
1608
        if invoice.IsAMP() {
×
UNCOV
1609
                invoiceID := int64(invoice.AddIndex)
×
UNCOV
1610
                ampState, ampHtlcs, err := fetchAmpState(
×
UNCOV
1611
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
UNCOV
1612
                )
×
UNCOV
1613
                if err != nil {
×
1614
                        return nil, nil, err
×
1615
                }
×
1616

UNCOV
1617
                invoice.AMPState = ampState
×
UNCOV
1618
                invoice.Htlcs = ampHtlcs
×
UNCOV
1619

×
UNCOV
1620
                return hash, invoice, nil
×
1621
        }
1622

1623
        // Otherwise simply fetch the invoice HTLCs.
UNCOV
1624
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
UNCOV
1625
        if err != nil {
×
1626
                return nil, nil, err
×
1627
        }
×
1628

UNCOV
1629
        if len(htlcs) > 0 {
×
UNCOV
1630
                invoice.Htlcs = htlcs
×
UNCOV
1631
        }
×
1632

UNCOV
1633
        return hash, invoice, nil
×
1634
}
1635

1636
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1637
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1638
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
1639

×
UNCOV
1640
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
UNCOV
1641
        if err != nil {
×
1642
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1643
                        err)
×
1644
        }
×
1645

UNCOV
1646
        features := lnwire.EmptyFeatureVector()
×
UNCOV
1647
        for _, feature := range rows {
×
UNCOV
1648
                features.Set(lnwire.FeatureBit(feature.Feature))
×
UNCOV
1649
        }
×
1650

UNCOV
1651
        return features, nil
×
1652
}
1653

1654
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1655
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1656
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
UNCOV
1657

×
UNCOV
1658
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
UNCOV
1659
        if err != nil {
×
1660
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1661
        }
×
1662

1663
        // We have no htlcs to unmarshal.
UNCOV
1664
        if len(htlcRows) == 0 {
×
UNCOV
1665
                return nil, nil
×
UNCOV
1666
        }
×
1667

UNCOV
1668
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
1669
        if err != nil {
×
1670
                return nil, fmt.Errorf("unable to get custom records for "+
×
1671
                        "invoice htlcs: %w", err)
×
1672
        }
×
1673

UNCOV
1674
        cr := make(map[int64]record.CustomSet, len(crRows))
×
UNCOV
1675
        for _, row := range crRows {
×
UNCOV
1676
                if _, ok := cr[row.HtlcID]; !ok {
×
UNCOV
1677
                        cr[row.HtlcID] = make(record.CustomSet)
×
UNCOV
1678
                }
×
1679

UNCOV
1680
                value := row.Value
×
UNCOV
1681
                if value == nil {
×
UNCOV
1682
                        value = []byte{}
×
UNCOV
1683
                }
×
UNCOV
1684
                cr[row.HtlcID][uint64(row.Key)] = value
×
1685
        }
1686

UNCOV
1687
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
UNCOV
1688

×
UNCOV
1689
        for _, row := range htlcRows {
×
UNCOV
1690
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
UNCOV
1691
                if err != nil {
×
1692
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1693
                                "htlc(%d): %w", row.ID, err)
×
1694
                }
×
1695

UNCOV
1696
                if customRecords, ok := cr[row.ID]; ok {
×
UNCOV
1697
                        htlc.CustomRecords = customRecords
×
UNCOV
1698
                } else {
×
UNCOV
1699
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
1700
                }
×
1701

UNCOV
1702
                htlcs[circuiteKey] = htlc
×
1703
        }
1704

UNCOV
1705
        return htlcs, nil
×
1706
}
1707

1708
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1709
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
UNCOV
1710
        error) {
×
UNCOV
1711

×
UNCOV
1712
        var (
×
UNCOV
1713
                settleIndex    int64
×
UNCOV
1714
                settledAt      time.Time
×
UNCOV
1715
                memo           []byte
×
UNCOV
1716
                paymentRequest []byte
×
UNCOV
1717
                preimage       *lntypes.Preimage
×
UNCOV
1718
                paymentAddr    [32]byte
×
UNCOV
1719
        )
×
UNCOV
1720

×
UNCOV
1721
        hash, err := lntypes.MakeHash(row.Hash)
×
UNCOV
1722
        if err != nil {
×
1723
                return nil, nil, err
×
1724
        }
×
1725

UNCOV
1726
        if row.SettleIndex.Valid {
×
UNCOV
1727
                settleIndex = row.SettleIndex.Int64
×
UNCOV
1728
        }
×
1729

UNCOV
1730
        if row.SettledAt.Valid {
×
UNCOV
1731
                settledAt = row.SettledAt.Time.Local()
×
UNCOV
1732
        }
×
1733

UNCOV
1734
        if row.Memo.Valid {
×
UNCOV
1735
                memo = []byte(row.Memo.String)
×
UNCOV
1736
        }
×
1737

1738
        // Keysend payments will have this field empty.
UNCOV
1739
        if row.PaymentRequest.Valid {
×
UNCOV
1740
                paymentRequest = []byte(row.PaymentRequest.String)
×
UNCOV
1741
        } else {
×
UNCOV
1742
                paymentRequest = []byte{}
×
UNCOV
1743
        }
×
1744

1745
        // We may not have the preimage if this a hodl invoice.
UNCOV
1746
        if row.Preimage != nil {
×
UNCOV
1747
                preimage = &lntypes.Preimage{}
×
UNCOV
1748
                copy(preimage[:], row.Preimage)
×
UNCOV
1749
        }
×
1750

UNCOV
1751
        copy(paymentAddr[:], row.PaymentAddr)
×
UNCOV
1752

×
UNCOV
1753
        var cltvDelta int32
×
UNCOV
1754
        if row.CltvDelta.Valid {
×
UNCOV
1755
                cltvDelta = row.CltvDelta.Int32
×
UNCOV
1756
        }
×
1757

UNCOV
1758
        expiry := time.Duration(row.Expiry) * time.Second
×
UNCOV
1759

×
UNCOV
1760
        invoice := &Invoice{
×
UNCOV
1761
                SettleIndex:    uint64(settleIndex),
×
UNCOV
1762
                SettleDate:     settledAt,
×
UNCOV
1763
                Memo:           memo,
×
UNCOV
1764
                PaymentRequest: paymentRequest,
×
UNCOV
1765
                CreationDate:   row.CreatedAt.Local(),
×
UNCOV
1766
                Terms: ContractTerm{
×
UNCOV
1767
                        FinalCltvDelta:  cltvDelta,
×
UNCOV
1768
                        Expiry:          expiry,
×
UNCOV
1769
                        PaymentPreimage: preimage,
×
UNCOV
1770
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1771
                        PaymentAddr:     paymentAddr,
×
UNCOV
1772
                },
×
UNCOV
1773
                AddIndex:    uint64(row.ID),
×
UNCOV
1774
                State:       ContractState(row.State),
×
UNCOV
1775
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
UNCOV
1776
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
UNCOV
1777
                AMPState:    AMPInvoiceState{},
×
UNCOV
1778
                HodlInvoice: row.IsHodl,
×
UNCOV
1779
        }
×
UNCOV
1780

×
UNCOV
1781
        return &hash, invoice, nil
×
1782
}
1783

1784
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1785
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
UNCOV
1786
        *InvoiceHTLC, error) {
×
UNCOV
1787

×
UNCOV
1788
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
1789
        if err != nil {
×
1790
                return CircuitKey{}, nil, err
×
1791
        }
×
1792

UNCOV
1793
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
1794

×
UNCOV
1795
        if row.HtlcID < 0 {
×
1796
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1797
                        "value: %v", row.HtlcID)
×
1798
        }
×
1799

UNCOV
1800
        htlcID := uint64(row.HtlcID)
×
UNCOV
1801

×
UNCOV
1802
        circuitKey := CircuitKey{
×
UNCOV
1803
                ChanID: chanID,
×
UNCOV
1804
                HtlcID: htlcID,
×
UNCOV
1805
        }
×
UNCOV
1806

×
UNCOV
1807
        htlc := &InvoiceHTLC{
×
UNCOV
1808
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1809
                AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
1810
                AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
1811
                Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
1812
                State:        HtlcState(row.State),
×
UNCOV
1813
        }
×
UNCOV
1814

×
UNCOV
1815
        if row.TotalMppMsat.Valid {
×
UNCOV
1816
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
UNCOV
1817
        }
×
1818

UNCOV
1819
        if row.ResolveTime.Valid {
×
UNCOV
1820
                htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
1821
        }
×
1822

UNCOV
1823
        return circuitKey, htlc, nil
×
1824
}
1825

1826
// queryWithLimit is a helper method that can be used to query the database
1827
// using a limit and offset. The passed query function should return the number
1828
// of rows returned and an error if any.
UNCOV
1829
func queryWithLimit(query func(int) (int, error), limit int) error {
×
UNCOV
1830
        offset := 0
×
UNCOV
1831
        for {
×
UNCOV
1832
                rows, err := query(offset)
×
UNCOV
1833
                if err != nil {
×
1834
                        return err
×
1835
                }
×
1836

UNCOV
1837
                if rows < limit {
×
UNCOV
1838
                        return nil
×
UNCOV
1839
                }
×
1840

UNCOV
1841
                offset += limit
×
1842
        }
1843
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc