• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13536249039

26 Feb 2025 03:42AM UTC coverage: 57.462% (-1.4%) from 58.835%
13536249039

Pull #8453

github

Roasbeef
peer: update chooseDeliveryScript to gen script if needed

In this commit, we update `chooseDeliveryScript` to generate a new
script if needed. This allows us to fold in a few other lines that
always followed this function into this expanded function.

The tests have been updated accordingly.
Pull Request #8453: [4/4] - multi: integrate new rbf coop close FSM into the existing peer flow

275 of 1318 new or added lines in 22 files covered. (20.86%)

19521 existing lines in 257 files now uncovered.

103858 of 180741 relevant lines covered (57.46%)

24750.23 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.96
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27
)
28

29
// SQLInvoiceQueries is an interface that defines the set of operations that can
30
// be executed against the invoice SQL database.
31
type SQLInvoiceQueries interface { //nolint:interfacebloat
32
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
33
                error)
34

35
        // TODO(bhandras): remove this once migrations have been separated out.
36
        InsertMigratedInvoice(ctx context.Context,
37
                arg sqlc.InsertMigratedInvoiceParams) (int64, error)
38

39
        InsertInvoiceFeature(ctx context.Context,
40
                arg sqlc.InsertInvoiceFeatureParams) error
41

42
        InsertInvoiceHTLC(ctx context.Context,
43
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
44

45
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
46
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
47

48
        FilterInvoices(ctx context.Context,
49
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
50

51
        GetInvoice(ctx context.Context,
52
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
53

54
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
55
                error)
56

57
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
58
                error)
59

60
        GetInvoiceFeatures(ctx context.Context,
61
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
62

63
        GetInvoiceHTLCCustomRecords(ctx context.Context,
64
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
65

66
        GetInvoiceHTLCs(ctx context.Context,
67
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
68

69
        UpdateInvoiceState(ctx context.Context,
70
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
71

72
        UpdateInvoiceAmountPaid(ctx context.Context,
73
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
74

75
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
76

77
        UpdateInvoiceHTLC(ctx context.Context,
78
                arg sqlc.UpdateInvoiceHTLCParams) error
79

80
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
81
                sql.Result, error)
82

83
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
84

85
        // AMP sub invoice specific methods.
86
        UpsertAMPSubInvoice(ctx context.Context,
87
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
88

89
        // TODO(bhandras): remove this once migrations have been separated out.
90
        InsertAMPSubInvoice(ctx context.Context,
91
                arg sqlc.InsertAMPSubInvoiceParams) error
92

93
        UpdateAMPSubInvoiceState(ctx context.Context,
94
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
95

96
        InsertAMPSubInvoiceHTLC(ctx context.Context,
97
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
98

99
        FetchAMPSubInvoices(ctx context.Context,
100
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
101
                error)
102

103
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
104
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
105
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
106

107
        FetchSettledAMPSubInvoices(ctx context.Context,
108
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
109
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
110

111
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
112
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
113
                error)
114

115
        // Invoice events specific methods.
116
        OnInvoiceCreated(ctx context.Context,
117
                arg sqlc.OnInvoiceCreatedParams) error
118

119
        OnInvoiceCanceled(ctx context.Context,
120
                arg sqlc.OnInvoiceCanceledParams) error
121

122
        OnInvoiceSettled(ctx context.Context,
123
                arg sqlc.OnInvoiceSettledParams) error
124

125
        OnAMPSubInvoiceCreated(ctx context.Context,
126
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
127

128
        OnAMPSubInvoiceCanceled(ctx context.Context,
129
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
130

131
        OnAMPSubInvoiceSettled(ctx context.Context,
132
                arg sqlc.OnAMPSubInvoiceSettledParams) error
133

134
        // Migration specific methods.
135
        // TODO(bhandras): remove this once migrations have been separated out.
136
        InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
137
                arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
138

139
        SetKVInvoicePaymentHash(ctx context.Context,
140
                arg sqlc.SetKVInvoicePaymentHashParams) error
141

142
        GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
143
                []byte, error)
144

145
        ClearKVInvoiceHashIndex(ctx context.Context) error
146
}
147

148
var _ InvoiceDB = (*SQLStore)(nil)
149

150
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
151
// SQLInvoiceQueries understands.
152
type SQLInvoiceQueriesTxOptions struct {
153
        // readOnly governs if a read only transaction is needed or not.
154
        readOnly bool
155
}
156

157
// ReadOnly returns true if the transaction should be read only.
158
//
159
// NOTE: This implements the TxOptions.
160
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
22,795✔
161
        return a.readOnly
22,795✔
162
}
22,795✔
163

164
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
165
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
21,216✔
166
        return SQLInvoiceQueriesTxOptions{
21,216✔
167
                readOnly: true,
21,216✔
168
        }
21,216✔
169
}
21,216✔
170

171
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
172
// of batched database operations.
173
type BatchedSQLInvoiceQueries interface {
174
        SQLInvoiceQueries
175

176
        sqldb.BatchedTx[SQLInvoiceQueries]
177
}
178

179
// SQLStore represents a storage backend.
180
type SQLStore struct {
181
        db    BatchedSQLInvoiceQueries
182
        clock clock.Clock
183
        opts  SQLStoreOptions
184
}
185

186
// SQLStoreOptions holds the options for the SQL store.
187
type SQLStoreOptions struct {
188
        paginationLimit int
189
}
190

191
// defaultSQLStoreOptions returns the default options for the SQL store.
192
func defaultSQLStoreOptions() SQLStoreOptions {
512✔
193
        return SQLStoreOptions{
512✔
194
                paginationLimit: defaultQueryPaginationLimit,
512✔
195
        }
512✔
196
}
512✔
197

198
// SQLStoreOption is a functional option that can be used to optionally modify
199
// the behavior of the SQL store.
200
type SQLStoreOption func(*SQLStoreOptions)
201

202
// WithPaginationLimit sets the pagination limit for the SQL store queries that
203
// paginate results.
204
func WithPaginationLimit(limit int) SQLStoreOption {
50✔
205
        return func(o *SQLStoreOptions) {
100✔
206
                o.paginationLimit = limit
50✔
207
        }
50✔
208
}
209

210
// NewSQLStore creates a new SQLStore instance given a open
211
// BatchedSQLInvoiceQueries storage backend.
212
func NewSQLStore(db BatchedSQLInvoiceQueries,
213
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
512✔
214

512✔
215
        opts := defaultSQLStoreOptions()
512✔
216
        for _, applyOption := range options {
562✔
217
                applyOption(&opts)
50✔
218
        }
50✔
219

220
        return &SQLStore{
512✔
221
                db:    db,
512✔
222
                clock: clock,
512✔
223
                opts:  opts,
512✔
224
        }
512✔
225
}
226

227
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
228
        sqlc.InsertInvoiceParams, error) {
20,596✔
229

20,596✔
230
        // Precompute the payment request hash so we can use it in the query.
20,596✔
231
        var paymentRequestHash []byte
20,596✔
232
        if len(invoice.PaymentRequest) > 0 {
40,788✔
233
                h := sha256.New()
20,192✔
234
                h.Write(invoice.PaymentRequest)
20,192✔
235
                paymentRequestHash = h.Sum(nil)
20,192✔
236
        }
20,192✔
237

238
        params := sqlc.InsertInvoiceParams{
20,596✔
239
                Hash:       paymentHash[:],
20,596✔
240
                AmountMsat: int64(invoice.Terms.Value),
20,596✔
241
                CltvDelta: sqldb.SQLInt32(
20,596✔
242
                        invoice.Terms.FinalCltvDelta,
20,596✔
243
                ),
20,596✔
244
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
20,596✔
245
                // Note: keysend invoices don't have a payment request.
20,596✔
246
                PaymentRequest: sqldb.SQLStr(string(
20,596✔
247
                        invoice.PaymentRequest),
20,596✔
248
                ),
20,596✔
249
                PaymentRequestHash: paymentRequestHash,
20,596✔
250
                State:              int16(invoice.State),
20,596✔
251
                AmountPaidMsat:     int64(invoice.AmtPaid),
20,596✔
252
                IsAmp:              invoice.IsAMP(),
20,596✔
253
                IsHodl:             invoice.HodlInvoice,
20,596✔
254
                IsKeysend:          invoice.IsKeysend(),
20,596✔
255
                CreatedAt:          invoice.CreationDate.UTC(),
20,596✔
256
        }
20,596✔
257

20,596✔
258
        if invoice.Memo != nil {
40,890✔
259
                // Store the memo as a nullable string in the database. Note
20,294✔
260
                // that for compatibility reasons, we store the value as a valid
20,294✔
261
                // string even if it's empty.
20,294✔
262
                params.Memo = sql.NullString{
20,294✔
263
                        String: string(invoice.Memo),
20,294✔
264
                        Valid:  true,
20,294✔
265
                }
20,294✔
266
        }
20,294✔
267

268
        // Some invoices may not have a preimage, like in the case of HODL
269
        // invoices.
270
        if invoice.Terms.PaymentPreimage != nil {
41,162✔
271
                preimage := *invoice.Terms.PaymentPreimage
20,566✔
272
                if preimage == UnknownPreimage {
20,566✔
273
                        return sqlc.InsertInvoiceParams{},
×
274
                                errors.New("cannot use all-zeroes preimage")
×
275
                }
×
276
                params.Preimage = preimage[:]
20,566✔
277
        }
278

279
        // Some non MPP payments may have the default (invalid) value.
280
        if invoice.Terms.PaymentAddr != BlankPayAddr {
20,946✔
281
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
350✔
282
        }
350✔
283

284
        return params, nil
20,596✔
285
}
286

287
// AddInvoice inserts the targeted invoice into the database. If the invoice has
288
// *any* payment hashes which already exists within the database, then the
289
// insertion will be aborted and rejected due to the strict policy banning any
290
// duplicate payment hashes.
291
//
292
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
293
func (i *SQLStore) AddInvoice(ctx context.Context,
294
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
596✔
295

596✔
296
        // Make sure this is a valid invoice before trying to store it in our
596✔
297
        // DB.
596✔
298
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
600✔
299
                return 0, err
4✔
300
        }
4✔
301

302
        var (
592✔
303
                writeTxOpts SQLInvoiceQueriesTxOptions
592✔
304
                invoiceID   int64
592✔
305
        )
592✔
306

592✔
307
        insertInvoiceParams, err := makeInsertInvoiceParams(
592✔
308
                newInvoice, paymentHash,
592✔
309
        )
592✔
310
        if err != nil {
592✔
311
                return 0, err
×
312
        }
×
313

314
        err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
1,184✔
315
                var err error
592✔
316
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
592✔
317
                if err != nil {
608✔
318
                        return fmt.Errorf("unable to insert invoice: %w", err)
16✔
319
                }
16✔
320

321
                // TODO(positiveblue): if invocies do not have custom features
322
                // maybe just store the "invoice type" and populate the features
323
                // based on that.
324
                for feature := range newInvoice.Terms.Features.Features() {
670✔
325
                        params := sqlc.InsertInvoiceFeatureParams{
94✔
326
                                InvoiceID: invoiceID,
94✔
327
                                Feature:   int32(feature),
94✔
328
                        }
94✔
329

94✔
330
                        err := db.InsertInvoiceFeature(ctx, params)
94✔
331
                        if err != nil {
94✔
332
                                return fmt.Errorf("unable to insert invoice "+
×
333
                                        "feature(%v): %w", feature, err)
×
334
                        }
×
335
                }
336

337
                // Finally add a new event for this invoice.
338
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
576✔
339
                        AddedAt:   newInvoice.CreationDate.UTC(),
576✔
340
                        InvoiceID: invoiceID,
576✔
341
                })
576✔
342
        }, func() {})
592✔
343
        if err != nil {
608✔
344
                mappedSQLErr := sqldb.MapSQLError(err)
16✔
345
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
16✔
346
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
32✔
347
                        // Add context to unique constraint errors.
16✔
348
                        return 0, ErrDuplicateInvoice
16✔
349
                }
16✔
350

351
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
352
                        paymentHash, err)
×
353
        }
354

355
        newInvoice.AddIndex = uint64(invoiceID)
576✔
356

576✔
357
        return newInvoice.AddIndex, nil
576✔
358
}
359

360
// getInvoiceByRef fetches the invoice with the given reference. The reference
361
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
362
func getInvoiceByRef(ctx context.Context,
363
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
21,625✔
364

21,625✔
365
        // If the reference is empty, we can't look up the invoice.
21,625✔
366
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
21,627✔
367
                return sqlc.Invoice{}, ErrInvoiceNotFound
2✔
368
        }
2✔
369

370
        // If the reference is a hash only, we can look up the invoice directly
371
        // by the payment hash which is faster.
372
        if ref.IsHashOnly() {
42,223✔
373
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
20,600✔
374
                if errors.Is(err, sql.ErrNoRows) {
20,632✔
375
                        return sqlc.Invoice{}, ErrInvoiceNotFound
32✔
376
                }
32✔
377

378
                return invoice, err
20,568✔
379
        }
380

381
        // Otherwise the reference may include more fields, so we'll need to
382
        // assemble the query parameters based on the fields that are set.
383
        var params sqlc.GetInvoiceParams
1,023✔
384

1,023✔
385
        if ref.PayHash() != nil {
1,949✔
386
                params.Hash = ref.PayHash()[:]
926✔
387
        }
926✔
388

389
        // Newer invoices (0.11 and up) are indexed by payment address in
390
        // addition to payment hash, but pre 0.8 invoices do not have one at
391
        // all. Only allow lookups for payment address if it is not a blank
392
        // payment address, which is a special-cased value for legacy keysend
393
        // invoices.
394
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
1,159✔
395
                params.PaymentAddr = ref.PayAddr()[:]
136✔
396
        }
136✔
397

398
        // If the reference has a set ID we'll fetch the invoice which has the
399
        // corresponding AMP sub invoice.
400
        if ref.SetID() != nil {
1,120✔
401
                params.SetID = ref.SetID()[:]
97✔
402
        }
97✔
403

404
        var (
1,023✔
405
                rows []sqlc.Invoice
1,023✔
406
                err  error
1,023✔
407
        )
1,023✔
408

1,023✔
409
        // We need to split the query based on how we intend to look up the
1,023✔
410
        // invoice. If only the set ID is given then we want to have an exact
1,023✔
411
        // match on the set ID. If other fields are given, we want to match on
1,023✔
412
        // those fields and the set ID but with a less strict join condition.
1,023✔
413
        if params.Hash == nil && params.PaymentAddr == nil &&
1,023✔
414
                params.SetID != nil {
1,048✔
415

25✔
416
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
25✔
417
        } else {
1,023✔
418
                rows, err = db.GetInvoice(ctx, params)
998✔
419
        }
998✔
420

421
        switch {
1,023✔
422
        case len(rows) == 0:
14✔
423
                return sqlc.Invoice{}, ErrInvoiceNotFound
14✔
424

425
        case len(rows) > 1:
×
426
                // In case the reference is ambiguous, meaning it matches more
×
427
                // than        one invoice, we'll return an error.
×
428
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
×
429
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
430

431
        case err != nil:
×
432
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
433
                        err)
×
434
        }
435

436
        return rows[0], nil
1,009✔
437
}
438

439
// fetchInvoice fetches the common invoice data and the AMP state for the
440
// invoice with the given reference.
441
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
442
        *Invoice, error) {
21,625✔
443

21,625✔
444
        // Fetch the invoice from the database.
21,625✔
445
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
21,625✔
446
        if err != nil {
21,673✔
447
                return nil, err
48✔
448
        }
48✔
449

450
        var (
21,577✔
451
                setID         *[32]byte
21,577✔
452
                fetchAmpHtlcs bool
21,577✔
453
        )
21,577✔
454

21,577✔
455
        // Now that we got the invoice itself, fetch the HTLCs as requested by
21,577✔
456
        // the modifier.
21,577✔
457
        switch ref.Modifier() {
21,577✔
458
        case DefaultModifier:
21,484✔
459
                // By default we'll fetch all AMP HTLCs.
21,484✔
460
                setID = nil
21,484✔
461
                fetchAmpHtlcs = true
21,484✔
462

463
        case HtlcSetOnlyModifier:
91✔
464
                // In this case we'll fetch all AMP HTLCs for the specified set
91✔
465
                // id.
91✔
466
                if ref.SetID() == nil {
91✔
467
                        return nil, fmt.Errorf("set ID is required to use " +
×
468
                                "the HTLC set only modifier")
×
469
                }
×
470

471
                setID = ref.SetID()
91✔
472
                fetchAmpHtlcs = true
91✔
473

474
        case HtlcSetBlankModifier:
2✔
475
                // No need to fetch any HTLCs.
2✔
476
                setID = nil
2✔
477
                fetchAmpHtlcs = false
2✔
478

479
        default:
×
480
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
481
                        ref.Modifier())
×
482
        }
483

484
        // Fetch the rest of the invoice data and fill the invoice struct.
485
        _, invoice, err := fetchInvoiceData(
21,577✔
486
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
21,577✔
487
        )
21,577✔
488
        if err != nil {
21,577✔
489
                return nil, err
×
490
        }
×
491

492
        return invoice, nil
21,577✔
493
}
494

495
// fetchAmpState fetches the AMP state for the invoice with the given ID.
496
// Optional setID can be provided to fetch the state for a specific AMP HTLC
497
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
498
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
499
// well.
500
//
501
//nolint:funlen
502
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
503
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
504
        HTLCSet, error) {
11,173✔
505

11,173✔
506
        var paramSetID []byte
11,173✔
507
        if setID != nil {
11,268✔
508
                paramSetID = setID[:]
95✔
509
        }
95✔
510

511
        // First fetch all the AMP sub invoices for this invoice or the one
512
        // matching the provided set ID.
513
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
11,173✔
514
                ctx, sqlc.FetchAMPSubInvoicesParams{
11,173✔
515
                        InvoiceID: invoiceID,
11,173✔
516
                        SetID:     paramSetID,
11,173✔
517
                },
11,173✔
518
        )
11,173✔
519
        if err != nil {
11,173✔
520
                return nil, nil, err
×
521
        }
×
522

523
        ampState := make(map[SetID]InvoiceStateAMP)
11,173✔
524
        for _, row := range ampInvoiceRows {
42,409✔
525
                var rowSetID [32]byte
31,236✔
526

31,236✔
527
                if len(row.SetID) != 32 {
31,236✔
528
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
529
                                len(row.SetID))
×
530
                }
×
531

532
                var settleDate time.Time
31,236✔
533
                if row.SettledAt.Valid {
39,321✔
534
                        settleDate = row.SettledAt.Time.Local()
8,085✔
535
                }
8,085✔
536

537
                copy(rowSetID[:], row.SetID)
31,236✔
538
                ampState[rowSetID] = InvoiceStateAMP{
31,236✔
539
                        State:       HtlcState(row.State),
31,236✔
540
                        SettleIndex: uint64(row.SettleIndex.Int64),
31,236✔
541
                        SettleDate:  settleDate,
31,236✔
542
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
31,236✔
543
                }
31,236✔
544
        }
545

546
        if !fetchHtlcs {
11,175✔
547
                return ampState, nil, nil
2✔
548
        }
2✔
549

550
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
11,171✔
551
        if err != nil {
11,171✔
552
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
553
                        "invoice HTLCs: %w", err)
×
554
        }
×
555

556
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
11,171✔
557
        for _, row := range customRecordRows {
202,995✔
558
                if _, ok := customRecords[row.HtlcID]; !ok {
261,750✔
559
                        customRecords[row.HtlcID] = make(record.CustomSet)
69,926✔
560
                }
69,926✔
561

562
                value := row.Value
191,824✔
563
                if value == nil {
191,824✔
564
                        value = []byte{}
×
565
                }
×
566

567
                customRecords[row.HtlcID][uint64(row.Key)] = value
191,824✔
568
        }
569

570
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
571
        // provided set ID.
572
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
11,171✔
573
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
11,171✔
574
                        InvoiceID: invoiceID,
11,171✔
575
                        SetID:     paramSetID,
11,171✔
576
                },
11,171✔
577
        )
11,171✔
578
        if err != nil {
11,171✔
579
                return nil, nil, err
×
580
        }
×
581

582
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
11,171✔
583
        for _, row := range ampHtlcRows {
99,635✔
584
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
88,464✔
585
                if err != nil {
88,464✔
586
                        return nil, nil, err
×
587
                }
×
588

589
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
88,464✔
590

88,464✔
591
                if row.HtlcID < 0 {
88,464✔
592
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
593
                                "value: %v", row.HtlcID)
×
594
                }
×
595

596
                htlcID := uint64(row.HtlcID)
88,464✔
597

88,464✔
598
                circuitKey := CircuitKey{
88,464✔
599
                        ChanID: chanID,
88,464✔
600
                        HtlcID: htlcID,
88,464✔
601
                }
88,464✔
602

88,464✔
603
                htlc := &InvoiceHTLC{
88,464✔
604
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
88,464✔
605
                        AcceptHeight: uint32(row.AcceptHeight),
88,464✔
606
                        AcceptTime:   row.AcceptTime.Local(),
88,464✔
607
                        Expiry:       uint32(row.ExpiryHeight),
88,464✔
608
                        State:        HtlcState(row.State),
88,464✔
609
                }
88,464✔
610

88,464✔
611
                if row.TotalMppMsat.Valid {
88,577✔
612
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
113✔
613
                                row.TotalMppMsat.Int64,
113✔
614
                        )
113✔
615
                }
113✔
616

617
                if row.ResolveTime.Valid {
130,565✔
618
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
42,101✔
619
                }
42,101✔
620

621
                var (
88,464✔
622
                        rootShare [32]byte
88,464✔
623
                        setID     [32]byte
88,464✔
624
                )
88,464✔
625

88,464✔
626
                if len(row.RootShare) != 32 {
88,464✔
627
                        return nil, nil, fmt.Errorf("invalid root share "+
×
628
                                "length: %d", len(row.RootShare))
×
629
                }
×
630
                copy(rootShare[:], row.RootShare)
88,464✔
631

88,464✔
632
                if len(row.SetID) != 32 {
88,464✔
633
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
634
                                len(row.SetID))
×
635
                }
×
636
                copy(setID[:], row.SetID)
88,464✔
637

88,464✔
638
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
88,464✔
639
                        return nil, nil, fmt.Errorf("invalid child index "+
×
640
                                "value: %v", row.ChildIndex)
×
641
                }
×
642

643
                ampRecord := record.NewAMP(
88,464✔
644
                        rootShare, setID, uint32(row.ChildIndex),
88,464✔
645
                )
88,464✔
646

88,464✔
647
                htlc.AMP = &InvoiceHtlcAMPData{
88,464✔
648
                        Record: *ampRecord,
88,464✔
649
                }
88,464✔
650

88,464✔
651
                if len(row.Hash) != 32 {
88,464✔
652
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
653
                                len(row.Hash))
×
654
                }
×
655
                copy(htlc.AMP.Hash[:], row.Hash)
88,464✔
656

88,464✔
657
                if row.Preimage != nil {
176,805✔
658
                        preimage, err := lntypes.MakePreimage(row.Preimage)
88,341✔
659
                        if err != nil {
88,341✔
660
                                return nil, nil, err
×
661
                        }
×
662

663
                        htlc.AMP.Preimage = &preimage
88,341✔
664
                }
665

666
                if _, ok := customRecords[row.ID]; ok {
158,390✔
667
                        htlc.CustomRecords = customRecords[row.ID]
69,926✔
668
                } else {
88,464✔
669
                        htlc.CustomRecords = make(record.CustomSet)
18,538✔
670
                }
18,538✔
671

672
                ampHtlcs[circuitKey] = htlc
88,464✔
673
        }
674

675
        if len(ampHtlcs) > 0 {
22,290✔
676
                for setID := range ampState {
42,349✔
677
                        var amtPaid lnwire.MilliSatoshi
31,230✔
678
                        invoiceKeys := make(
31,230✔
679
                                map[models.CircuitKey]struct{},
31,230✔
680
                        )
31,230✔
681

31,230✔
682
                        for key, htlc := range ampHtlcs {
344,764✔
683
                                if htlc.AMP.Record.SetID() != setID {
538,604✔
684
                                        continue
225,070✔
685
                                }
686

687
                                invoiceKeys[key] = struct{}{}
88,464✔
688

88,464✔
689
                                if htlc.State != HtlcStateCanceled {
157,817✔
690
                                        amtPaid += htlc.Amt
69,353✔
691
                                }
69,353✔
692
                        }
693

694
                        setState := ampState[setID]
31,230✔
695
                        setState.InvoiceKeys = invoiceKeys
31,230✔
696
                        setState.AmtPaid = amtPaid
31,230✔
697
                        ampState[setID] = setState
31,230✔
698
                }
699
        }
700

701
        return ampState, ampHtlcs, nil
11,171✔
702
}
703

704
// LookupInvoice attempts to look up an invoice corresponding the passed in
705
// reference. The reference may be a payment hash, a payment address, or a set
706
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
707
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
708
// error.
709
func (i *SQLStore) LookupInvoice(ctx context.Context,
710
        ref InvoiceRef) (Invoice, error) {
20,852✔
711

20,852✔
712
        var (
20,852✔
713
                invoice *Invoice
20,852✔
714
                err     error
20,852✔
715
        )
20,852✔
716

20,852✔
717
        readTxOpt := NewSQLInvoiceQueryReadTx()
20,852✔
718
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
41,704✔
719
                invoice, err = fetchInvoice(ctx, db, ref)
20,852✔
720

20,852✔
721
                return err
20,852✔
722
        }, func() {})
41,704✔
723
        if txErr != nil {
20,892✔
724
                return Invoice{}, txErr
40✔
725
        }
40✔
726

727
        return *invoice, nil
20,812✔
728
}
729

730
// FetchPendingInvoices returns all the invoices that are currently in a
731
// "pending" state. An invoice is pending if it has been created but not yet
732
// settled or canceled.
733
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
734
        map[lntypes.Hash]Invoice, error) {
262✔
735

262✔
736
        var invoices map[lntypes.Hash]Invoice
262✔
737

262✔
738
        readTxOpt := NewSQLInvoiceQueryReadTx()
262✔
739
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
524✔
740
                return queryWithLimit(func(offset int) (int, error) {
530✔
741
                        params := sqlc.FilterInvoicesParams{
268✔
742
                                PendingOnly: true,
268✔
743
                                NumOffset:   int32(offset),
268✔
744
                                NumLimit:    int32(i.opts.paginationLimit),
268✔
745
                                Reverse:     false,
268✔
746
                        }
268✔
747

268✔
748
                        rows, err := db.FilterInvoices(ctx, params)
268✔
749
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
268✔
750
                                return 0, fmt.Errorf("unable to get invoices "+
×
751
                                        "from db: %w", err)
×
752
                        }
×
753

754
                        // Load all the information for the invoices.
755
                        for _, row := range rows {
308✔
756
                                hash, invoice, err := fetchInvoiceData(
40✔
757
                                        ctx, db, row, nil, true,
40✔
758
                                )
40✔
759
                                if err != nil {
40✔
760
                                        return 0, err
×
761
                                }
×
762

763
                                invoices[*hash] = *invoice
40✔
764
                        }
765

766
                        return len(rows), nil
268✔
767
                }, i.opts.paginationLimit)
768
        }, func() {
262✔
769
                invoices = make(map[lntypes.Hash]Invoice)
262✔
770
        })
262✔
771
        if err != nil {
262✔
772
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
773
                        err)
×
774
        }
×
775

776
        return invoices, nil
262✔
777
}
778

779
// InvoicesSettledSince can be used by callers to catch up any settled invoices
780
// they missed within the settled invoice time series. We'll return all known
781
// settled invoice that have a settle index higher than the passed idx.
782
//
783
// NOTE: The index starts from 1. As a result we enforce that specifying a value
784
// below the starting index value is a noop.
785
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
786
        []Invoice, error) {
42✔
787

42✔
788
        var invoices []Invoice
42✔
789

42✔
790
        if idx == 0 {
78✔
791
                return invoices, nil
36✔
792
        }
36✔
793

794
        readTxOpt := NewSQLInvoiceQueryReadTx()
6✔
795
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
12✔
796
                err := queryWithLimit(func(offset int) (int, error) {
18✔
797
                        params := sqlc.FilterInvoicesParams{
12✔
798
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
12✔
799
                                NumOffset:      int32(offset),
12✔
800
                                NumLimit:       int32(i.opts.paginationLimit),
12✔
801
                                Reverse:        false,
12✔
802
                        }
12✔
803

12✔
804
                        rows, err := db.FilterInvoices(ctx, params)
12✔
805
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
12✔
806
                                return 0, fmt.Errorf("unable to get invoices "+
×
807
                                        "from db: %w", err)
×
808
                        }
×
809

810
                        // Load all the information for the invoices.
811
                        for _, row := range rows {
30✔
812
                                _, invoice, err := fetchInvoiceData(
18✔
813
                                        ctx, db, row, nil, true,
18✔
814
                                )
18✔
815
                                if err != nil {
18✔
816
                                        return 0, fmt.Errorf("unable to fetch "+
×
817
                                                "invoice(id=%d) from db: %w",
×
818
                                                row.ID, err)
×
819
                                }
×
820

821
                                invoices = append(invoices, *invoice)
18✔
822
                        }
823

824
                        return len(rows), nil
12✔
825
                }, i.opts.paginationLimit)
826
                if err != nil {
6✔
827
                        return err
×
828
                }
×
829

830
                // Now fetch all the AMP sub invoices that were settled since
831
                // the provided index.
832
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
6✔
833
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
6✔
834
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
6✔
835
                        },
6✔
836
                )
6✔
837
                if err != nil {
6✔
838
                        return err
×
839
                }
×
840

841
                for _, ampInvoice := range ampInvoices {
10✔
842
                        // Convert the row to a sqlc.Invoice so we can use the
4✔
843
                        // existing fetchInvoiceData function.
4✔
844
                        sqlInvoice := sqlc.Invoice{
4✔
845
                                ID:             ampInvoice.ID,
4✔
846
                                Hash:           ampInvoice.Hash,
4✔
847
                                Preimage:       ampInvoice.Preimage,
4✔
848
                                SettleIndex:    ampInvoice.AmpSettleIndex,
4✔
849
                                SettledAt:      ampInvoice.AmpSettledAt,
4✔
850
                                Memo:           ampInvoice.Memo,
4✔
851
                                AmountMsat:     ampInvoice.AmountMsat,
4✔
852
                                CltvDelta:      ampInvoice.CltvDelta,
4✔
853
                                Expiry:         ampInvoice.Expiry,
4✔
854
                                PaymentAddr:    ampInvoice.PaymentAddr,
4✔
855
                                PaymentRequest: ampInvoice.PaymentRequest,
4✔
856
                                State:          ampInvoice.State,
4✔
857
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
4✔
858
                                IsAmp:          ampInvoice.IsAmp,
4✔
859
                                IsHodl:         ampInvoice.IsHodl,
4✔
860
                                IsKeysend:      ampInvoice.IsKeysend,
4✔
861
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
4✔
862
                        }
4✔
863

4✔
864
                        // Fetch the state and HTLCs for this AMP sub invoice.
4✔
865
                        _, invoice, err := fetchInvoiceData(
4✔
866
                                ctx, db, sqlInvoice,
4✔
867
                                (*[32]byte)(ampInvoice.SetID), true,
4✔
868
                        )
4✔
869
                        if err != nil {
4✔
870
                                return fmt.Errorf("unable to fetch "+
×
871
                                        "AMP invoice(id=%d) from db: %w",
×
872
                                        ampInvoice.ID, err)
×
873
                        }
×
874

875
                        invoices = append(invoices, *invoice)
4✔
876
                }
877

878
                return nil
6✔
879
        }, func() {
6✔
880
                invoices = nil
6✔
881
        })
6✔
882
        if err != nil {
6✔
883
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
884
                        "index (excluding) %d: %w", idx, err)
×
885
        }
×
886

887
        return invoices, nil
6✔
888
}
889

890
// InvoicesAddedSince can be used by callers to seek into the event time series
891
// of all the invoices added in the database. This method will return all
892
// invoices with an add index greater than the specified idx.
893
//
894
// NOTE: The index starts from 1. As a result we enforce that specifying a value
895
// below the starting index value is a noop.
896
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
897
        []Invoice, error) {
40✔
898

40✔
899
        var result []Invoice
40✔
900

40✔
901
        if idx == 0 {
74✔
902
                return result, nil
34✔
903
        }
34✔
904

905
        readTxOpt := NewSQLInvoiceQueryReadTx()
6✔
906
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
12✔
907
                return queryWithLimit(func(offset int) (int, error) {
30✔
908
                        params := sqlc.FilterInvoicesParams{
24✔
909
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
24✔
910
                                NumOffset:   int32(offset),
24✔
911
                                NumLimit:    int32(i.opts.paginationLimit),
24✔
912
                                Reverse:     false,
24✔
913
                        }
24✔
914

24✔
915
                        rows, err := db.FilterInvoices(ctx, params)
24✔
916
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
24✔
917
                                return 0, fmt.Errorf("unable to get invoices "+
×
918
                                        "from db: %w", err)
×
919
                        }
×
920

921
                        // Load all the information for the invoices.
922
                        for _, row := range rows {
82✔
923
                                _, invoice, err := fetchInvoiceData(
58✔
924
                                        ctx, db, row, nil, true,
58✔
925
                                )
58✔
926
                                if err != nil {
58✔
927
                                        return 0, err
×
928
                                }
×
929

930
                                result = append(result, *invoice)
58✔
931
                        }
932

933
                        return len(rows), nil
24✔
934
                }, i.opts.paginationLimit)
935
        }, func() {
6✔
936
                result = nil
6✔
937
        })
6✔
938

939
        if err != nil {
6✔
940
                return nil, fmt.Errorf("unable to get invoices added since "+
×
941
                        "index %d: %w", idx, err)
×
942
        }
×
943

944
        return result, nil
6✔
945
}
946

947
// QueryInvoices allows a caller to query the invoice database for invoices
948
// within the specified add index range.
949
func (i *SQLStore) QueryInvoices(ctx context.Context,
950
        q InvoiceQuery) (InvoiceSlice, error) {
90✔
951

90✔
952
        var invoices []Invoice
90✔
953

90✔
954
        if q.NumMaxInvoices == 0 {
90✔
955
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
956
                        "be non-zero")
×
957
        }
×
958

959
        readTxOpt := NewSQLInvoiceQueryReadTx()
90✔
960
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
180✔
961
                return queryWithLimit(func(offset int) (int, error) {
486✔
962
                        params := sqlc.FilterInvoicesParams{
396✔
963
                                NumOffset:   int32(offset),
396✔
964
                                NumLimit:    int32(i.opts.paginationLimit),
396✔
965
                                PendingOnly: q.PendingOnly,
396✔
966
                                Reverse:     q.Reversed,
396✔
967
                        }
396✔
968

396✔
969
                        if q.Reversed {
510✔
970
                                // If the index offset was not set, we want to
114✔
971
                                // fetch from the lastest invoice.
114✔
972
                                if q.IndexOffset == 0 {
162✔
973
                                        params.AddIndexLet = sqldb.SQLInt64(
48✔
974
                                                int64(math.MaxInt64),
48✔
975
                                        )
48✔
976
                                } else {
114✔
977
                                        // The invoice with index offset id must
66✔
978
                                        // not be included in the results.
66✔
979
                                        params.AddIndexLet = sqldb.SQLInt64(
66✔
980
                                                q.IndexOffset - 1,
66✔
981
                                        )
66✔
982
                                }
66✔
983
                        } else {
282✔
984
                                // The invoice with index offset id must not be
282✔
985
                                // included in the results.
282✔
986
                                params.AddIndexGet = sqldb.SQLInt64(
282✔
987
                                        q.IndexOffset + 1,
282✔
988
                                )
282✔
989
                        }
282✔
990

991
                        if q.CreationDateStart != 0 {
458✔
992
                                params.CreatedAfter = sqldb.SQLTime(
62✔
993
                                        time.Unix(q.CreationDateStart, 0).UTC(),
62✔
994
                                )
62✔
995
                        }
62✔
996

997
                        if q.CreationDateEnd != 0 {
458✔
998
                                // We need to add 1 to the end date as we're
62✔
999
                                // checking less than the end date in SQL.
62✔
1000
                                params.CreatedBefore = sqldb.SQLTime(
62✔
1001
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
62✔
1002
                                )
62✔
1003
                        }
62✔
1004

1005
                        rows, err := db.FilterInvoices(ctx, params)
396✔
1006
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
396✔
1007
                                return 0, fmt.Errorf("unable to get invoices "+
×
1008
                                        "from db: %w", err)
×
1009
                        }
×
1010

1011
                        // Load all the information for the invoices.
1012
                        for _, row := range rows {
1,418✔
1013
                                _, invoice, err := fetchInvoiceData(
1,022✔
1014
                                        ctx, db, row, nil, true,
1,022✔
1015
                                )
1,022✔
1016
                                if err != nil {
1,022✔
1017
                                        return 0, err
×
1018
                                }
×
1019

1020
                                invoices = append(invoices, *invoice)
1,022✔
1021

1,022✔
1022
                                if len(invoices) == int(q.NumMaxInvoices) {
1,050✔
1023
                                        return 0, nil
28✔
1024
                                }
28✔
1025
                        }
1026

1027
                        return len(rows), nil
368✔
1028
                }, i.opts.paginationLimit)
1029
        }, func() {
90✔
1030
                invoices = nil
90✔
1031
        })
90✔
1032
        if err != nil {
90✔
1033
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
1034
                        "invoices: %w", err)
×
1035
        }
×
1036

1037
        if len(invoices) == 0 {
102✔
1038
                return InvoiceSlice{
12✔
1039
                        InvoiceQuery: q,
12✔
1040
                }, nil
12✔
1041
        }
12✔
1042

1043
        // If we iterated through the add index in reverse order, then
1044
        // we'll need to reverse the slice of invoices to return them in
1045
        // forward order.
1046
        if q.Reversed {
102✔
1047
                numInvoices := len(invoices)
24✔
1048
                for i := 0; i < numInvoices/2; i++ {
162✔
1049
                        reverse := numInvoices - i - 1
138✔
1050
                        invoices[i], invoices[reverse] =
138✔
1051
                                invoices[reverse], invoices[i]
138✔
1052
                }
138✔
1053
        }
1054

1055
        res := InvoiceSlice{
78✔
1056
                InvoiceQuery:     q,
78✔
1057
                Invoices:         invoices,
78✔
1058
                FirstIndexOffset: invoices[0].AddIndex,
78✔
1059
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
78✔
1060
        }
78✔
1061

78✔
1062
        return res, nil
78✔
1063
}
1064

1065
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1066
// a SQL database as the backend.
1067
type sqlInvoiceUpdater struct {
1068
        db         SQLInvoiceQueries
1069
        ctx        context.Context //nolint:containedctx
1070
        invoice    *Invoice
1071
        updateTime time.Time
1072
}
1073

1074
// AddHtlc adds a new htlc to the invoice.
1075
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
1076
        newHtlc *InvoiceHTLC) error {
598✔
1077

598✔
1078
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
598✔
1079
                s.ctx, sqlc.InsertInvoiceHTLCParams{
598✔
1080
                        HtlcID: int64(circuitKey.HtlcID),
598✔
1081
                        ChanID: strconv.FormatUint(
598✔
1082
                                circuitKey.ChanID.ToUint64(), 10,
598✔
1083
                        ),
598✔
1084
                        AmountMsat: int64(newHtlc.Amt),
598✔
1085
                        TotalMppMsat: sql.NullInt64{
598✔
1086
                                Int64: int64(newHtlc.MppTotalAmt),
598✔
1087
                                Valid: newHtlc.MppTotalAmt != 0,
598✔
1088
                        },
598✔
1089
                        AcceptHeight: int32(newHtlc.AcceptHeight),
598✔
1090
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
598✔
1091
                        ExpiryHeight: int32(newHtlc.Expiry),
598✔
1092
                        State:        int16(newHtlc.State),
598✔
1093
                        InvoiceID:    int64(s.invoice.AddIndex),
598✔
1094
                },
598✔
1095
        )
598✔
1096
        if err != nil {
598✔
1097
                return err
×
1098
        }
×
1099

1100
        for key, value := range newHtlc.CustomRecords {
610✔
1101
                err = s.db.InsertInvoiceHTLCCustomRecord(
12✔
1102
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
12✔
1103
                                // TODO(bhandras): schema might be wrong here
12✔
1104
                                // as the custom record key is an uint64.
12✔
1105
                                Key:    int64(key),
12✔
1106
                                Value:  value,
12✔
1107
                                HtlcID: htlcPrimaryKeyID,
12✔
1108
                        },
12✔
1109
                )
12✔
1110
                if err != nil {
12✔
1111
                        return err
×
1112
                }
×
1113
        }
1114

1115
        if newHtlc.AMP != nil {
646✔
1116
                setID := newHtlc.AMP.Record.SetID()
48✔
1117

48✔
1118
                upsertResult, err := s.db.UpsertAMPSubInvoice(
48✔
1119
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
48✔
1120
                                SetID:     setID[:],
48✔
1121
                                CreatedAt: s.updateTime.UTC(),
48✔
1122
                                InvoiceID: int64(s.invoice.AddIndex),
48✔
1123
                        },
48✔
1124
                )
48✔
1125
                if err != nil {
50✔
1126
                        mappedSQLErr := sqldb.MapSQLError(err)
2✔
1127
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
2✔
1128
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
4✔
1129
                                return ErrDuplicateSetID{
2✔
1130
                                        SetID: setID,
2✔
1131
                                }
2✔
1132
                        }
2✔
1133

1134
                        return err
×
1135
                }
1136

1137
                // If we're just inserting the AMP invoice, we'll get a non
1138
                // zero rows affected count.
1139
                rowsAffected, err := upsertResult.RowsAffected()
46✔
1140
                if err != nil {
46✔
1141
                        return err
×
1142
                }
×
1143
                if rowsAffected != 0 {
78✔
1144
                        // If we're inserting a new AMP invoice, we'll also
32✔
1145
                        // insert a new invoice event.
32✔
1146
                        err = s.db.OnAMPSubInvoiceCreated(
32✔
1147
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
32✔
1148
                                        AddedAt:   s.updateTime.UTC(),
32✔
1149
                                        InvoiceID: int64(s.invoice.AddIndex),
32✔
1150
                                        SetID:     setID[:],
32✔
1151
                                },
32✔
1152
                        )
32✔
1153
                        if err != nil {
32✔
1154
                                return err
×
1155
                        }
×
1156
                }
1157

1158
                rootShare := newHtlc.AMP.Record.RootShare()
46✔
1159

46✔
1160
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
46✔
1161
                        InvoiceID: int64(s.invoice.AddIndex),
46✔
1162
                        SetID:     setID[:],
46✔
1163
                        HtlcID:    htlcPrimaryKeyID,
46✔
1164
                        RootShare: rootShare[:],
46✔
1165
                        ChildIndex: int64(
46✔
1166
                                newHtlc.AMP.Record.ChildIndex(),
46✔
1167
                        ),
46✔
1168
                        Hash: newHtlc.AMP.Hash[:],
46✔
1169
                }
46✔
1170

46✔
1171
                if newHtlc.AMP.Preimage != nil {
66✔
1172
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
20✔
1173
                }
20✔
1174

1175
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
46✔
1176
                if err != nil {
46✔
1177
                        return err
×
1178
                }
×
1179
        }
1180

1181
        return nil
596✔
1182
}
1183

1184
// ResolveHtlc marks an htlc as resolved with the given state.
1185
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
1186
        state HtlcState, resolveTime time.Time) error {
587✔
1187

587✔
1188
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
587✔
1189
                HtlcID: int64(circuitKey.HtlcID),
587✔
1190
                ChanID: strconv.FormatUint(
587✔
1191
                        circuitKey.ChanID.ToUint64(), 10,
587✔
1192
                ),
587✔
1193
                InvoiceID:   int64(s.invoice.AddIndex),
587✔
1194
                State:       int16(state),
587✔
1195
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
587✔
1196
        })
587✔
1197
}
587✔
1198

1199
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1200
// identified by the setID.
1201
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
1202
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
12✔
1203

12✔
1204
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
12✔
1205
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
12✔
1206
                        InvoiceID: int64(s.invoice.AddIndex),
12✔
1207
                        SetID:     setID[:],
12✔
1208
                        HtlcID:    int64(circuitKey.HtlcID),
12✔
1209
                        Preimage:  preimage[:],
12✔
1210
                        ChanID: strconv.FormatUint(
12✔
1211
                                circuitKey.ChanID.ToUint64(), 10,
12✔
1212
                        ),
12✔
1213
                },
12✔
1214
        )
12✔
1215
        if err != nil {
12✔
1216
                return err
×
1217
        }
×
1218

1219
        rowsAffected, err := result.RowsAffected()
12✔
1220
        if err != nil {
12✔
1221
                return err
×
1222
        }
×
1223
        if rowsAffected == 0 {
12✔
1224
                return ErrInvoiceNotFound
×
1225
        }
×
1226

1227
        return nil
12✔
1228
}
1229

1230
// UpdateInvoiceState updates the invoice state to the new state.
1231
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
1232
        newState ContractState, preimage *lntypes.Preimage) error {
398✔
1233

398✔
1234
        var (
398✔
1235
                settleIndex sql.NullInt64
398✔
1236
                settledAt   sql.NullTime
398✔
1237
        )
398✔
1238

398✔
1239
        switch newState {
398✔
1240
        case ContractSettled:
316✔
1241
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
316✔
1242
                if err != nil {
316✔
1243
                        return err
×
1244
                }
×
1245

1246
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
316✔
1247

316✔
1248
                // If the invoice is settled, we'll also update the settle time.
316✔
1249
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
316✔
1250

316✔
1251
                err = s.db.OnInvoiceSettled(
316✔
1252
                        s.ctx, sqlc.OnInvoiceSettledParams{
316✔
1253
                                AddedAt:   s.updateTime.UTC(),
316✔
1254
                                InvoiceID: int64(s.invoice.AddIndex),
316✔
1255
                        },
316✔
1256
                )
316✔
1257
                if err != nil {
316✔
1258
                        return err
×
1259
                }
×
1260

1261
        case ContractCanceled:
64✔
1262
                err := s.db.OnInvoiceCanceled(
64✔
1263
                        s.ctx, sqlc.OnInvoiceCanceledParams{
64✔
1264
                                AddedAt:   s.updateTime.UTC(),
64✔
1265
                                InvoiceID: int64(s.invoice.AddIndex),
64✔
1266
                        },
64✔
1267
                )
64✔
1268
                if err != nil {
64✔
1269
                        return err
×
1270
                }
×
1271
        }
1272

1273
        params := sqlc.UpdateInvoiceStateParams{
398✔
1274
                ID:          int64(s.invoice.AddIndex),
398✔
1275
                State:       int16(newState),
398✔
1276
                SettleIndex: settleIndex,
398✔
1277
                SettledAt:   settledAt,
398✔
1278
        }
398✔
1279

398✔
1280
        if preimage != nil {
406✔
1281
                params.Preimage = preimage[:]
8✔
1282
        }
8✔
1283

1284
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
398✔
1285
        if err != nil {
398✔
UNCOV
1286
                return err
×
UNCOV
1287
        }
×
1288
        rowsAffected, err := result.RowsAffected()
398✔
1289
        if err != nil {
398✔
1290
                return err
×
1291
        }
×
1292

1293
        if rowsAffected == 0 {
398✔
1294
                return ErrInvoiceNotFound
×
1295
        }
×
1296

1297
        if settleIndex.Valid {
714✔
1298
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
316✔
1299
                s.invoice.SettleDate = s.updateTime
316✔
1300
        }
316✔
1301

1302
        return nil
398✔
1303
}
1304

1305
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1306
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
1307
        amtPaid lnwire.MilliSatoshi) error {
634✔
1308

634✔
1309
        _, err := s.db.UpdateInvoiceAmountPaid(
634✔
1310
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
634✔
1311
                        ID:             int64(s.invoice.AddIndex),
634✔
1312
                        AmountPaidMsat: int64(amtPaid),
634✔
1313
                },
634✔
1314
        )
634✔
1315

634✔
1316
        return err
634✔
1317
}
634✔
1318

1319
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1320
// setID.
1321
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
1322
        newState InvoiceStateAMP, _ models.CircuitKey) error {
88✔
1323

88✔
1324
        var (
88✔
1325
                settleIndex sql.NullInt64
88✔
1326
                settledAt   sql.NullTime
88✔
1327
        )
88✔
1328

88✔
1329
        switch newState.State {
88✔
1330
        case HtlcStateSettled:
22✔
1331
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
22✔
1332
                if err != nil {
22✔
1333
                        return err
×
1334
                }
×
1335

1336
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
22✔
1337

22✔
1338
                // If the invoice is settled, we'll also update the settle time.
22✔
1339
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
22✔
1340

22✔
1341
                err = s.db.OnAMPSubInvoiceSettled(
22✔
1342
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
22✔
1343
                                AddedAt:   s.updateTime.UTC(),
22✔
1344
                                InvoiceID: int64(s.invoice.AddIndex),
22✔
1345
                                SetID:     setID[:],
22✔
1346
                        },
22✔
1347
                )
22✔
1348
                if err != nil {
22✔
1349
                        return err
×
1350
                }
×
1351

1352
        case HtlcStateCanceled:
22✔
1353
                err := s.db.OnAMPSubInvoiceCanceled(
22✔
1354
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
22✔
1355
                                AddedAt:   s.updateTime.UTC(),
22✔
1356
                                InvoiceID: int64(s.invoice.AddIndex),
22✔
1357
                                SetID:     setID[:],
22✔
1358
                        },
22✔
1359
                )
22✔
1360
                if err != nil {
22✔
1361
                        return err
×
1362
                }
×
1363
        }
1364

1365
        err := s.db.UpdateAMPSubInvoiceState(
88✔
1366
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
88✔
1367
                        SetID:       setID[:],
88✔
1368
                        State:       int16(newState.State),
88✔
1369
                        SettleIndex: settleIndex,
88✔
1370
                        SettledAt:   settledAt,
88✔
1371
                },
88✔
1372
        )
88✔
1373
        if err != nil {
88✔
1374
                return err
×
1375
        }
×
1376

1377
        if settleIndex.Valid {
110✔
1378
                updatedState := s.invoice.AMPState[setID]
22✔
1379
                updatedState.SettleIndex = uint64(settleIndex.Int64)
22✔
1380
                updatedState.SettleDate = s.updateTime.UTC()
22✔
1381
                s.invoice.AMPState[setID] = updatedState
22✔
1382
        }
22✔
1383

1384
        return nil
88✔
1385
}
1386

1387
// Finalize finalizes the update before it is written to the database. Note that
1388
// we don't use this directly in the SQL implementation, so the function is just
1389
// a stub.
1390
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
700✔
1391
        return nil
700✔
1392
}
700✔
1393

1394
// UpdateInvoice attempts to update an invoice corresponding to the passed
1395
// reference. If an invoice matching the passed reference doesn't exist within
1396
// the database, then the action will fail with  ErrInvoiceNotFound error.
1397
//
1398
// The update is performed inside the same database transaction that fetches the
1399
// invoice and is therefore atomic. The fields to update are controlled by the
1400
// supplied callback.
1401
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1402
        setID *SetID, callback InvoiceUpdateCallback) (
1403
        *Invoice, error) {
768✔
1404

768✔
1405
        var updatedInvoice *Invoice
768✔
1406

768✔
1407
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
768✔
1408
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
1,537✔
1409
                switch {
769✔
1410
                // For the default case we fetch all HTLCs.
1411
                case setID == nil:
686✔
1412
                        ref.refModifier = DefaultModifier
686✔
1413

1414
                // If the setID is the blank but NOT nil, we set the
1415
                // refModifier to HtlcSetBlankModifier to fetch no HTLC for the
1416
                // AMP invoice.
1417
                case *setID == BlankPayAddr:
×
1418
                        ref.refModifier = HtlcSetBlankModifier
×
1419

1420
                // A setID is provided, we use the refModifier to fetch only
1421
                // the HTLCs for the given setID and also make sure we add the
1422
                // setID to the ref.
1423
                default:
83✔
1424
                        var setIDBytes [32]byte
83✔
1425
                        copy(setIDBytes[:], setID[:])
83✔
1426
                        ref.setID = &setIDBytes
83✔
1427

83✔
1428
                        // We only fetch the HTLCs for the given setID.
83✔
1429
                        ref.refModifier = HtlcSetOnlyModifier
83✔
1430
                }
1431

1432
                invoice, err := fetchInvoice(ctx, db, ref)
769✔
1433
                if err != nil {
777✔
1434
                        return err
8✔
1435
                }
8✔
1436

1437
                updateTime := i.clock.Now()
761✔
1438
                updater := &sqlInvoiceUpdater{
761✔
1439
                        db:         db,
761✔
1440
                        ctx:        ctx,
761✔
1441
                        invoice:    invoice,
761✔
1442
                        updateTime: updateTime,
761✔
1443
                }
761✔
1444

761✔
1445
                payHash := ref.PayHash()
761✔
1446
                updatedInvoice, err = UpdateInvoice(
761✔
1447
                        payHash, invoice, updateTime, callback, updater,
761✔
1448
                )
761✔
1449

761✔
1450
                return err
761✔
1451
        }, func() {})
769✔
1452
        if txErr != nil {
800✔
1453
                // If the invoice is already settled, we'll return the
32✔
1454
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
32✔
1455
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
44✔
1456
                        return updatedInvoice, txErr
12✔
1457
                }
12✔
1458

1459
                return nil, txErr
20✔
1460
        }
1461

1462
        return updatedInvoice, nil
736✔
1463
}
1464

1465
// DeleteInvoice attempts to delete the passed invoices and all their related
1466
// data from the database in one transaction.
1467
func (i *SQLStore) DeleteInvoice(ctx context.Context,
1468
        invoicesToDelete []InvoiceDeleteRef) error {
12✔
1469

12✔
1470
        // All the InvoiceDeleteRef instances include the add index of the
12✔
1471
        // invoice. The rest was added to ensure that the invoices were deleted
12✔
1472
        // properly in the kv database. When we have fully migrated we can
12✔
1473
        // remove the rest of the fields.
12✔
1474
        for _, ref := range invoicesToDelete {
40✔
1475
                if ref.AddIndex == 0 {
28✔
1476
                        return fmt.Errorf("unable to delete invoice using a "+
×
1477
                                "ref without AddIndex set: %v", ref)
×
1478
                }
×
1479
        }
1480

1481
        var writeTxOpt SQLInvoiceQueriesTxOptions
12✔
1482
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
24✔
1483
                for _, ref := range invoicesToDelete {
34✔
1484
                        params := sqlc.DeleteInvoiceParams{
22✔
1485
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
22✔
1486
                        }
22✔
1487

22✔
1488
                        if ref.SettleIndex != 0 {
28✔
1489
                                params.SettleIndex = sqldb.SQLInt64(
6✔
1490
                                        ref.SettleIndex,
6✔
1491
                                )
6✔
1492
                        }
6✔
1493

1494
                        if ref.PayHash != lntypes.ZeroHash {
44✔
1495
                                params.Hash = ref.PayHash[:]
22✔
1496
                        }
22✔
1497

1498
                        result, err := db.DeleteInvoice(ctx, params)
22✔
1499
                        if err != nil {
22✔
1500
                                return fmt.Errorf("unable to delete "+
×
1501
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1502
                        }
×
1503
                        rowsAffected, err := result.RowsAffected()
22✔
1504
                        if err != nil {
22✔
1505
                                return fmt.Errorf("unable to get rows "+
×
1506
                                        "affected: %w", err)
×
1507
                        }
×
1508
                        if rowsAffected == 0 {
28✔
1509
                                return fmt.Errorf("%w: %v",
6✔
1510
                                        ErrInvoiceNotFound, ref.AddIndex)
6✔
1511
                        }
6✔
1512
                }
1513

1514
                return nil
6✔
1515
        }, func() {})
12✔
1516

1517
        if err != nil {
18✔
1518
                return fmt.Errorf("unable to delete invoices: %w", err)
6✔
1519
        }
6✔
1520

1521
        return nil
6✔
1522
}
1523

1524
// DeleteCanceledInvoices removes all canceled invoices from the database.
1525
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
6✔
1526
        var writeTxOpt SQLInvoiceQueriesTxOptions
6✔
1527
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
12✔
1528
                _, err := db.DeleteCanceledInvoices(ctx)
6✔
1529
                if err != nil {
6✔
1530
                        return fmt.Errorf("unable to delete canceled "+
×
1531
                                "invoices: %w", err)
×
1532
                }
×
1533

1534
                return nil
6✔
1535
        }, func() {})
6✔
1536
        if err != nil {
6✔
1537
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1538
        }
×
1539

1540
        return nil
6✔
1541
}
1542

1543
// fetchInvoiceData fetches additional data for the given invoice. If the
1544
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1545
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1546
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1547
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1548
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
1549
        *Invoice, error) {
22,719✔
1550

22,719✔
1551
        // Unmarshal the common data.
22,719✔
1552
        hash, invoice, err := unmarshalInvoice(row)
22,719✔
1553
        if err != nil {
22,719✔
1554
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1555
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1556
        }
×
1557

1558
        // Fetch the invoice features.
1559
        features, err := getInvoiceFeatures(ctx, db, row.ID)
22,719✔
1560
        if err != nil {
22,719✔
1561
                return nil, nil, err
×
1562
        }
×
1563

1564
        invoice.Terms.Features = features
22,719✔
1565

22,719✔
1566
        // If this is an AMP invoice, we'll need fetch the AMP state along
22,719✔
1567
        // with the HTLCs (if requested).
22,719✔
1568
        if invoice.IsAMP() {
33,892✔
1569
                invoiceID := int64(invoice.AddIndex)
11,173✔
1570
                ampState, ampHtlcs, err := fetchAmpState(
11,173✔
1571
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
11,173✔
1572
                )
11,173✔
1573
                if err != nil {
11,173✔
1574
                        return nil, nil, err
×
1575
                }
×
1576

1577
                invoice.AMPState = ampState
11,173✔
1578
                invoice.Htlcs = ampHtlcs
11,173✔
1579

11,173✔
1580
                return hash, invoice, nil
11,173✔
1581
        }
1582

1583
        // Otherwise simply fetch the invoice HTLCs.
1584
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
11,546✔
1585
        if err != nil {
11,546✔
1586
                return nil, nil, err
×
1587
        }
×
1588

1589
        if len(htlcs) > 0 {
19,272✔
1590
                invoice.Htlcs = htlcs
7,726✔
1591
        }
7,726✔
1592

1593
        return hash, invoice, nil
11,546✔
1594
}
1595

1596
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1597
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
1598
        invoiceID int64) (*lnwire.FeatureVector, error) {
22,719✔
1599

22,719✔
1600
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
22,719✔
1601
        if err != nil {
22,719✔
1602
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1603
                        err)
×
1604
        }
×
1605

1606
        features := lnwire.EmptyFeatureVector()
22,719✔
1607
        for _, feature := range rows {
38,478✔
1608
                features.Set(lnwire.FeatureBit(feature.Feature))
15,759✔
1609
        }
15,759✔
1610

1611
        return features, nil
22,719✔
1612
}
1613

1614
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1615
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
1616
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
11,546✔
1617

11,546✔
1618
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
11,546✔
1619
        if err != nil {
11,546✔
1620
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1621
        }
×
1622

1623
        // We have no htlcs to unmarshal.
1624
        if len(htlcRows) == 0 {
15,366✔
1625
                return nil, nil
3,820✔
1626
        }
3,820✔
1627

1628
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
7,726✔
1629
        if err != nil {
7,726✔
1630
                return nil, fmt.Errorf("unable to get custom records for "+
×
1631
                        "invoice htlcs: %w", err)
×
1632
        }
×
1633

1634
        cr := make(map[int64]record.CustomSet, len(crRows))
7,726✔
1635
        for _, row := range crRows {
46,440✔
1636
                if _, ok := cr[row.HtlcID]; !ok {
52,808✔
1637
                        cr[row.HtlcID] = make(record.CustomSet)
14,094✔
1638
                }
14,094✔
1639

1640
                value := row.Value
38,714✔
1641
                if value == nil {
38,715✔
1642
                        value = []byte{}
1✔
1643
                }
1✔
1644
                cr[row.HtlcID][uint64(row.Key)] = value
38,714✔
1645
        }
1646

1647
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
7,726✔
1648

7,726✔
1649
        for _, row := range htlcRows {
26,969✔
1650
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
19,243✔
1651
                if err != nil {
19,243✔
1652
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1653
                                "htlc(%d): %w", row.ID, err)
×
1654
                }
×
1655

1656
                if customRecords, ok := cr[row.ID]; ok {
33,337✔
1657
                        htlc.CustomRecords = customRecords
14,094✔
1658
                } else {
19,243✔
1659
                        htlc.CustomRecords = make(record.CustomSet)
5,149✔
1660
                }
5,149✔
1661

1662
                htlcs[circuiteKey] = htlc
19,243✔
1663
        }
1664

1665
        return htlcs, nil
7,726✔
1666
}
1667

1668
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1669
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
1670
        error) {
22,719✔
1671

22,719✔
1672
        var (
22,719✔
1673
                settleIndex    int64
22,719✔
1674
                settledAt      time.Time
22,719✔
1675
                memo           []byte
22,719✔
1676
                paymentRequest []byte
22,719✔
1677
                preimage       *lntypes.Preimage
22,719✔
1678
                paymentAddr    [32]byte
22,719✔
1679
        )
22,719✔
1680

22,719✔
1681
        hash, err := lntypes.MakeHash(row.Hash)
22,719✔
1682
        if err != nil {
22,719✔
1683
                return nil, nil, err
×
1684
        }
×
1685

1686
        if row.SettleIndex.Valid {
28,669✔
1687
                settleIndex = row.SettleIndex.Int64
5,950✔
1688
        }
5,950✔
1689

1690
        if row.SettledAt.Valid {
28,669✔
1691
                settledAt = row.SettledAt.Time.Local()
5,950✔
1692
        }
5,950✔
1693

1694
        if row.Memo.Valid {
44,037✔
1695
                memo = []byte(row.Memo.String)
21,318✔
1696
        }
21,318✔
1697

1698
        // Keysend payments will have this field empty.
1699
        if row.PaymentRequest.Valid {
43,558✔
1700
                paymentRequest = []byte(row.PaymentRequest.String)
20,839✔
1701
        } else {
22,719✔
1702
                paymentRequest = []byte{}
1,880✔
1703
        }
1,880✔
1704

1705
        // We may not have the preimage if this a hodl invoice.
1706
        if row.Preimage != nil {
45,301✔
1707
                preimage = &lntypes.Preimage{}
22,582✔
1708
                copy(preimage[:], row.Preimage)
22,582✔
1709
        }
22,582✔
1710

1711
        copy(paymentAddr[:], row.PaymentAddr)
22,719✔
1712

22,719✔
1713
        var cltvDelta int32
22,719✔
1714
        if row.CltvDelta.Valid {
45,438✔
1715
                cltvDelta = row.CltvDelta.Int32
22,719✔
1716
        }
22,719✔
1717

1718
        expiry := time.Duration(row.Expiry) * time.Second
22,719✔
1719

22,719✔
1720
        invoice := &Invoice{
22,719✔
1721
                SettleIndex:    uint64(settleIndex),
22,719✔
1722
                SettleDate:     settledAt,
22,719✔
1723
                Memo:           memo,
22,719✔
1724
                PaymentRequest: paymentRequest,
22,719✔
1725
                CreationDate:   row.CreatedAt.Local(),
22,719✔
1726
                Terms: ContractTerm{
22,719✔
1727
                        FinalCltvDelta:  cltvDelta,
22,719✔
1728
                        Expiry:          expiry,
22,719✔
1729
                        PaymentPreimage: preimage,
22,719✔
1730
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
22,719✔
1731
                        PaymentAddr:     paymentAddr,
22,719✔
1732
                },
22,719✔
1733
                AddIndex:    uint64(row.ID),
22,719✔
1734
                State:       ContractState(row.State),
22,719✔
1735
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
22,719✔
1736
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
22,719✔
1737
                AMPState:    AMPInvoiceState{},
22,719✔
1738
                HodlInvoice: row.IsHodl,
22,719✔
1739
        }
22,719✔
1740

22,719✔
1741
        return &hash, invoice, nil
22,719✔
1742
}
1743

1744
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1745
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
1746
        *InvoiceHTLC, error) {
19,243✔
1747

19,243✔
1748
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
19,243✔
1749
        if err != nil {
19,243✔
1750
                return CircuitKey{}, nil, err
×
1751
        }
×
1752

1753
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
19,243✔
1754

19,243✔
1755
        if row.HtlcID < 0 {
19,243✔
1756
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1757
                        "value: %v", row.HtlcID)
×
1758
        }
×
1759

1760
        htlcID := uint64(row.HtlcID)
19,243✔
1761

19,243✔
1762
        circuitKey := CircuitKey{
19,243✔
1763
                ChanID: chanID,
19,243✔
1764
                HtlcID: htlcID,
19,243✔
1765
        }
19,243✔
1766

19,243✔
1767
        htlc := &InvoiceHTLC{
19,243✔
1768
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
19,243✔
1769
                AcceptHeight: uint32(row.AcceptHeight),
19,243✔
1770
                AcceptTime:   row.AcceptTime.Local(),
19,243✔
1771
                Expiry:       uint32(row.ExpiryHeight),
19,243✔
1772
                State:        HtlcState(row.State),
19,243✔
1773
        }
19,243✔
1774

19,243✔
1775
        if row.TotalMppMsat.Valid {
34,495✔
1776
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
15,252✔
1777
        }
15,252✔
1778

1779
        if row.ResolveTime.Valid {
32,018✔
1780
                htlc.ResolveTime = row.ResolveTime.Time.Local()
12,775✔
1781
        }
12,775✔
1782

1783
        return circuitKey, htlc, nil
19,243✔
1784
}
1785

1786
// queryWithLimit is a helper method that can be used to query the database
1787
// using a limit and offset. The passed query function should return the number
1788
// of rows returned and an error if any.
1789
func queryWithLimit(query func(int) (int, error), limit int) error {
364✔
1790
        offset := 0
364✔
1791
        for {
1,064✔
1792
                rows, err := query(offset)
700✔
1793
                if err != nil {
700✔
1794
                        return err
×
1795
                }
×
1796

1797
                if rows < limit {
1,064✔
1798
                        return nil
364✔
1799
                }
364✔
1800

1801
                offset += limit
336✔
1802
        }
1803
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc