• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 14513053602

17 Apr 2025 09:56AM UTC coverage: 56.754% (-12.3%) from 69.035%
14513053602

Pull #9727

github

web-flow
Merge 5fb0f4317 into 24fdae7df
Pull Request #9727: Aux bandwidth manager: also pass HTLC blob to `ShouldHandleTraffic`

3 of 8 new or added lines in 2 files covered. (37.5%)

24357 existing lines in 290 files now uncovered.

107518 of 189445 relevant lines covered (56.75%)

22634.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.98
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27
)
28

29
// SQLInvoiceQueries is an interface that defines the set of operations that can
30
// be executed against the invoice SQL database.
31
type SQLInvoiceQueries interface { //nolint:interfacebloat
32
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
33
                error)
34

35
        // TODO(bhandras): remove this once migrations have been separated out.
36
        InsertMigratedInvoice(ctx context.Context,
37
                arg sqlc.InsertMigratedInvoiceParams) (int64, error)
38

39
        InsertInvoiceFeature(ctx context.Context,
40
                arg sqlc.InsertInvoiceFeatureParams) error
41

42
        InsertInvoiceHTLC(ctx context.Context,
43
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
44

45
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
46
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
47

48
        FilterInvoices(ctx context.Context,
49
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
50

51
        GetInvoice(ctx context.Context,
52
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
53

54
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
55
                error)
56

57
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
58
                error)
59

60
        GetInvoiceFeatures(ctx context.Context,
61
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
62

63
        GetInvoiceHTLCCustomRecords(ctx context.Context,
64
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
65

66
        GetInvoiceHTLCs(ctx context.Context,
67
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
68

69
        UpdateInvoiceState(ctx context.Context,
70
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
71

72
        UpdateInvoiceAmountPaid(ctx context.Context,
73
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
74

75
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
76

77
        UpdateInvoiceHTLC(ctx context.Context,
78
                arg sqlc.UpdateInvoiceHTLCParams) error
79

80
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
81
                sql.Result, error)
82

83
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
84

85
        // AMP sub invoice specific methods.
86
        UpsertAMPSubInvoice(ctx context.Context,
87
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
88

89
        // TODO(bhandras): remove this once migrations have been separated out.
90
        InsertAMPSubInvoice(ctx context.Context,
91
                arg sqlc.InsertAMPSubInvoiceParams) error
92

93
        UpdateAMPSubInvoiceState(ctx context.Context,
94
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
95

96
        InsertAMPSubInvoiceHTLC(ctx context.Context,
97
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
98

99
        FetchAMPSubInvoices(ctx context.Context,
100
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
101
                error)
102

103
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
104
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
105
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
106

107
        FetchSettledAMPSubInvoices(ctx context.Context,
108
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
109
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
110

111
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
112
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
113
                error)
114

115
        // Invoice events specific methods.
116
        OnInvoiceCreated(ctx context.Context,
117
                arg sqlc.OnInvoiceCreatedParams) error
118

119
        OnInvoiceCanceled(ctx context.Context,
120
                arg sqlc.OnInvoiceCanceledParams) error
121

122
        OnInvoiceSettled(ctx context.Context,
123
                arg sqlc.OnInvoiceSettledParams) error
124

125
        OnAMPSubInvoiceCreated(ctx context.Context,
126
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
127

128
        OnAMPSubInvoiceCanceled(ctx context.Context,
129
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
130

131
        OnAMPSubInvoiceSettled(ctx context.Context,
132
                arg sqlc.OnAMPSubInvoiceSettledParams) error
133

134
        // Migration specific methods.
135
        // TODO(bhandras): remove this once migrations have been separated out.
136
        InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
137
                arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
138

139
        SetKVInvoicePaymentHash(ctx context.Context,
140
                arg sqlc.SetKVInvoicePaymentHashParams) error
141

142
        GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
143
                []byte, error)
144

145
        ClearKVInvoiceHashIndex(ctx context.Context) error
146
}
147

148
var _ InvoiceDB = (*SQLStore)(nil)
149

150
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
151
// SQLInvoiceQueries understands.
152
type SQLInvoiceQueriesTxOptions struct {
153
        // readOnly governs if a read only transaction is needed or not.
154
        readOnly bool
155
}
156

157
// ReadOnly returns true if the transaction should be read only.
158
//
159
// NOTE: This implements the TxOptions.
160
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
161
        return a.readOnly
162
}
163

164
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
22,795✔
165
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
22,795✔
166
        return SQLInvoiceQueriesTxOptions{
22,795✔
167
                readOnly: true,
168
        }
169
}
21,216✔
170

21,216✔
171
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
21,216✔
172
// of batched database operations.
21,216✔
173
type BatchedSQLInvoiceQueries interface {
21,216✔
174
        SQLInvoiceQueries
175

176
        sqldb.BatchedTx[SQLInvoiceQueries]
177
}
178

179
// SQLStore represents a storage backend.
180
type SQLStore struct {
181
        db    BatchedSQLInvoiceQueries
182
        clock clock.Clock
183
        opts  SQLStoreOptions
184
}
185

186
// SQLStoreOptions holds the options for the SQL store.
187
type SQLStoreOptions struct {
188
        paginationLimit int
189
}
190

191
// defaultSQLStoreOptions returns the default options for the SQL store.
192
func defaultSQLStoreOptions() SQLStoreOptions {
193
        return SQLStoreOptions{
194
                paginationLimit: defaultQueryPaginationLimit,
195
        }
196
}
512✔
197

512✔
198
// SQLStoreOption is a functional option that can be used to optionally modify
512✔
199
// the behavior of the SQL store.
512✔
200
type SQLStoreOption func(*SQLStoreOptions)
512✔
201

202
// WithPaginationLimit sets the pagination limit for the SQL store queries that
203
// paginate results.
204
func WithPaginationLimit(limit int) SQLStoreOption {
205
        return func(o *SQLStoreOptions) {
206
                o.paginationLimit = limit
207
        }
208
}
50✔
209

100✔
210
// NewSQLStore creates a new SQLStore instance given a open
50✔
211
// BatchedSQLInvoiceQueries storage backend.
50✔
212
func NewSQLStore(db BatchedSQLInvoiceQueries,
213
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
214

215
        opts := defaultSQLStoreOptions()
216
        for _, applyOption := range options {
217
                applyOption(&opts)
512✔
218
        }
512✔
219

512✔
220
        return &SQLStore{
562✔
221
                db:    db,
50✔
222
                clock: clock,
50✔
223
                opts:  opts,
224
        }
512✔
225
}
512✔
226

512✔
227
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
512✔
228
        sqlc.InsertInvoiceParams, error) {
512✔
229

230
        // Precompute the payment request hash so we can use it in the query.
231
        var paymentRequestHash []byte
232
        if len(invoice.PaymentRequest) > 0 {
20,596✔
233
                h := sha256.New()
20,596✔
234
                h.Write(invoice.PaymentRequest)
20,596✔
235
                paymentRequestHash = h.Sum(nil)
20,596✔
236
        }
40,793✔
237

20,197✔
238
        params := sqlc.InsertInvoiceParams{
20,197✔
239
                Hash:       paymentHash[:],
20,197✔
240
                AmountMsat: int64(invoice.Terms.Value),
20,197✔
241
                CltvDelta: sqldb.SQLInt32(
242
                        invoice.Terms.FinalCltvDelta,
20,596✔
243
                ),
20,596✔
244
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
20,596✔
245
                // Note: keysend invoices don't have a payment request.
20,596✔
246
                PaymentRequest: sqldb.SQLStr(string(
20,596✔
247
                        invoice.PaymentRequest),
20,596✔
248
                ),
20,596✔
249
                PaymentRequestHash: paymentRequestHash,
20,596✔
250
                State:              int16(invoice.State),
20,596✔
251
                AmountPaidMsat:     int64(invoice.AmtPaid),
20,596✔
252
                IsAmp:              invoice.IsAMP(),
20,596✔
253
                IsHodl:             invoice.HodlInvoice,
20,596✔
254
                IsKeysend:          invoice.IsKeysend(),
20,596✔
255
                CreatedAt:          invoice.CreationDate.UTC(),
20,596✔
256
        }
20,596✔
257

20,596✔
258
        if invoice.Memo != nil {
20,596✔
259
                // Store the memo as a nullable string in the database. Note
20,596✔
260
                // that for compatibility reasons, we store the value as a valid
20,596✔
261
                // string even if it's empty.
20,596✔
262
                params.Memo = sql.NullString{
40,890✔
263
                        String: string(invoice.Memo),
20,294✔
264
                        Valid:  true,
20,294✔
265
                }
20,294✔
266
        }
20,294✔
267

20,294✔
268
        // Some invoices may not have a preimage, like in the case of HODL
20,294✔
269
        // invoices.
20,294✔
270
        if invoice.Terms.PaymentPreimage != nil {
20,294✔
271
                preimage := *invoice.Terms.PaymentPreimage
272
                if preimage == UnknownPreimage {
273
                        return sqlc.InsertInvoiceParams{},
274
                                errors.New("cannot use all-zeroes preimage")
41,162✔
275
                }
20,566✔
276
                params.Preimage = preimage[:]
20,566✔
UNCOV
277
        }
×
UNCOV
278

×
UNCOV
279
        // Some non MPP payments may have the default (invalid) value.
×
280
        if invoice.Terms.PaymentAddr != BlankPayAddr {
20,566✔
281
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
282
        }
283

284
        return params, nil
20,946✔
285
}
350✔
286

350✔
287
// AddInvoice inserts the targeted invoice into the database. If the invoice has
288
// *any* payment hashes which already exists within the database, then the
20,596✔
289
// insertion will be aborted and rejected due to the strict policy banning any
290
// duplicate payment hashes.
291
//
292
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
293
func (i *SQLStore) AddInvoice(ctx context.Context,
294
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
295

296
        // Make sure this is a valid invoice before trying to store it in our
297
        // DB.
298
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
596✔
299
                return 0, err
596✔
300
        }
596✔
301

596✔
302
        var (
600✔
303
                writeTxOpts SQLInvoiceQueriesTxOptions
4✔
304
                invoiceID   int64
4✔
305
        )
306

592✔
307
        insertInvoiceParams, err := makeInsertInvoiceParams(
592✔
308
                newInvoice, paymentHash,
592✔
309
        )
592✔
310
        if err != nil {
592✔
311
                return 0, err
592✔
312
        }
592✔
313

592✔
314
        err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
592✔
UNCOV
315
                var err error
×
UNCOV
316
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
×
317
                if err != nil {
318
                        return fmt.Errorf("unable to insert invoice: %w", err)
1,184✔
319
                }
592✔
320

592✔
321
                // TODO(positiveblue): if invocies do not have custom features
608✔
322
                // maybe just store the "invoice type" and populate the features
16✔
323
                // based on that.
16✔
324
                for feature := range newInvoice.Terms.Features.Features() {
325
                        params := sqlc.InsertInvoiceFeatureParams{
326
                                InvoiceID: invoiceID,
327
                                Feature:   int32(feature),
328
                        }
670✔
329

94✔
330
                        err := db.InsertInvoiceFeature(ctx, params)
94✔
331
                        if err != nil {
94✔
332
                                return fmt.Errorf("unable to insert invoice "+
94✔
333
                                        "feature(%v): %w", feature, err)
94✔
334
                        }
94✔
335
                }
94✔
UNCOV
336

×
UNCOV
337
                // Finally add a new event for this invoice.
×
UNCOV
338
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
339
                        AddedAt:   newInvoice.CreationDate.UTC(),
340
                        InvoiceID: invoiceID,
341
                })
342
        }, func() {})
576✔
343
        if err != nil {
576✔
344
                mappedSQLErr := sqldb.MapSQLError(err)
576✔
345
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
576✔
346
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
592✔
347
                        // Add context to unique constraint errors.
608✔
348
                        return 0, ErrDuplicateInvoice
16✔
349
                }
16✔
350

32✔
351
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
16✔
352
                        paymentHash, err)
16✔
353
        }
16✔
354

UNCOV
355
        newInvoice.AddIndex = uint64(invoiceID)
×
UNCOV
356

×
357
        return newInvoice.AddIndex, nil
358
}
359

576✔
360
// getInvoiceByRef fetches the invoice with the given reference. The reference
576✔
361
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
576✔
362
func getInvoiceByRef(ctx context.Context,
363
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
364

365
        // If the reference is empty, we can't look up the invoice.
366
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
367
                return sqlc.Invoice{}, ErrInvoiceNotFound
21,625✔
368
        }
21,625✔
369

21,625✔
370
        // If the reference is a hash only, we can look up the invoice directly
21,627✔
371
        // by the payment hash which is faster.
2✔
372
        if ref.IsHashOnly() {
2✔
373
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
374
                if errors.Is(err, sql.ErrNoRows) {
375
                        return sqlc.Invoice{}, ErrInvoiceNotFound
376
                }
42,224✔
377

20,601✔
378
                return invoice, err
20,633✔
379
        }
32✔
380

32✔
381
        // Otherwise the reference may include more fields, so we'll need to
382
        // assemble the query parameters based on the fields that are set.
20,569✔
383
        var params sqlc.GetInvoiceParams
384

385
        if ref.PayHash() != nil {
386
                params.Hash = ref.PayHash()[:]
387
        }
1,022✔
388

1,022✔
389
        // Newer invoices (0.11 and up) are indexed by payment address in
1,948✔
390
        // addition to payment hash, but pre 0.8 invoices do not have one at
926✔
391
        // all. Only allow lookups for payment address if it is not a blank
926✔
392
        // payment address, which is a special-cased value for legacy keysend
393
        // invoices.
394
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
395
                params.PaymentAddr = ref.PayAddr()[:]
396
        }
397

398
        // If the reference has a set ID we'll fetch the invoice which has the
1,158✔
399
        // corresponding AMP sub invoice.
136✔
400
        if ref.SetID() != nil {
136✔
401
                params.SetID = ref.SetID()[:]
402
        }
403

404
        var (
1,118✔
405
                rows []sqlc.Invoice
96✔
406
                err  error
96✔
407
        )
408

1,022✔
409
        // We need to split the query based on how we intend to look up the
1,022✔
410
        // invoice. If only the set ID is given then we want to have an exact
1,022✔
411
        // match on the set ID. If other fields are given, we want to match on
1,022✔
412
        // those fields and the set ID but with a less strict join condition.
1,022✔
413
        if params.Hash == nil && params.PaymentAddr == nil &&
1,022✔
414
                params.SetID != nil {
1,022✔
415

1,022✔
416
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
1,022✔
417
        } else {
1,022✔
418
                rows, err = db.GetInvoice(ctx, params)
1,046✔
419
        }
24✔
420

24✔
421
        switch {
1,022✔
422
        case len(rows) == 0:
998✔
423
                return sqlc.Invoice{}, ErrInvoiceNotFound
998✔
424

425
        case len(rows) > 1:
1,022✔
426
                // In case the reference is ambiguous, meaning it matches more
14✔
427
                // than        one invoice, we'll return an error.
14✔
428
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
429
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
UNCOV
430

×
431
        case err != nil:
×
432
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
433
                        err)
×
434
        }
UNCOV
435

×
UNCOV
436
        return rows[0], nil
×
UNCOV
437
}
×
438

439
// fetchInvoice fetches the common invoice data and the AMP state for the
440
// invoice with the given reference.
1,008✔
441
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
442
        *Invoice, error) {
443

444
        // Fetch the invoice from the database.
445
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
446
        if err != nil {
21,625✔
447
                return nil, err
21,625✔
448
        }
21,625✔
449

21,625✔
450
        var (
21,673✔
451
                setID         *[32]byte
48✔
452
                fetchAmpHtlcs bool
48✔
453
        )
454

21,577✔
455
        // Now that we got the invoice itself, fetch the HTLCs as requested by
21,577✔
456
        // the modifier.
21,577✔
457
        switch ref.Modifier() {
21,577✔
458
        case DefaultModifier:
21,577✔
459
                // By default we'll fetch all AMP HTLCs.
21,577✔
460
                setID = nil
21,577✔
461
                fetchAmpHtlcs = true
21,577✔
462

21,485✔
463
        case HtlcSetOnlyModifier:
21,485✔
464
                // In this case we'll fetch all AMP HTLCs for the specified set
21,485✔
465
                // id.
21,485✔
466
                if ref.SetID() == nil {
467
                        return nil, fmt.Errorf("set ID is required to use " +
90✔
468
                                "the HTLC set only modifier")
90✔
469
                }
90✔
470

90✔
UNCOV
471
                setID = ref.SetID()
×
UNCOV
472
                fetchAmpHtlcs = true
×
UNCOV
473

×
474
        case HtlcSetBlankModifier:
475
                // No need to fetch any HTLCs.
90✔
476
                setID = nil
90✔
477
                fetchAmpHtlcs = false
478

2✔
479
        default:
2✔
480
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
2✔
481
                        ref.Modifier())
2✔
482
        }
UNCOV
483

×
UNCOV
484
        // Fetch the rest of the invoice data and fill the invoice struct.
×
UNCOV
485
        _, invoice, err := fetchInvoiceData(
×
486
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
487
        )
488
        if err != nil {
489
                return nil, err
21,577✔
490
        }
21,577✔
491

21,577✔
492
        return invoice, nil
21,577✔
UNCOV
493
}
×
UNCOV
494

×
495
// fetchAmpState fetches the AMP state for the invoice with the given ID.
496
// Optional setID can be provided to fetch the state for a specific AMP HTLC
21,577✔
497
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
498
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
499
// well.
500
//
501
//nolint:funlen
502
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
503
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
504
        HTLCSet, error) {
505

506
        var paramSetID []byte
507
        if setID != nil {
508
                paramSetID = setID[:]
9,973✔
509
        }
9,973✔
510

9,973✔
511
        // First fetch all the AMP sub invoices for this invoice or the one
10,067✔
512
        // matching the provided set ID.
94✔
513
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
94✔
514
                ctx, sqlc.FetchAMPSubInvoicesParams{
515
                        InvoiceID: invoiceID,
516
                        SetID:     paramSetID,
517
                },
9,973✔
518
        )
9,973✔
519
        if err != nil {
9,973✔
520
                return nil, nil, err
9,973✔
521
        }
9,973✔
522

9,973✔
523
        ampState := make(map[SetID]InvoiceStateAMP)
9,973✔
UNCOV
524
        for _, row := range ampInvoiceRows {
×
UNCOV
525
                var rowSetID [32]byte
×
526

527
                if len(row.SetID) != 32 {
9,973✔
528
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
37,948✔
529
                                len(row.SetID))
27,975✔
530
                }
27,975✔
531

27,975✔
UNCOV
532
                var settleDate time.Time
×
UNCOV
533
                if row.SettledAt.Valid {
×
UNCOV
534
                        settleDate = row.SettledAt.Time.Local()
×
535
                }
536

27,975✔
537
                copy(rowSetID[:], row.SetID)
35,642✔
538
                ampState[rowSetID] = InvoiceStateAMP{
7,667✔
539
                        State:       HtlcState(row.State),
7,667✔
540
                        SettleIndex: uint64(row.SettleIndex.Int64),
541
                        SettleDate:  settleDate,
27,975✔
542
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
27,975✔
543
                }
27,975✔
544
        }
27,975✔
545

27,975✔
546
        if !fetchHtlcs {
27,975✔
547
                return ampState, nil, nil
27,975✔
548
        }
549

550
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
9,975✔
551
        if err != nil {
2✔
552
                return nil, nil, fmt.Errorf("unable to get custom records for "+
2✔
553
                        "invoice HTLCs: %w", err)
554
        }
9,971✔
555

9,971✔
UNCOV
556
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
UNCOV
557
        for _, row := range customRecordRows {
×
UNCOV
558
                if _, ok := customRecords[row.HtlcID]; !ok {
×
559
                        customRecords[row.HtlcID] = make(record.CustomSet)
560
                }
9,971✔
561

182,153✔
562
                value := row.Value
234,971✔
563
                if value == nil {
62,789✔
564
                        value = []byte{}
62,789✔
565
                }
566

172,182✔
567
                customRecords[row.HtlcID][uint64(row.Key)] = value
172,182✔
UNCOV
568
        }
×
UNCOV
569

×
570
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
571
        // provided set ID.
172,182✔
572
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
573
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
574
                        InvoiceID: invoiceID,
575
                        SetID:     paramSetID,
576
                },
9,971✔
577
        )
9,971✔
578
        if err != nil {
9,971✔
579
                return nil, nil, err
9,971✔
580
        }
9,971✔
581

9,971✔
582
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
9,971✔
UNCOV
583
        for _, row := range ampHtlcRows {
×
UNCOV
584
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
585
                if err != nil {
586
                        return nil, nil, err
9,971✔
587
                }
89,348✔
588

79,377✔
589
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
79,377✔
UNCOV
590

×
UNCOV
591
                if row.HtlcID < 0 {
×
592
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
593
                                "value: %v", row.HtlcID)
79,377✔
594
                }
79,377✔
595

79,377✔
UNCOV
596
                htlcID := uint64(row.HtlcID)
×
UNCOV
597

×
UNCOV
598
                circuitKey := CircuitKey{
×
599
                        ChanID: chanID,
600
                        HtlcID: htlcID,
79,377✔
601
                }
79,377✔
602

79,377✔
603
                htlc := &InvoiceHTLC{
79,377✔
604
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
79,377✔
605
                        AcceptHeight: uint32(row.AcceptHeight),
79,377✔
606
                        AcceptTime:   row.AcceptTime.Local(),
79,377✔
607
                        Expiry:       uint32(row.ExpiryHeight),
79,377✔
608
                        State:        HtlcState(row.State),
79,377✔
609
                }
79,377✔
610

79,377✔
611
                if row.TotalMppMsat.Valid {
79,377✔
612
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
79,377✔
613
                                row.TotalMppMsat.Int64,
79,377✔
614
                        )
79,377✔
615
                }
79,491✔
616

114✔
617
                if row.ResolveTime.Valid {
114✔
618
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
114✔
619
                }
114✔
620

621
                var (
118,260✔
622
                        rootShare [32]byte
38,883✔
623
                        setID     [32]byte
38,883✔
624
                )
625

79,377✔
626
                if len(row.RootShare) != 32 {
79,377✔
627
                        return nil, nil, fmt.Errorf("invalid root share "+
79,377✔
628
                                "length: %d", len(row.RootShare))
79,377✔
629
                }
79,377✔
630
                copy(rootShare[:], row.RootShare)
79,377✔
UNCOV
631

×
UNCOV
632
                if len(row.SetID) != 32 {
×
633
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
634
                                len(row.SetID))
79,377✔
635
                }
79,377✔
636
                copy(setID[:], row.SetID)
79,377✔
UNCOV
637

×
UNCOV
638
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
639
                        return nil, nil, fmt.Errorf("invalid child index "+
×
640
                                "value: %v", row.ChildIndex)
79,377✔
641
                }
79,377✔
642

79,377✔
UNCOV
643
                ampRecord := record.NewAMP(
×
UNCOV
644
                        rootShare, setID, uint32(row.ChildIndex),
×
UNCOV
645
                )
×
646

647
                htlc.AMP = &InvoiceHtlcAMPData{
79,377✔
648
                        Record: *ampRecord,
79,377✔
649
                }
79,377✔
650

79,377✔
651
                if len(row.Hash) != 32 {
79,377✔
652
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
79,377✔
653
                                len(row.Hash))
79,377✔
654
                }
79,377✔
655
                copy(htlc.AMP.Hash[:], row.Hash)
79,377✔
UNCOV
656

×
UNCOV
657
                if row.Preimage != nil {
×
UNCOV
658
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
659
                        if err != nil {
79,377✔
660
                                return nil, nil, err
79,377✔
661
                        }
158,630✔
662

79,253✔
663
                        htlc.AMP.Preimage = &preimage
79,253✔
UNCOV
664
                }
×
UNCOV
665

×
666
                if _, ok := customRecords[row.ID]; ok {
667
                        htlc.CustomRecords = customRecords[row.ID]
79,253✔
668
                } else {
669
                        htlc.CustomRecords = make(record.CustomSet)
670
                }
142,166✔
671

62,789✔
672
                ampHtlcs[circuitKey] = htlc
79,377✔
673
        }
16,588✔
674

16,588✔
675
        if len(ampHtlcs) > 0 {
676
                for setID := range ampState {
79,377✔
677
                        var amtPaid lnwire.MilliSatoshi
678
                        invoiceKeys := make(
679
                                map[models.CircuitKey]struct{},
19,890✔
680
                        )
37,888✔
681

27,969✔
682
                        for key, htlc := range ampHtlcs {
27,969✔
683
                                if htlc.AMP.Record.SetID() != setID {
27,969✔
684
                                        continue
27,969✔
685
                                }
27,969✔
686

309,520✔
687
                                invoiceKeys[key] = struct{}{}
483,725✔
688

202,174✔
689
                                if htlc.State != HtlcStateCanceled {
690
                                        amtPaid += htlc.Amt
691
                                }
79,377✔
692
                        }
79,377✔
693

141,731✔
694
                        setState := ampState[setID]
62,354✔
695
                        setState.InvoiceKeys = invoiceKeys
62,354✔
696
                        setState.AmtPaid = amtPaid
697
                        ampState[setID] = setState
698
                }
27,969✔
699
        }
27,969✔
700

27,969✔
701
        return ampState, ampHtlcs, nil
27,969✔
702
}
703

704
// LookupInvoice attempts to look up an invoice corresponding the passed in
705
// reference. The reference may be a payment hash, a payment address, or a set
9,971✔
706
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
707
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
708
// error.
709
func (i *SQLStore) LookupInvoice(ctx context.Context,
710
        ref InvoiceRef) (Invoice, error) {
711

712
        var (
713
                invoice *Invoice
714
                err     error
20,852✔
715
        )
20,852✔
716

20,852✔
717
        readTxOpt := NewSQLInvoiceQueryReadTx()
20,852✔
718
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
20,852✔
719
                invoice, err = fetchInvoice(ctx, db, ref)
20,852✔
720

20,852✔
721
                return err
20,852✔
722
        }, func() {})
41,704✔
723
        if txErr != nil {
20,852✔
724
                return Invoice{}, txErr
20,852✔
725
        }
20,852✔
726

41,704✔
727
        return *invoice, nil
20,892✔
728
}
40✔
729

40✔
730
// FetchPendingInvoices returns all the invoices that are currently in a
731
// "pending" state. An invoice is pending if it has been created but not yet
20,812✔
732
// settled or canceled.
733
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
734
        map[lntypes.Hash]Invoice, error) {
735

736
        var invoices map[lntypes.Hash]Invoice
737

738
        readTxOpt := NewSQLInvoiceQueryReadTx()
262✔
739
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
262✔
740
                return queryWithLimit(func(offset int) (int, error) {
262✔
741
                        params := sqlc.FilterInvoicesParams{
262✔
742
                                PendingOnly: true,
262✔
743
                                NumOffset:   int32(offset),
524✔
744
                                NumLimit:    int32(i.opts.paginationLimit),
530✔
745
                                Reverse:     false,
268✔
746
                        }
268✔
747

268✔
748
                        rows, err := db.FilterInvoices(ctx, params)
268✔
749
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
268✔
750
                                return 0, fmt.Errorf("unable to get invoices "+
268✔
751
                                        "from db: %w", err)
268✔
752
                        }
268✔
753

268✔
UNCOV
754
                        // Load all the information for the invoices.
×
UNCOV
755
                        for _, row := range rows {
×
UNCOV
756
                                hash, invoice, err := fetchInvoiceData(
×
757
                                        ctx, db, row, nil, true,
758
                                )
759
                                if err != nil {
308✔
760
                                        return 0, err
40✔
761
                                }
40✔
762

40✔
763
                                invoices[*hash] = *invoice
40✔
UNCOV
764
                        }
×
UNCOV
765

×
766
                        return len(rows), nil
767
                }, i.opts.paginationLimit)
40✔
768
        }, func() {
769
                invoices = make(map[lntypes.Hash]Invoice)
770
        })
268✔
771
        if err != nil {
772
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
262✔
773
                        err)
262✔
774
        }
262✔
775

262✔
UNCOV
776
        return invoices, nil
×
UNCOV
777
}
×
UNCOV
778

×
779
// InvoicesSettledSince can be used by callers to catch up any settled invoices
780
// they missed within the settled invoice time series. We'll return all known
262✔
781
// settled invoice that have a settle index higher than the passed idx.
782
//
783
// NOTE: The index starts from 1. As a result we enforce that specifying a value
784
// below the starting index value is a noop.
785
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
786
        []Invoice, error) {
787

788
        var invoices []Invoice
789

790
        if idx == 0 {
42✔
791
                return invoices, nil
42✔
792
        }
42✔
793

42✔
794
        readTxOpt := NewSQLInvoiceQueryReadTx()
42✔
795
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
42✔
796
                err := queryWithLimit(func(offset int) (int, error) {
42✔
797
                        params := sqlc.FilterInvoicesParams{
42✔
798
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
42✔
799
                                NumOffset:      int32(offset),
78✔
800
                                NumLimit:       int32(i.opts.paginationLimit),
36✔
801
                                Reverse:        false,
36✔
802
                        }
803

6✔
804
                        rows, err := db.FilterInvoices(ctx, params)
12✔
805
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
18✔
806
                                return 0, fmt.Errorf("unable to get invoices "+
12✔
807
                                        "from db: %w", err)
12✔
808
                        }
12✔
809

12✔
810
                        // Load all the information for the invoices.
12✔
811
                        for _, row := range rows {
12✔
812
                                _, invoice, err := fetchInvoiceData(
12✔
813
                                        ctx, db, row, nil, true,
12✔
814
                                )
12✔
UNCOV
815
                                if err != nil {
×
816
                                        return 0, fmt.Errorf("unable to fetch "+
×
817
                                                "invoice(id=%d) from db: %w",
×
818
                                                row.ID, err)
819
                                }
820

30✔
821
                                invoices = append(invoices, *invoice)
18✔
822
                        }
18✔
823

18✔
824
                        return len(rows), nil
18✔
UNCOV
825
                }, i.opts.paginationLimit)
×
UNCOV
826
                if err != nil {
×
827
                        return err
×
828
                }
×
829

830
                // Now fetch all the AMP sub invoices that were settled since
18✔
831
                // the provided index.
18✔
832
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
18✔
833
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
18✔
834
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
18✔
UNCOV
835
                        },
×
UNCOV
836
                )
×
UNCOV
837
                if err != nil {
×
838
                        return err
×
839
                }
×
UNCOV
840

×
UNCOV
841
                for _, ampInvoice := range ampInvoices {
×
UNCOV
842
                        // Convert the row to a sqlc.Invoice so we can use the
×
843
                        // existing fetchInvoiceData function.
844
                        sqlInvoice := sqlc.Invoice{
845
                                ID:             ampInvoice.ID,
12✔
846
                                Hash:           ampInvoice.Hash,
847
                                Preimage:       ampInvoice.Preimage,
6✔
UNCOV
848
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
UNCOV
849
                                SettledAt:      ampInvoice.AmpSettledAt,
×
850
                                Memo:           ampInvoice.Memo,
851
                                AmountMsat:     ampInvoice.AmountMsat,
852
                                CltvDelta:      ampInvoice.CltvDelta,
853
                                Expiry:         ampInvoice.Expiry,
6✔
854
                                PaymentAddr:    ampInvoice.PaymentAddr,
6✔
855
                                PaymentRequest: ampInvoice.PaymentRequest,
6✔
856
                                State:          ampInvoice.State,
6✔
857
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
6✔
858
                                IsAmp:          ampInvoice.IsAmp,
6✔
UNCOV
859
                                IsHodl:         ampInvoice.IsHodl,
×
UNCOV
860
                                IsKeysend:      ampInvoice.IsKeysend,
×
861
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
862
                        }
10✔
863

4✔
864
                        // Fetch the state and HTLCs for this AMP sub invoice.
4✔
865
                        _, invoice, err := fetchInvoiceData(
4✔
866
                                ctx, db, sqlInvoice,
4✔
867
                                (*[32]byte)(ampInvoice.SetID), true,
4✔
868
                        )
4✔
869
                        if err != nil {
4✔
870
                                return fmt.Errorf("unable to fetch "+
4✔
871
                                        "AMP invoice(id=%d) from db: %w",
4✔
872
                                        ampInvoice.ID, err)
4✔
873
                        }
4✔
874

4✔
875
                        invoices = append(invoices, *invoice)
4✔
876
                }
4✔
877

4✔
878
                return nil
4✔
879
        }, func() {
4✔
880
                invoices = nil
4✔
881
        })
4✔
882
        if err != nil {
4✔
883
                return nil, fmt.Errorf("unable to get invoices settled since "+
4✔
884
                        "index (excluding) %d: %w", idx, err)
4✔
885
        }
4✔
886

4✔
887
        return invoices, nil
4✔
888
}
4✔
889

4✔
890
// InvoicesAddedSince can be used by callers to seek into the event time series
4✔
UNCOV
891
// of all the invoices added in the database. This method will return all
×
UNCOV
892
// invoices with an add index greater than the specified idx.
×
UNCOV
893
//
×
UNCOV
894
// NOTE: The index starts from 1. As a result we enforce that specifying a value
×
895
// below the starting index value is a noop.
896
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
4✔
897
        []Invoice, error) {
4✔
898

4✔
899
        var result []Invoice
4✔
900

4✔
UNCOV
901
        if idx == 0 {
×
UNCOV
902
                return result, nil
×
UNCOV
903
        }
×
UNCOV
904

×
UNCOV
905
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
906
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
907
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
908
                        params := sqlc.FilterInvoicesParams{
×
909
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
910
                                NumOffset:   int32(offset),
911
                                NumLimit:    int32(i.opts.paginationLimit),
6✔
912
                                Reverse:     false,
6✔
913
                        }
6✔
914

6✔
915
                        rows, err := db.FilterInvoices(ctx, params)
6✔
UNCOV
916
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
917
                                return 0, fmt.Errorf("unable to get invoices "+
×
918
                                        "from db: %w", err)
×
919
                        }
920

6✔
921
                        // Load all the information for the invoices.
6✔
922
                        for _, row := range rows {
6✔
923
                                _, invoice, err := fetchInvoiceData(
6✔
924
                                        ctx, db, row, nil, true,
6✔
925
                                )
6✔
926
                                if err != nil {
6✔
927
                                        return 0, err
928
                                }
929

930
                                result = append(result, *invoice)
931
                        }
932

933
                        return len(rows), nil
934
                }, i.opts.paginationLimit)
935
        }, func() {
936
                result = nil
40✔
937
        })
40✔
938

40✔
939
        if err != nil {
40✔
940
                return nil, fmt.Errorf("unable to get invoices added since "+
40✔
941
                        "index %d: %w", idx, err)
40✔
942
        }
40✔
943

40✔
944
        return result, nil
40✔
945
}
74✔
946

34✔
947
// QueryInvoices allows a caller to query the invoice database for invoices
34✔
948
// within the specified add index range.
949
func (i *SQLStore) QueryInvoices(ctx context.Context,
6✔
950
        q InvoiceQuery) (InvoiceSlice, error) {
12✔
951

30✔
952
        var invoices []Invoice
24✔
953

24✔
954
        if q.NumMaxInvoices == 0 {
24✔
955
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
24✔
956
                        "be non-zero")
24✔
957
        }
24✔
958

24✔
959
        readTxOpt := NewSQLInvoiceQueryReadTx()
24✔
960
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
24✔
UNCOV
961
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
962
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
963
                                NumOffset:   int32(offset),
×
964
                                NumLimit:    int32(i.opts.paginationLimit),
965
                                PendingOnly: q.PendingOnly,
966
                                Reverse:     q.Reversed,
82✔
967
                        }
58✔
968

58✔
969
                        if q.Reversed {
58✔
970
                                // If the index offset was not set, we want to
58✔
UNCOV
971
                                // fetch from the lastest invoice.
×
UNCOV
972
                                if q.IndexOffset == 0 {
×
973
                                        params.AddIndexLet = sqldb.SQLInt64(
974
                                                int64(math.MaxInt64),
58✔
975
                                        )
58✔
976
                                } else {
58✔
977
                                        // The invoice with index offset id must
58✔
978
                                        // not be included in the results.
58✔
UNCOV
979
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
980
                                                q.IndexOffset - 1,
×
UNCOV
981
                                        )
×
UNCOV
982
                                }
×
UNCOV
983
                        } else {
×
UNCOV
984
                                // The invoice with index offset id must not be
×
UNCOV
985
                                // included in the results.
×
986
                                params.AddIndexGet = sqldb.SQLInt64(
987
                                        q.IndexOffset + 1,
988
                                )
24✔
989
                        }
990

6✔
991
                        if q.CreationDateStart != 0 {
6✔
992
                                params.CreatedAfter = sqldb.SQLTime(
6✔
993
                                        time.Unix(q.CreationDateStart, 0).UTC(),
994
                                )
6✔
UNCOV
995
                        }
×
UNCOV
996

×
UNCOV
997
                        if q.CreationDateEnd != 0 {
×
998
                                // We need to add 1 to the end date as we're
999
                                // checking less than the end date in SQL.
6✔
1000
                                params.CreatedBefore = sqldb.SQLTime(
6✔
1001
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
6✔
1002
                                )
6✔
1003
                        }
6✔
1004

6✔
1005
                        rows, err := db.FilterInvoices(ctx, params)
6✔
1006
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
1007
                                return 0, fmt.Errorf("unable to get invoices "+
1008
                                        "from db: %w", err)
1009
                        }
1010

1011
                        // Load all the information for the invoices.
90✔
1012
                        for _, row := range rows {
90✔
1013
                                _, invoice, err := fetchInvoiceData(
90✔
1014
                                        ctx, db, row, nil, true,
90✔
1015
                                )
90✔
UNCOV
1016
                                if err != nil {
×
1017
                                        return 0, err
×
1018
                                }
×
1019

1020
                                invoices = append(invoices, *invoice)
90✔
1021

180✔
1022
                                if len(invoices) == int(q.NumMaxInvoices) {
486✔
1023
                                        return 0, nil
396✔
1024
                                }
396✔
1025
                        }
396✔
1026

396✔
1027
                        return len(rows), nil
396✔
1028
                }, i.opts.paginationLimit)
396✔
1029
        }, func() {
396✔
1030
                invoices = nil
510✔
1031
        })
114✔
1032
        if err != nil {
114✔
1033
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
162✔
1034
                        "invoices: %w", err)
48✔
1035
        }
48✔
1036

48✔
1037
        if len(invoices) == 0 {
114✔
1038
                return InvoiceSlice{
66✔
1039
                        InvoiceQuery: q,
66✔
1040
                }, nil
66✔
1041
        }
66✔
1042

66✔
1043
        // If we iterated through the add index in reverse order, then
66✔
1044
        // we'll need to reverse the slice of invoices to return them in
282✔
1045
        // forward order.
282✔
1046
        if q.Reversed {
282✔
1047
                numInvoices := len(invoices)
282✔
1048
                for i := 0; i < numInvoices/2; i++ {
282✔
1049
                        reverse := numInvoices - i - 1
282✔
1050
                        invoices[i], invoices[reverse] =
282✔
1051
                                invoices[reverse], invoices[i]
1052
                }
458✔
1053
        }
62✔
1054

62✔
1055
        res := InvoiceSlice{
62✔
1056
                InvoiceQuery:     q,
62✔
1057
                Invoices:         invoices,
1058
                FirstIndexOffset: invoices[0].AddIndex,
458✔
1059
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
62✔
1060
        }
62✔
1061

62✔
1062
        return res, nil
62✔
1063
}
62✔
1064

62✔
1065
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1066
// a SQL database as the backend.
396✔
1067
type sqlInvoiceUpdater struct {
396✔
UNCOV
1068
        db         SQLInvoiceQueries
×
UNCOV
1069
        ctx        context.Context //nolint:containedctx
×
UNCOV
1070
        invoice    *Invoice
×
1071
        updateTime time.Time
1072
}
1073

1,418✔
1074
// AddHtlc adds a new htlc to the invoice.
1,022✔
1075
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
1,022✔
1076
        newHtlc *InvoiceHTLC) error {
1,022✔
1077

1,022✔
UNCOV
1078
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
UNCOV
1079
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
1080
                        HtlcID: int64(circuitKey.HtlcID),
1081
                        ChanID: strconv.FormatUint(
1,022✔
1082
                                circuitKey.ChanID.ToUint64(), 10,
1,022✔
1083
                        ),
1,050✔
1084
                        AmountMsat: int64(newHtlc.Amt),
28✔
1085
                        TotalMppMsat: sql.NullInt64{
28✔
1086
                                Int64: int64(newHtlc.MppTotalAmt),
1087
                                Valid: newHtlc.MppTotalAmt != 0,
1088
                        },
368✔
1089
                        AcceptHeight: int32(newHtlc.AcceptHeight),
1090
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
90✔
1091
                        ExpiryHeight: int32(newHtlc.Expiry),
90✔
1092
                        State:        int16(newHtlc.State),
90✔
1093
                        InvoiceID:    int64(s.invoice.AddIndex),
90✔
UNCOV
1094
                },
×
UNCOV
1095
        )
×
UNCOV
1096
        if err != nil {
×
1097
                return err
1098
        }
102✔
1099

12✔
1100
        for key, value := range newHtlc.CustomRecords {
12✔
1101
                err = s.db.InsertInvoiceHTLCCustomRecord(
12✔
1102
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
12✔
1103
                                // TODO(bhandras): schema might be wrong here
1104
                                // as the custom record key is an uint64.
1105
                                Key:    int64(key),
1106
                                Value:  value,
1107
                                HtlcID: htlcPrimaryKeyID,
102✔
1108
                        },
24✔
1109
                )
162✔
1110
                if err != nil {
138✔
1111
                        return err
138✔
1112
                }
138✔
1113
        }
138✔
1114

1115
        if newHtlc.AMP != nil {
1116
                setID := newHtlc.AMP.Record.SetID()
78✔
1117

78✔
1118
                upsertResult, err := s.db.UpsertAMPSubInvoice(
78✔
1119
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
78✔
1120
                                SetID:     setID[:],
78✔
1121
                                CreatedAt: s.updateTime.UTC(),
78✔
1122
                                InvoiceID: int64(s.invoice.AddIndex),
78✔
1123
                        },
78✔
1124
                )
1125
                if err != nil {
1126
                        mappedSQLErr := sqldb.MapSQLError(err)
1127
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
1128
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
1129
                                return ErrDuplicateSetID{
1130
                                        SetID: setID,
1131
                                }
1132
                        }
1133

1134
                        return err
1135
                }
1136

1137
                // If we're just inserting the AMP invoice, we'll get a non
598✔
1138
                // zero rows affected count.
598✔
1139
                rowsAffected, err := upsertResult.RowsAffected()
598✔
1140
                if err != nil {
598✔
1141
                        return err
598✔
1142
                }
598✔
1143
                if rowsAffected != 0 {
598✔
1144
                        // If we're inserting a new AMP invoice, we'll also
598✔
1145
                        // insert a new invoice event.
598✔
1146
                        err = s.db.OnAMPSubInvoiceCreated(
598✔
1147
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
598✔
1148
                                        AddedAt:   s.updateTime.UTC(),
598✔
1149
                                        InvoiceID: int64(s.invoice.AddIndex),
598✔
1150
                                        SetID:     setID[:],
598✔
1151
                                },
598✔
1152
                        )
598✔
1153
                        if err != nil {
598✔
1154
                                return err
598✔
1155
                        }
598✔
1156
                }
598✔
1157

598✔
UNCOV
1158
                rootShare := newHtlc.AMP.Record.RootShare()
×
UNCOV
1159

×
1160
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
1161
                        InvoiceID: int64(s.invoice.AddIndex),
610✔
1162
                        SetID:     setID[:],
12✔
1163
                        HtlcID:    htlcPrimaryKeyID,
12✔
1164
                        RootShare: rootShare[:],
12✔
1165
                        ChildIndex: int64(
12✔
1166
                                newHtlc.AMP.Record.ChildIndex(),
12✔
1167
                        ),
12✔
1168
                        Hash: newHtlc.AMP.Hash[:],
12✔
1169
                }
12✔
1170

12✔
1171
                if newHtlc.AMP.Preimage != nil {
12✔
UNCOV
1172
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
UNCOV
1173
                }
×
1174

1175
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
1176
                if err != nil {
646✔
1177
                        return err
48✔
1178
                }
48✔
1179
        }
48✔
1180

48✔
1181
        return nil
48✔
1182
}
48✔
1183

48✔
1184
// ResolveHtlc marks an htlc as resolved with the given state.
48✔
1185
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
48✔
1186
        state HtlcState, resolveTime time.Time) error {
50✔
1187

2✔
1188
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
2✔
1189
                HtlcID: int64(circuitKey.HtlcID),
4✔
1190
                ChanID: strconv.FormatUint(
2✔
1191
                        circuitKey.ChanID.ToUint64(), 10,
2✔
1192
                ),
2✔
1193
                InvoiceID:   int64(s.invoice.AddIndex),
2✔
1194
                State:       int16(state),
UNCOV
1195
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
1196
        })
1197
}
1198

1199
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1200
// identified by the setID.
46✔
1201
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
46✔
UNCOV
1202
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
UNCOV
1203

×
1204
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
78✔
1205
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
32✔
1206
                        InvoiceID: int64(s.invoice.AddIndex),
32✔
1207
                        SetID:     setID[:],
32✔
1208
                        HtlcID:    int64(circuitKey.HtlcID),
32✔
1209
                        Preimage:  preimage[:],
32✔
1210
                        ChanID: strconv.FormatUint(
32✔
1211
                                circuitKey.ChanID.ToUint64(), 10,
32✔
1212
                        ),
32✔
1213
                },
32✔
1214
        )
32✔
UNCOV
1215
        if err != nil {
×
1216
                return err
×
1217
        }
1218

1219
        rowsAffected, err := result.RowsAffected()
46✔
1220
        if err != nil {
46✔
1221
                return err
46✔
1222
        }
46✔
1223
        if rowsAffected == 0 {
46✔
1224
                return ErrInvoiceNotFound
46✔
1225
        }
46✔
1226

46✔
1227
        return nil
46✔
1228
}
46✔
1229

46✔
1230
// UpdateInvoiceState updates the invoice state to the new state.
46✔
1231
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
46✔
1232
        newState ContractState, preimage *lntypes.Preimage) error {
66✔
1233

20✔
1234
        var (
20✔
1235
                settleIndex sql.NullInt64
1236
                settledAt   sql.NullTime
46✔
1237
        )
46✔
UNCOV
1238

×
UNCOV
1239
        switch newState {
×
1240
        case ContractSettled:
1241
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
1242
                if err != nil {
596✔
1243
                        return err
1244
                }
1245

1246
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
1247

586✔
1248
                // If the invoice is settled, we'll also update the settle time.
586✔
1249
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
586✔
1250

586✔
1251
                err = s.db.OnInvoiceSettled(
586✔
1252
                        s.ctx, sqlc.OnInvoiceSettledParams{
586✔
1253
                                AddedAt:   s.updateTime.UTC(),
586✔
1254
                                InvoiceID: int64(s.invoice.AddIndex),
586✔
1255
                        },
586✔
1256
                )
586✔
1257
                if err != nil {
586✔
1258
                        return err
586✔
1259
                }
1260

1261
        case ContractCanceled:
1262
                err := s.db.OnInvoiceCanceled(
1263
                        s.ctx, sqlc.OnInvoiceCanceledParams{
12✔
1264
                                AddedAt:   s.updateTime.UTC(),
12✔
1265
                                InvoiceID: int64(s.invoice.AddIndex),
12✔
1266
                        },
12✔
1267
                )
12✔
1268
                if err != nil {
12✔
1269
                        return err
12✔
1270
                }
12✔
1271
        }
12✔
1272

12✔
1273
        params := sqlc.UpdateInvoiceStateParams{
12✔
1274
                ID:          int64(s.invoice.AddIndex),
12✔
1275
                State:       int16(newState),
12✔
1276
                SettleIndex: settleIndex,
12✔
UNCOV
1277
                SettledAt:   settledAt,
×
UNCOV
1278
        }
×
1279

1280
        if preimage != nil {
12✔
1281
                params.Preimage = preimage[:]
12✔
UNCOV
1282
        }
×
UNCOV
1283

×
1284
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
12✔
UNCOV
1285
        if err != nil {
×
UNCOV
1286
                return err
×
1287
        }
1288
        rowsAffected, err := result.RowsAffected()
12✔
1289
        if err != nil {
1290
                return err
1291
        }
1292

1293
        if rowsAffected == 0 {
399✔
1294
                return ErrInvoiceNotFound
399✔
1295
        }
399✔
1296

399✔
1297
        if settleIndex.Valid {
399✔
1298
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
399✔
1299
                s.invoice.SettleDate = s.updateTime
399✔
1300
        }
399✔
1301

316✔
1302
        return nil
316✔
1303
}
316✔
UNCOV
1304

×
UNCOV
1305
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
×
1306
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
1307
        amtPaid lnwire.MilliSatoshi) error {
316✔
1308

316✔
1309
        _, err := s.db.UpdateInvoiceAmountPaid(
316✔
1310
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
316✔
1311
                        ID:             int64(s.invoice.AddIndex),
316✔
1312
                        AmountPaidMsat: int64(amtPaid),
316✔
1313
                },
316✔
1314
        )
316✔
1315

316✔
1316
        return err
316✔
1317
}
316✔
1318

316✔
UNCOV
1319
// UpdateAmpState updates the state of the AMP sub invoice identified by the
×
UNCOV
1320
// setID.
×
1321
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
1322
        newState InvoiceStateAMP, _ models.CircuitKey) error {
65✔
1323

65✔
1324
        var (
65✔
1325
                settleIndex sql.NullInt64
65✔
1326
                settledAt   sql.NullTime
65✔
1327
        )
65✔
1328

65✔
1329
        switch newState.State {
65✔
UNCOV
1330
        case HtlcStateSettled:
×
UNCOV
1331
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1332
                if err != nil {
1333
                        return err
1334
                }
399✔
1335

399✔
1336
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
399✔
1337

399✔
1338
                // If the invoice is settled, we'll also update the settle time.
399✔
1339
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
399✔
1340

399✔
1341
                err = s.db.OnAMPSubInvoiceSettled(
407✔
1342
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
8✔
1343
                                AddedAt:   s.updateTime.UTC(),
8✔
1344
                                InvoiceID: int64(s.invoice.AddIndex),
1345
                                SetID:     setID[:],
399✔
1346
                        },
400✔
1347
                )
1✔
1348
                if err != nil {
1✔
1349
                        return err
398✔
1350
                }
398✔
UNCOV
1351

×
UNCOV
1352
        case HtlcStateCanceled:
×
1353
                err := s.db.OnAMPSubInvoiceCanceled(
1354
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
398✔
UNCOV
1355
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1356
                                InvoiceID: int64(s.invoice.AddIndex),
×
1357
                                SetID:     setID[:],
1358
                        },
714✔
1359
                )
316✔
1360
                if err != nil {
316✔
1361
                        return err
316✔
1362
                }
1363
        }
398✔
1364

1365
        err := s.db.UpdateAMPSubInvoiceState(
1366
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
1367
                        SetID:       setID[:],
1368
                        State:       int16(newState.State),
634✔
1369
                        SettleIndex: settleIndex,
634✔
1370
                        SettledAt:   settledAt,
634✔
1371
                },
634✔
1372
        )
634✔
1373
        if err != nil {
634✔
1374
                return err
634✔
1375
        }
634✔
1376

634✔
1377
        if settleIndex.Valid {
634✔
1378
                updatedState := s.invoice.AMPState[setID]
634✔
1379
                updatedState.SettleIndex = uint64(settleIndex.Int64)
1380
                updatedState.SettleDate = s.updateTime.UTC()
1381
                s.invoice.AMPState[setID] = updatedState
1382
        }
1383

88✔
1384
        return nil
88✔
1385
}
88✔
1386

88✔
1387
// Finalize finalizes the update before it is written to the database. Note that
88✔
1388
// we don't use this directly in the SQL implementation, so the function is just
88✔
1389
// a stub.
88✔
1390
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
88✔
1391
        return nil
22✔
1392
}
22✔
1393

22✔
UNCOV
1394
// UpdateInvoice attempts to update an invoice corresponding to the passed
×
UNCOV
1395
// reference. If an invoice matching the passed reference doesn't exist within
×
1396
// the database, then the action will fail with  ErrInvoiceNotFound error.
1397
//
22✔
1398
// The update is performed inside the same database transaction that fetches the
22✔
1399
// invoice and is therefore atomic. The fields to update are controlled by the
22✔
1400
// supplied callback.
22✔
1401
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
22✔
1402
        setID *SetID, callback InvoiceUpdateCallback) (
22✔
1403
        *Invoice, error) {
22✔
1404

22✔
1405
        var updatedInvoice *Invoice
22✔
1406

22✔
1407
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
22✔
1408
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
22✔
1409
                switch {
22✔
UNCOV
1410
                // For the default case we fetch all HTLCs.
×
UNCOV
1411
                case setID == nil:
×
1412
                        ref.refModifier = DefaultModifier
1413

22✔
1414
                // If the setID is the blank but NOT nil, we set the
22✔
1415
                // refModifier to HtlcSetBlankModifier to fetch no HTLC for the
22✔
1416
                // AMP invoice.
22✔
1417
                case *setID == BlankPayAddr:
22✔
1418
                        ref.refModifier = HtlcSetBlankModifier
22✔
1419

22✔
1420
                // A setID is provided, we use the refModifier to fetch only
22✔
1421
                // the HTLCs for the given setID and also make sure we add the
22✔
UNCOV
1422
                // setID to the ref.
×
UNCOV
1423
                default:
×
1424
                        var setIDBytes [32]byte
1425
                        copy(setIDBytes[:], setID[:])
1426
                        ref.setID = &setIDBytes
88✔
1427

88✔
1428
                        // We only fetch the HTLCs for the given setID.
88✔
1429
                        ref.refModifier = HtlcSetOnlyModifier
88✔
1430
                }
88✔
1431

88✔
1432
                invoice, err := fetchInvoice(ctx, db, ref)
88✔
1433
                if err != nil {
88✔
1434
                        return err
88✔
UNCOV
1435
                }
×
UNCOV
1436

×
1437
                updateTime := i.clock.Now()
1438
                updater := &sqlInvoiceUpdater{
110✔
1439
                        db:         db,
22✔
1440
                        ctx:        ctx,
22✔
1441
                        invoice:    invoice,
22✔
1442
                        updateTime: updateTime,
22✔
1443
                }
22✔
1444

1445
                payHash := ref.PayHash()
88✔
1446
                updatedInvoice, err = UpdateInvoice(
1447
                        payHash, invoice, updateTime, callback, updater,
1448
                )
1449

1450
                return err
1451
        }, func() {})
702✔
1452
        if txErr != nil {
702✔
1453
                // If the invoice is already settled, we'll return the
702✔
1454
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
1455
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
1456
                        return updatedInvoice, txErr
1457
                }
1458

1459
                return nil, txErr
1460
        }
1461

1462
        return updatedInvoice, nil
1463
}
1464

768✔
1465
// DeleteInvoice attempts to delete the passed invoices and all their related
768✔
1466
// data from the database in one transaction.
768✔
1467
func (i *SQLStore) DeleteInvoice(ctx context.Context,
768✔
1468
        invoicesToDelete []InvoiceDeleteRef) error {
768✔
1469

1,537✔
1470
        // All the InvoiceDeleteRef instances include the add index of the
769✔
1471
        // invoice. The rest was added to ensure that the invoices were deleted
1472
        // properly in the kv database. When we have fully migrated we can
687✔
1473
        // remove the rest of the fields.
687✔
1474
        for _, ref := range invoicesToDelete {
1475
                if ref.AddIndex == 0 {
1476
                        return fmt.Errorf("unable to delete invoice using a "+
1477
                                "ref without AddIndex set: %v", ref)
1478
                }
×
UNCOV
1479
        }
×
1480

1481
        var writeTxOpt SQLInvoiceQueriesTxOptions
1482
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
1483
                for _, ref := range invoicesToDelete {
1484
                        params := sqlc.DeleteInvoiceParams{
82✔
1485
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
82✔
1486
                        }
82✔
1487

82✔
1488
                        if ref.SettleIndex != 0 {
82✔
1489
                                params.SettleIndex = sqldb.SQLInt64(
82✔
1490
                                        ref.SettleIndex,
82✔
1491
                                )
1492
                        }
1493

769✔
1494
                        if ref.PayHash != lntypes.ZeroHash {
777✔
1495
                                params.Hash = ref.PayHash[:]
8✔
1496
                        }
8✔
1497

1498
                        result, err := db.DeleteInvoice(ctx, params)
761✔
1499
                        if err != nil {
761✔
1500
                                return fmt.Errorf("unable to delete "+
761✔
1501
                                        "invoice(%v): %w", ref.AddIndex, err)
761✔
1502
                        }
761✔
1503
                        rowsAffected, err := result.RowsAffected()
761✔
1504
                        if err != nil {
761✔
1505
                                return fmt.Errorf("unable to get rows "+
761✔
1506
                                        "affected: %w", err)
761✔
1507
                        }
761✔
1508
                        if rowsAffected == 0 {
761✔
1509
                                return fmt.Errorf("%w: %v",
761✔
1510
                                        ErrInvoiceNotFound, ref.AddIndex)
761✔
1511
                        }
761✔
1512
                }
769✔
1513

800✔
1514
                return nil
32✔
1515
        }, func() {})
32✔
1516

44✔
1517
        if err != nil {
12✔
1518
                return fmt.Errorf("unable to delete invoices: %w", err)
12✔
1519
        }
1520

20✔
1521
        return nil
1522
}
1523

736✔
1524
// DeleteCanceledInvoices removes all canceled invoices from the database.
1525
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
1526
        var writeTxOpt SQLInvoiceQueriesTxOptions
1527
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
1528
                _, err := db.DeleteCanceledInvoices(ctx)
1529
                if err != nil {
12✔
1530
                        return fmt.Errorf("unable to delete canceled "+
12✔
1531
                                "invoices: %w", err)
12✔
1532
                }
12✔
1533

12✔
1534
                return nil
12✔
1535
        }, func() {})
40✔
1536
        if err != nil {
28✔
1537
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1538
        }
×
UNCOV
1539

×
1540
        return nil
1541
}
1542

12✔
1543
// fetchInvoiceData fetches additional data for the given invoice. If the
24✔
1544
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
34✔
1545
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
22✔
1546
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
22✔
1547
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
22✔
1548
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
22✔
1549
        *Invoice, error) {
28✔
1550

6✔
1551
        // Unmarshal the common data.
6✔
1552
        hash, invoice, err := unmarshalInvoice(row)
6✔
1553
        if err != nil {
6✔
1554
                return nil, nil, fmt.Errorf("unable to unmarshal "+
1555
                        "invoice(id=%d) from db: %w", row.ID, err)
44✔
1556
        }
22✔
1557

22✔
1558
        // Fetch the invoice features.
1559
        features, err := getInvoiceFeatures(ctx, db, row.ID)
22✔
1560
        if err != nil {
22✔
1561
                return nil, nil, err
×
1562
        }
×
UNCOV
1563

×
1564
        invoice.Terms.Features = features
22✔
1565

22✔
UNCOV
1566
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
UNCOV
1567
        // with the HTLCs (if requested).
×
UNCOV
1568
        if invoice.IsAMP() {
×
1569
                invoiceID := int64(invoice.AddIndex)
28✔
1570
                ampState, ampHtlcs, err := fetchAmpState(
6✔
1571
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
6✔
1572
                )
6✔
1573
                if err != nil {
1574
                        return nil, nil, err
1575
                }
6✔
1576

12✔
1577
                invoice.AMPState = ampState
1578
                invoice.Htlcs = ampHtlcs
18✔
1579

6✔
1580
                return hash, invoice, nil
6✔
1581
        }
1582

6✔
1583
        // Otherwise simply fetch the invoice HTLCs.
1584
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
1585
        if err != nil {
1586
                return nil, nil, err
6✔
1587
        }
6✔
1588

12✔
1589
        if len(htlcs) > 0 {
6✔
1590
                invoice.Htlcs = htlcs
6✔
UNCOV
1591
        }
×
UNCOV
1592

×
UNCOV
1593
        return hash, invoice, nil
×
1594
}
1595

6✔
1596
// getInvoiceFeatures fetches the invoice features for the given invoice id.
6✔
1597
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
6✔
UNCOV
1598
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
1599

×
1600
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
1601
        if err != nil {
6✔
1602
                return nil, fmt.Errorf("unable to get invoice features: %w",
1603
                        err)
1604
        }
1605

1606
        features := lnwire.EmptyFeatureVector()
1607
        for _, feature := range rows {
1608
                features.Set(lnwire.FeatureBit(feature.Feature))
1609
        }
1610

22,719✔
1611
        return features, nil
22,719✔
1612
}
22,719✔
1613

22,719✔
1614
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
22,719✔
UNCOV
1615
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
×
UNCOV
1616
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
UNCOV
1617

×
1618
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
1619
        if err != nil {
1620
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
22,719✔
1621
        }
22,719✔
UNCOV
1622

×
UNCOV
1623
        // We have no htlcs to unmarshal.
×
1624
        if len(htlcRows) == 0 {
1625
                return nil, nil
22,719✔
1626
        }
22,719✔
1627

22,719✔
1628
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
22,719✔
1629
        if err != nil {
32,692✔
1630
                return nil, fmt.Errorf("unable to get custom records for "+
9,973✔
1631
                        "invoice htlcs: %w", err)
9,973✔
1632
        }
9,973✔
1633

9,973✔
1634
        cr := make(map[int64]record.CustomSet, len(crRows))
9,973✔
UNCOV
1635
        for _, row := range crRows {
×
UNCOV
1636
                if _, ok := cr[row.HtlcID]; !ok {
×
1637
                        cr[row.HtlcID] = make(record.CustomSet)
1638
                }
9,973✔
1639

9,973✔
1640
                value := row.Value
9,973✔
1641
                if value == nil {
9,973✔
1642
                        value = []byte{}
1643
                }
1644
                cr[row.HtlcID][uint64(row.Key)] = value
1645
        }
12,746✔
1646

12,746✔
UNCOV
1647
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
UNCOV
1648

×
1649
        for _, row := range htlcRows {
1650
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
21,329✔
1651
                if err != nil {
8,583✔
1652
                        return nil, fmt.Errorf("unable to unmarshal "+
8,583✔
1653
                                "htlc(%d): %w", row.ID, err)
1654
                }
12,746✔
1655

1656
                if customRecords, ok := cr[row.ID]; ok {
1657
                        htlc.CustomRecords = customRecords
1658
                } else {
1659
                        htlc.CustomRecords = make(record.CustomSet)
22,719✔
1660
                }
22,719✔
1661

22,719✔
1662
                htlcs[circuiteKey] = htlc
22,719✔
UNCOV
1663
        }
×
UNCOV
1664

×
UNCOV
1665
        return htlcs, nil
×
1666
}
1667

22,719✔
1668
// unmarshalInvoice converts an InvoiceRow to an Invoice.
38,878✔
1669
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
16,159✔
1670
        error) {
16,159✔
1671

1672
        var (
22,719✔
1673
                settleIndex    int64
1674
                settledAt      time.Time
1675
                memo           []byte
1676
                paymentRequest []byte
1677
                preimage       *lntypes.Preimage
12,746✔
1678
                paymentAddr    [32]byte
12,746✔
1679
        )
12,746✔
1680

12,746✔
UNCOV
1681
        hash, err := lntypes.MakeHash(row.Hash)
×
UNCOV
1682
        if err != nil {
×
1683
                return nil, nil, err
1684
        }
1685

16,909✔
1686
        if row.SettleIndex.Valid {
4,163✔
1687
                settleIndex = row.SettleIndex.Int64
4,163✔
1688
        }
1689

8,583✔
1690
        if row.SettledAt.Valid {
8,583✔
UNCOV
1691
                settledAt = row.SettledAt.Time.Local()
×
UNCOV
1692
        }
×
UNCOV
1693

×
1694
        if row.Memo.Valid {
1695
                memo = []byte(row.Memo.String)
8,583✔
1696
        }
58,599✔
1697

68,408✔
1698
        // Keysend payments will have this field empty.
18,392✔
1699
        if row.PaymentRequest.Valid {
18,392✔
1700
                paymentRequest = []byte(row.PaymentRequest.String)
1701
        } else {
50,016✔
1702
                paymentRequest = []byte{}
50,017✔
1703
        }
1✔
1704

1✔
1705
        // We may not have the preimage if this a hodl invoice.
50,016✔
1706
        if row.Preimage != nil {
1707
                preimage = &lntypes.Preimage{}
1708
                copy(preimage[:], row.Preimage)
8,583✔
1709
        }
8,583✔
1710

33,233✔
1711
        copy(paymentAddr[:], row.PaymentAddr)
24,650✔
1712

24,650✔
UNCOV
1713
        var cltvDelta int32
×
UNCOV
1714
        if row.CltvDelta.Valid {
×
UNCOV
1715
                cltvDelta = row.CltvDelta.Int32
×
1716
        }
1717

43,042✔
1718
        expiry := time.Duration(row.Expiry) * time.Second
18,392✔
1719

24,650✔
1720
        invoice := &Invoice{
6,258✔
1721
                SettleIndex:    uint64(settleIndex),
6,258✔
1722
                SettleDate:     settledAt,
1723
                Memo:           memo,
24,650✔
1724
                PaymentRequest: paymentRequest,
1725
                CreationDate:   row.CreatedAt.Local(),
1726
                Terms: ContractTerm{
8,583✔
1727
                        FinalCltvDelta:  cltvDelta,
1728
                        Expiry:          expiry,
1729
                        PaymentPreimage: preimage,
1730
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
1731
                        PaymentAddr:     paymentAddr,
22,719✔
1732
                },
22,719✔
1733
                AddIndex:    uint64(row.ID),
22,719✔
1734
                State:       ContractState(row.State),
22,719✔
1735
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
22,719✔
1736
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
22,719✔
1737
                AMPState:    AMPInvoiceState{},
22,719✔
1738
                HodlInvoice: row.IsHodl,
22,719✔
1739
        }
22,719✔
1740

22,719✔
1741
        return &hash, invoice, nil
22,719✔
1742
}
22,719✔
1743

22,719✔
UNCOV
1744
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
×
UNCOV
1745
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
×
1746
        *InvoiceHTLC, error) {
1747

28,761✔
1748
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
6,042✔
1749
        if err != nil {
6,042✔
1750
                return CircuitKey{}, nil, err
1751
        }
28,761✔
1752

6,042✔
1753
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
6,042✔
1754

1755
        if row.HtlcID < 0 {
44,037✔
1756
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
21,318✔
1757
                        "value: %v", row.HtlcID)
21,318✔
1758
        }
1759

1760
        htlcID := uint64(row.HtlcID)
43,566✔
1761

20,847✔
1762
        circuitKey := CircuitKey{
22,719✔
1763
                ChanID: chanID,
1,872✔
1764
                HtlcID: htlcID,
1,872✔
1765
        }
1766

1767
        htlc := &InvoiceHTLC{
45,301✔
1768
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
22,582✔
1769
                AcceptHeight: uint32(row.AcceptHeight),
22,582✔
1770
                AcceptTime:   row.AcceptTime.Local(),
22,582✔
1771
                Expiry:       uint32(row.ExpiryHeight),
1772
                State:        HtlcState(row.State),
22,719✔
1773
        }
22,719✔
1774

22,719✔
1775
        if row.TotalMppMsat.Valid {
45,438✔
1776
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
22,719✔
1777
        }
22,719✔
1778

1779
        if row.ResolveTime.Valid {
22,719✔
1780
                htlc.ResolveTime = row.ResolveTime.Time.Local()
22,719✔
1781
        }
22,719✔
1782

22,719✔
1783
        return circuitKey, htlc, nil
22,719✔
1784
}
22,719✔
1785

22,719✔
1786
// queryWithLimit is a helper method that can be used to query the database
22,719✔
1787
// using a limit and offset. The passed query function should return the number
22,719✔
1788
// of rows returned and an error if any.
22,719✔
1789
func queryWithLimit(query func(int) (int, error), limit int) error {
22,719✔
1790
        offset := 0
22,719✔
1791
        for {
22,719✔
1792
                rows, err := query(offset)
22,719✔
1793
                if err != nil {
22,719✔
1794
                        return err
22,719✔
1795
                }
22,719✔
1796

22,719✔
1797
                if rows < limit {
22,719✔
1798
                        return nil
22,719✔
1799
                }
22,719✔
1800

22,719✔
1801
                offset += limit
22,719✔
1802
        }
22,719✔
1803
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc