• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 12203027929

06 Dec 2024 05:09PM UTC coverage: 59.0% (+9.2%) from 49.807%
12203027929

Pull #8831

github

bhandras
docs: update release notes for 0.19.0
Pull Request #8831: invoices: migrate KV invoices to native SQL for users of KV SQL backends

513 of 692 new or added lines in 12 files covered. (74.13%)

61 existing lines in 14 files now uncovered.

133954 of 227042 relevant lines covered (59.0%)

19694.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.15
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27
)
28

29
// SQLInvoiceQueries is an interface that defines the set of operations that can
30
// be executed against the invoice SQL database.
31
type SQLInvoiceQueries interface { //nolint:interfacebloat
32
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
33
                error)
34

35
        InsertInvoiceFeature(ctx context.Context,
36
                arg sqlc.InsertInvoiceFeatureParams) error
37

38
        InsertInvoiceHTLC(ctx context.Context,
39
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
40

41
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
42
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
43

44
        FilterInvoices(ctx context.Context,
45
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
46

47
        GetInvoice(ctx context.Context,
48
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
49

50
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
51
                error)
52

53
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
54
                error)
55

56
        GetInvoiceFeatures(ctx context.Context,
57
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
58

59
        GetInvoiceHTLCCustomRecords(ctx context.Context,
60
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
61

62
        GetInvoiceHTLCs(ctx context.Context,
63
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
64

65
        UpdateInvoiceState(ctx context.Context,
66
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
67

68
        UpdateInvoiceAmountPaid(ctx context.Context,
69
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
70

71
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
72

73
        UpdateInvoiceHTLC(ctx context.Context,
74
                arg sqlc.UpdateInvoiceHTLCParams) error
75

76
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
77
                sql.Result, error)
78

79
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
80

81
        // AMP sub invoice specific methods.
82
        UpsertAMPSubInvoice(ctx context.Context,
83
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
84

85
        InsertAMPSubInvoice(ctx context.Context,
86
                arg sqlc.InsertAMPSubInvoiceParams) error
87

88
        UpdateAMPSubInvoiceState(ctx context.Context,
89
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
90

91
        InsertAMPSubInvoiceHTLC(ctx context.Context,
92
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
93

94
        FetchAMPSubInvoices(ctx context.Context,
95
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
96
                error)
97

98
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
99
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
100
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
101

102
        FetchSettledAMPSubInvoices(ctx context.Context,
103
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
104
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
105

106
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
107
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
108
                error)
109

110
        // Invoice events specific methods.
111
        OnInvoiceCreated(ctx context.Context,
112
                arg sqlc.OnInvoiceCreatedParams) error
113

114
        OnInvoiceCanceled(ctx context.Context,
115
                arg sqlc.OnInvoiceCanceledParams) error
116

117
        OnInvoiceSettled(ctx context.Context,
118
                arg sqlc.OnInvoiceSettledParams) error
119

120
        OnAMPSubInvoiceCreated(ctx context.Context,
121
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
122

123
        OnAMPSubInvoiceCanceled(ctx context.Context,
124
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
125

126
        OnAMPSubInvoiceSettled(ctx context.Context,
127
                arg sqlc.OnAMPSubInvoiceSettledParams) error
128

129
        // Migration specific methods.
130
        InsertInvoicePaymentHashAndKey(ctx context.Context,
131
                arg sqlc.InsertInvoicePaymentHashAndKeyParams) error
132

133
        SetInvoicePaymentHashAddIndex(ctx context.Context,
134
                arg sqlc.SetInvoicePaymentHashAddIndexParams) error
135

136
        GetInvoicePaymentHashByAddIndex(ctx context.Context,
137
                addIndex sql.NullInt64) ([]byte, error)
138

139
        ClearInvoiceHashIndex(ctx context.Context) error
140
}
141

142
var _ InvoiceDB = (*SQLStore)(nil)
143

144
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
145
// SQLInvoiceQueries understands.
146
type SQLInvoiceQueriesTxOptions struct {
147
        // readOnly governs if a read only transaction is needed or not.
148
        readOnly bool
149
}
150

151
// ReadOnly returns true if the transaction should be read only.
152
//
153
// NOTE: This implements the TxOptions.
154
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
22,716✔
155
        return a.readOnly
22,716✔
156
}
22,716✔
157

158
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
159
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
21,176✔
160
        return SQLInvoiceQueriesTxOptions{
21,176✔
161
                readOnly: true,
21,176✔
162
        }
21,176✔
163
}
21,176✔
164

165
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
166
// of batched database operations.
167
type BatchedSQLInvoiceQueries interface {
168
        SQLInvoiceQueries
169

170
        sqldb.BatchedTx[SQLInvoiceQueries]
171
}
172

173
// SQLStore represents a storage backend.
174
type SQLStore struct {
175
        db    BatchedSQLInvoiceQueries
176
        clock clock.Clock
177
        opts  SQLStoreOptions
178
}
179

180
// SQLStoreOptions holds the options for the SQL store.
181
type SQLStoreOptions struct {
182
        paginationLimit int
183
}
184

185
// defaultSQLStoreOptions returns the default options for the SQL store.
186
func defaultSQLStoreOptions() SQLStoreOptions {
506✔
187
        return SQLStoreOptions{
506✔
188
                paginationLimit: defaultQueryPaginationLimit,
506✔
189
        }
506✔
190
}
506✔
191

192
// SQLStoreOption is a functional option that can be used to optionally modify
193
// the behavior of the SQL store.
194
type SQLStoreOption func(*SQLStoreOptions)
195

196
// WithPaginationLimit sets the pagination limit for the SQL store queries that
197
// paginate results.
198
func WithPaginationLimit(limit int) SQLStoreOption {
50✔
199
        return func(o *SQLStoreOptions) {
100✔
200
                o.paginationLimit = limit
50✔
201
        }
50✔
202
}
203

204
// NewSQLStore creates a new SQLStore instance given a open
205
// BatchedSQLInvoiceQueries storage backend.
206
func NewSQLStore(db BatchedSQLInvoiceQueries,
207
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
506✔
208

506✔
209
        opts := defaultSQLStoreOptions()
506✔
210
        for _, applyOption := range options {
556✔
211
                applyOption(&opts)
50✔
212
        }
50✔
213

214
        return &SQLStore{
506✔
215
                db:    db,
506✔
216
                clock: clock,
506✔
217
                opts:  opts,
506✔
218
        }
506✔
219
}
220

221
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
222
        sqlc.InsertInvoiceParams, error) {
20,590✔
223

20,590✔
224
        // Precompute the payment request hash so we can use it in the query.
20,590✔
225
        var paymentRequestHash []byte
20,590✔
226
        if len(invoice.PaymentRequest) > 0 {
40,778✔
227
                h := sha256.New()
20,188✔
228
                h.Write(invoice.PaymentRequest)
20,188✔
229
                paymentRequestHash = h.Sum(nil)
20,188✔
230
        }
20,188✔
231

232
        params := sqlc.InsertInvoiceParams{
20,590✔
233
                Hash:       paymentHash[:],
20,590✔
234
                AmountMsat: int64(invoice.Terms.Value),
20,590✔
235
                CltvDelta: sqldb.SQLInt32(
20,590✔
236
                        invoice.Terms.FinalCltvDelta,
20,590✔
237
                ),
20,590✔
238
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
20,590✔
239
                // Note: keysend invoices don't have a payment request.
20,590✔
240
                PaymentRequest: sqldb.SQLStr(string(
20,590✔
241
                        invoice.PaymentRequest),
20,590✔
242
                ),
20,590✔
243
                PaymentRequestHash: paymentRequestHash,
20,590✔
244
                State:              int16(invoice.State),
20,590✔
245
                AmountPaidMsat:     int64(invoice.AmtPaid),
20,590✔
246
                IsAmp:              invoice.IsAMP(),
20,590✔
247
                IsHodl:             invoice.HodlInvoice,
20,590✔
248
                IsKeysend:          invoice.IsKeysend(),
20,590✔
249
                CreatedAt:          invoice.CreationDate.UTC(),
20,590✔
250
        }
20,590✔
251

20,590✔
252
        if invoice.Memo != nil {
40,884✔
253
                // Store the memo as a nullable string in the database. Note
20,294✔
254
                // that for compatibility reasons, we store the value as a valid
20,294✔
255
                // string even if it's empty.
20,294✔
256
                params.Memo = sql.NullString{
20,294✔
257
                        String: string(invoice.Memo),
20,294✔
258
                        Valid:  true,
20,294✔
259
                }
20,294✔
260
        }
20,294✔
261

262
        // Some invoices may not have a preimage, like in the case of HODL
263
        // invoices.
264
        if invoice.Terms.PaymentPreimage != nil {
41,154✔
265
                preimage := *invoice.Terms.PaymentPreimage
20,564✔
266
                if preimage == UnknownPreimage {
20,564✔
NEW
267
                        return sqlc.InsertInvoiceParams{},
×
NEW
268
                                errors.New("cannot use all-zeroes preimage")
×
NEW
269
                }
×
270
                params.Preimage = preimage[:]
20,564✔
271
        }
272

273
        // Some non MPP payments may have the default (invalid) value.
274
        if invoice.Terms.PaymentAddr != BlankPayAddr {
20,936✔
275
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
346✔
276
        }
346✔
277

278
        return params, nil
20,590✔
279
}
280

281
// AddInvoice inserts the targeted invoice into the database. If the invoice has
282
// *any* payment hashes which already exists within the database, then the
283
// insertion will be aborted and rejected due to the strict policy banning any
284
// duplicate payment hashes.
285
//
286
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
287
func (i *SQLStore) AddInvoice(ctx context.Context,
288
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
590✔
289

590✔
290
        // Make sure this is a valid invoice before trying to store it in our
590✔
291
        // DB.
590✔
292
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
594✔
293
                return 0, err
4✔
294
        }
4✔
295

296
        var (
586✔
297
                writeTxOpts SQLInvoiceQueriesTxOptions
586✔
298
                invoiceID   int64
586✔
299
        )
586✔
300

586✔
301
        insertInvoiceParams, err := makeInsertInvoiceParams(
586✔
302
                newInvoice, paymentHash,
586✔
303
        )
586✔
304
        if err != nil {
586✔
NEW
305
                return 0, err
×
UNCOV
306
        }
×
307

308
        err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
1,172✔
309
                var err error
586✔
310
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
586✔
311
                if err != nil {
602✔
312
                        return fmt.Errorf("unable to insert invoice: %w", err)
16✔
313
                }
16✔
314

315
                // TODO(positiveblue): if invocies do not have custom features
316
                // maybe just store the "invoice type" and populate the features
317
                // based on that.
318
                for feature := range newInvoice.Terms.Features.Features() {
652✔
319
                        params := sqlc.InsertInvoiceFeatureParams{
82✔
320
                                InvoiceID: invoiceID,
82✔
321
                                Feature:   int32(feature),
82✔
322
                        }
82✔
323

82✔
324
                        err := db.InsertInvoiceFeature(ctx, params)
82✔
325
                        if err != nil {
82✔
326
                                return fmt.Errorf("unable to insert invoice "+
×
327
                                        "feature(%v): %w", feature, err)
×
328
                        }
×
329
                }
330

331
                // Finally add a new event for this invoice.
332
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
570✔
333
                        AddedAt:   newInvoice.CreationDate.UTC(),
570✔
334
                        InvoiceID: invoiceID,
570✔
335
                })
570✔
336
        }, func() {})
586✔
337
        if err != nil {
602✔
338
                mappedSQLErr := sqldb.MapSQLError(err)
16✔
339
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
16✔
340
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
32✔
341
                        // Add context to unique constraint errors.
16✔
342
                        return 0, ErrDuplicateInvoice
16✔
343
                }
16✔
344

345
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
346
                        paymentHash, err)
×
347
        }
348

349
        newInvoice.AddIndex = uint64(invoiceID)
570✔
350

570✔
351
        return newInvoice.AddIndex, nil
570✔
352
}
353

354
// getInvoiceByRef fetches the invoice with the given reference. The reference
355
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
356
func getInvoiceByRef(ctx context.Context,
357
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
21,558✔
358

21,558✔
359
        // If the reference is empty, we can't look up the invoice.
21,558✔
360
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
21,560✔
361
                return sqlc.Invoice{}, ErrInvoiceNotFound
2✔
362
        }
2✔
363

364
        // If the reference is a hash only, we can look up the invoice directly
365
        // by the payment hash which is faster.
366
        if ref.IsHashOnly() {
42,130✔
367
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
20,574✔
368
                if errors.Is(err, sql.ErrNoRows) {
20,606✔
369
                        return sqlc.Invoice{}, ErrInvoiceNotFound
32✔
370
                }
32✔
371

372
                return invoice, err
20,542✔
373
        }
374

375
        // Otherwise the reference may include more fields, so we'll need to
376
        // assemble the query parameters based on the fields that are set.
377
        var params sqlc.GetInvoiceParams
982✔
378

982✔
379
        if ref.PayHash() != nil {
1,898✔
380
                params.Hash = ref.PayHash()[:]
916✔
381
        }
916✔
382

383
        // Newer invoices (0.11 and up) are indexed by payment address in
384
        // addition to payment hash, but pre 0.8 invoices do not have one at
385
        // all. Only allow lookups for payment address if it is not a blank
386
        // payment address, which is a special-cased value for legacy keysend
387
        // invoices.
388
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
1,098✔
389
                params.PaymentAddr = ref.PayAddr()[:]
116✔
390
        }
116✔
391

392
        // If the reference has a set ID we'll fetch the invoice which has the
393
        // corresponding AMP sub invoice.
394
        if ref.SetID() != nil {
1,058✔
395
                params.SetID = ref.SetID()[:]
76✔
396
        }
76✔
397

398
        var (
982✔
399
                rows []sqlc.Invoice
982✔
400
                err  error
982✔
401
        )
982✔
402

982✔
403
        // We need to split the query based on how we intend to look up the
982✔
404
        // invoice. If only the set ID is given then we want to have an exact
982✔
405
        // match on the set ID. If other fields are given, we want to match on
982✔
406
        // those fields and the set ID but with a less strict join condition.
982✔
407
        if params.Hash == nil && params.PaymentAddr == nil &&
982✔
408
                params.SetID != nil {
996✔
409

14✔
410
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
14✔
411
        } else {
982✔
412
                rows, err = db.GetInvoice(ctx, params)
968✔
413
        }
968✔
414

415
        switch {
982✔
416
        case len(rows) == 0:
14✔
417
                return sqlc.Invoice{}, ErrInvoiceNotFound
14✔
418

419
        case len(rows) > 1:
×
420
                // In case the reference is ambiguous, meaning it matches more
×
421
                // than        one invoice, we'll return an error.
×
NEW
422
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
×
NEW
423
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
424

425
        case err != nil:
×
NEW
426
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
NEW
427
                        err)
×
428
        }
429

430
        return rows[0], nil
968✔
431
}
432

433
// fetchInvoice fetches the common invoice data and the AMP state for the
434
// invoice with the given reference.
435
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
436
        *Invoice, error) {
21,558✔
437

21,558✔
438
        // Fetch the invoice from the database.
21,558✔
439
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
21,558✔
440
        if err != nil {
21,606✔
441
                return nil, err
48✔
442
        }
48✔
443

444
        var (
21,510✔
445
                setID         *[32]byte
21,510✔
446
                fetchAmpHtlcs bool
21,510✔
447
        )
21,510✔
448

21,510✔
449
        // Now that we got the invoice itself, fetch the HTLCs as requested by
21,510✔
450
        // the modifier.
21,510✔
451
        switch ref.Modifier() {
21,510✔
452
        case DefaultModifier:
21,438✔
453
                // By default we'll fetch all AMP HTLCs.
21,438✔
454
                setID = nil
21,438✔
455
                fetchAmpHtlcs = true
21,438✔
456

457
        case HtlcSetOnlyModifier:
70✔
458
                // In this case we'll fetch all AMP HTLCs for the specified set
70✔
459
                // id.
70✔
460
                if ref.SetID() == nil {
70✔
461
                        return nil, fmt.Errorf("set ID is required to use " +
×
462
                                "the HTLC set only modifier")
×
463
                }
×
464

465
                setID = ref.SetID()
70✔
466
                fetchAmpHtlcs = true
70✔
467

468
        case HtlcSetBlankModifier:
2✔
469
                // No need to fetch any HTLCs.
2✔
470
                setID = nil
2✔
471
                fetchAmpHtlcs = false
2✔
472

473
        default:
×
474
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
475
                        ref.Modifier())
×
476
        }
477

478
        // Fetch the rest of the invoice data and fill the invoice struct.
479
        _, invoice, err := fetchInvoiceData(
21,510✔
480
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
21,510✔
481
        )
21,510✔
482
        if err != nil {
21,510✔
483
                return nil, err
×
484
        }
×
485

486
        return invoice, nil
21,510✔
487
}
488

489
// fetchAmpState fetches the AMP state for the invoice with the given ID.
490
// Optional setID can be provided to fetch the state for a specific AMP HTLC
491
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
492
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
493
// well.
494
//
495
//nolint:funlen
496
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
497
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
498
        HTLCSet, error) {
10,124✔
499

10,124✔
500
        var paramSetID []byte
10,124✔
501
        if setID != nil {
10,198✔
502
                paramSetID = setID[:]
74✔
503
        }
74✔
504

505
        // First fetch all the AMP sub invoices for this invoice or the one
506
        // matching the provided set ID.
507
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
10,124✔
508
                ctx, sqlc.FetchAMPSubInvoicesParams{
10,124✔
509
                        InvoiceID: invoiceID,
10,124✔
510
                        SetID:     paramSetID,
10,124✔
511
                },
10,124✔
512
        )
10,124✔
513
        if err != nil {
10,124✔
514
                return nil, nil, err
×
515
        }
×
516

517
        ampState := make(map[SetID]InvoiceStateAMP)
10,124✔
518
        for _, row := range ampInvoiceRows {
38,468✔
519
                var rowSetID [32]byte
28,344✔
520

28,344✔
521
                if len(row.SetID) != 32 {
28,344✔
522
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
523
                                len(row.SetID))
×
524
                }
×
525

526
                var settleDate time.Time
28,344✔
527
                if row.SettledAt.Valid {
36,083✔
528
                        settleDate = row.SettledAt.Time.Local()
7,739✔
529
                }
7,739✔
530

531
                copy(rowSetID[:], row.SetID)
28,344✔
532
                ampState[rowSetID] = InvoiceStateAMP{
28,344✔
533
                        State:       HtlcState(row.State),
28,344✔
534
                        SettleIndex: uint64(row.SettleIndex.Int64),
28,344✔
535
                        SettleDate:  settleDate,
28,344✔
536
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
28,344✔
537
                }
28,344✔
538
        }
539

540
        if !fetchHtlcs {
10,126✔
541
                return ampState, nil, nil
2✔
542
        }
2✔
543

544
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
10,122✔
545
        if err != nil {
10,122✔
546
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
547
                        "invoice HTLCs: %w", err)
×
548
        }
×
549

550
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
10,122✔
551
        for _, row := range customRecordRows {
184,579✔
552
                if _, ok := customRecords[row.HtlcID]; !ok {
237,979✔
553
                        customRecords[row.HtlcID] = make(record.CustomSet)
63,522✔
554
                }
63,522✔
555

556
                value := row.Value
174,457✔
557
                if value == nil {
174,457✔
558
                        value = []byte{}
×
559
                }
×
560

561
                customRecords[row.HtlcID][uint64(row.Key)] = value
174,457✔
562
        }
563

564
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
565
        // provided set ID.
566
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
10,122✔
567
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
10,122✔
568
                        InvoiceID: invoiceID,
10,122✔
569
                        SetID:     paramSetID,
10,122✔
570
                },
10,122✔
571
        )
10,122✔
572
        if err != nil {
10,122✔
573
                return nil, nil, err
×
574
        }
×
575

576
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
10,122✔
577
        for _, row := range ampHtlcRows {
90,462✔
578
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
80,340✔
579
                if err != nil {
80,340✔
580
                        return nil, nil, err
×
581
                }
×
582

583
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
80,340✔
584

80,340✔
585
                if row.HtlcID < 0 {
80,340✔
586
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
587
                                "value: %v", row.HtlcID)
×
588
                }
×
589

590
                htlcID := uint64(row.HtlcID)
80,340✔
591

80,340✔
592
                circuitKey := CircuitKey{
80,340✔
593
                        ChanID: chanID,
80,340✔
594
                        HtlcID: htlcID,
80,340✔
595
                }
80,340✔
596

80,340✔
597
                htlc := &InvoiceHTLC{
80,340✔
598
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
80,340✔
599
                        AcceptHeight: uint32(row.AcceptHeight),
80,340✔
600
                        AcceptTime:   row.AcceptTime.Local(),
80,340✔
601
                        Expiry:       uint32(row.ExpiryHeight),
80,340✔
602
                        State:        HtlcState(row.State),
80,340✔
603
                }
80,340✔
604

80,340✔
605
                if row.TotalMppMsat.Valid {
80,376✔
606
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
36✔
607
                                row.TotalMppMsat.Int64,
36✔
608
                        )
36✔
609
                }
36✔
610

611
                if row.ResolveTime.Valid {
119,906✔
612
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
39,566✔
613
                }
39,566✔
614

615
                var (
80,340✔
616
                        rootShare [32]byte
80,340✔
617
                        setID     [32]byte
80,340✔
618
                )
80,340✔
619

80,340✔
620
                if len(row.RootShare) != 32 {
80,340✔
621
                        return nil, nil, fmt.Errorf("invalid root share "+
×
622
                                "length: %d", len(row.RootShare))
×
623
                }
×
624
                copy(rootShare[:], row.RootShare)
80,340✔
625

80,340✔
626
                if len(row.SetID) != 32 {
80,340✔
627
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
628
                                len(row.SetID))
×
629
                }
×
630
                copy(setID[:], row.SetID)
80,340✔
631

80,340✔
632
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
80,340✔
633
                        return nil, nil, fmt.Errorf("invalid child index "+
×
634
                                "value: %v", row.ChildIndex)
×
635
                }
×
636

637
                ampRecord := record.NewAMP(
80,340✔
638
                        rootShare, setID, uint32(row.ChildIndex),
80,340✔
639
                )
80,340✔
640

80,340✔
641
                htlc.AMP = &InvoiceHtlcAMPData{
80,340✔
642
                        Record: *ampRecord,
80,340✔
643
                }
80,340✔
644

80,340✔
645
                if len(row.Hash) != 32 {
80,340✔
646
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
647
                                len(row.Hash))
×
648
                }
×
649
                copy(htlc.AMP.Hash[:], row.Hash)
80,340✔
650

80,340✔
651
                if row.Preimage != nil {
160,634✔
652
                        preimage, err := lntypes.MakePreimage(row.Preimage)
80,294✔
653
                        if err != nil {
80,294✔
654
                                return nil, nil, err
×
655
                        }
×
656

657
                        htlc.AMP.Preimage = &preimage
80,294✔
658
                }
659

660
                if _, ok := customRecords[row.ID]; ok {
143,862✔
661
                        htlc.CustomRecords = customRecords[row.ID]
63,522✔
662
                } else {
80,340✔
663
                        htlc.CustomRecords = make(record.CustomSet)
16,818✔
664
                }
16,818✔
665

666
                ampHtlcs[circuitKey] = htlc
80,340✔
667
        }
668

669
        if len(ampHtlcs) > 0 {
20,202✔
670
                for setID := range ampState {
38,418✔
671
                        var amtPaid lnwire.MilliSatoshi
28,338✔
672
                        invoiceKeys := make(
28,338✔
673
                                map[models.CircuitKey]struct{},
28,338✔
674
                        )
28,338✔
675

28,338✔
676
                        for key, htlc := range ampHtlcs {
312,088✔
677
                                if htlc.AMP.Record.SetID() != setID {
487,160✔
678
                                        continue
203,410✔
679
                                }
680

681
                                invoiceKeys[key] = struct{}{}
80,340✔
682

80,340✔
683
                                if htlc.State != HtlcStateCanceled {
143,090✔
684
                                        amtPaid += htlc.Amt
62,750✔
685
                                }
62,750✔
686
                        }
687

688
                        setState := ampState[setID]
28,338✔
689
                        setState.InvoiceKeys = invoiceKeys
28,338✔
690
                        setState.AmtPaid = amtPaid
28,338✔
691
                        ampState[setID] = setState
28,338✔
692
                }
693
        }
694

695
        return ampState, ampHtlcs, nil
10,122✔
696
}
697

698
// LookupInvoice attempts to look up an invoice corresponding the passed in
699
// reference. The reference may be a payment hash, a payment address, or a set
700
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
701
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
702
// error.
703
func (i *SQLStore) LookupInvoice(ctx context.Context,
704
        ref InvoiceRef) (Invoice, error) {
20,818✔
705

20,818✔
706
        var (
20,818✔
707
                invoice *Invoice
20,818✔
708
                err     error
20,818✔
709
        )
20,818✔
710

20,818✔
711
        readTxOpt := NewSQLInvoiceQueryReadTx()
20,818✔
712
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
41,636✔
713
                invoice, err = fetchInvoice(ctx, db, ref)
20,818✔
714

20,818✔
715
                return err
20,818✔
716
        }, func() {})
41,636✔
717
        if txErr != nil {
20,858✔
718
                return Invoice{}, txErr
40✔
719
        }
40✔
720

721
        return *invoice, nil
20,778✔
722
}
723

724
// FetchPendingInvoices returns all the invoices that are currently in a
725
// "pending" state. An invoice is pending if it has been created but not yet
726
// settled or canceled.
727
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
728
        map[lntypes.Hash]Invoice, error) {
256✔
729

256✔
730
        var invoices map[lntypes.Hash]Invoice
256✔
731

256✔
732
        readTxOpt := NewSQLInvoiceQueryReadTx()
256✔
733
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
512✔
734
                return queryWithLimit(func(offset int) (int, error) {
518✔
735
                        params := sqlc.FilterInvoicesParams{
262✔
736
                                PendingOnly: true,
262✔
737
                                NumOffset:   int32(offset),
262✔
738
                                NumLimit:    int32(i.opts.paginationLimit),
262✔
739
                                Reverse:     false,
262✔
740
                        }
262✔
741

262✔
742
                        rows, err := db.FilterInvoices(ctx, params)
262✔
743
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
262✔
744
                                return 0, fmt.Errorf("unable to get invoices "+
×
745
                                        "from db: %w", err)
×
746
                        }
×
747

748
                        // Load all the information for the invoices.
749
                        for _, row := range rows {
302✔
750
                                hash, invoice, err := fetchInvoiceData(
40✔
751
                                        ctx, db, row, nil, true,
40✔
752
                                )
40✔
753
                                if err != nil {
40✔
754
                                        return 0, err
×
755
                                }
×
756

757
                                invoices[*hash] = *invoice
40✔
758
                        }
759

760
                        return len(rows), nil
262✔
761
                }, i.opts.paginationLimit)
762
        }, func() {
256✔
763
                invoices = make(map[lntypes.Hash]Invoice)
256✔
764
        })
256✔
765
        if err != nil {
256✔
766
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
767
                        err)
×
768
        }
×
769

770
        return invoices, nil
256✔
771
}
772

773
// InvoicesSettledSince can be used by callers to catch up any settled invoices
774
// they missed within the settled invoice time series. We'll return all known
775
// settled invoice that have a settle index higher than the passed idx.
776
//
777
// NOTE: The index starts from 1. As a result we enforce that specifying a value
778
// below the starting index value is a noop.
779
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
780
        []Invoice, error) {
42✔
781

42✔
782
        var invoices []Invoice
42✔
783

42✔
784
        if idx == 0 {
78✔
785
                return invoices, nil
36✔
786
        }
36✔
787

788
        readTxOpt := NewSQLInvoiceQueryReadTx()
6✔
789
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
12✔
790
                err := queryWithLimit(func(offset int) (int, error) {
18✔
791
                        params := sqlc.FilterInvoicesParams{
12✔
792
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
12✔
793
                                NumOffset:      int32(offset),
12✔
794
                                NumLimit:       int32(i.opts.paginationLimit),
12✔
795
                                Reverse:        false,
12✔
796
                        }
12✔
797

12✔
798
                        rows, err := db.FilterInvoices(ctx, params)
12✔
799
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
12✔
800
                                return 0, fmt.Errorf("unable to get invoices "+
×
801
                                        "from db: %w", err)
×
802
                        }
×
803

804
                        // Load all the information for the invoices.
805
                        for _, row := range rows {
30✔
806
                                _, invoice, err := fetchInvoiceData(
18✔
807
                                        ctx, db, row, nil, true,
18✔
808
                                )
18✔
809
                                if err != nil {
18✔
810
                                        return 0, fmt.Errorf("unable to fetch "+
×
811
                                                "invoice(id=%d) from db: %w",
×
812
                                                row.ID, err)
×
813
                                }
×
814

815
                                invoices = append(invoices, *invoice)
18✔
816
                        }
817

818
                        return len(rows), nil
12✔
819
                }, i.opts.paginationLimit)
820
                if err != nil {
6✔
821
                        return err
×
822
                }
×
823

824
                // Now fetch all the AMP sub invoices that were settled since
825
                // the provided index.
826
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
6✔
827
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
6✔
828
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
6✔
829
                        },
6✔
830
                )
6✔
831
                if err != nil {
6✔
832
                        return err
×
833
                }
×
834

835
                for _, ampInvoice := range ampInvoices {
10✔
836
                        // Convert the row to a sqlc.Invoice so we can use the
4✔
837
                        // existing fetchInvoiceData function.
4✔
838
                        sqlInvoice := sqlc.Invoice{
4✔
839
                                ID:             ampInvoice.ID,
4✔
840
                                Hash:           ampInvoice.Hash,
4✔
841
                                Preimage:       ampInvoice.Preimage,
4✔
842
                                SettleIndex:    ampInvoice.AmpSettleIndex,
4✔
843
                                SettledAt:      ampInvoice.AmpSettledAt,
4✔
844
                                Memo:           ampInvoice.Memo,
4✔
845
                                AmountMsat:     ampInvoice.AmountMsat,
4✔
846
                                CltvDelta:      ampInvoice.CltvDelta,
4✔
847
                                Expiry:         ampInvoice.Expiry,
4✔
848
                                PaymentAddr:    ampInvoice.PaymentAddr,
4✔
849
                                PaymentRequest: ampInvoice.PaymentRequest,
4✔
850
                                State:          ampInvoice.State,
4✔
851
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
4✔
852
                                IsAmp:          ampInvoice.IsAmp,
4✔
853
                                IsHodl:         ampInvoice.IsHodl,
4✔
854
                                IsKeysend:      ampInvoice.IsKeysend,
4✔
855
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
4✔
856
                        }
4✔
857

4✔
858
                        // Fetch the state and HTLCs for this AMP sub invoice.
4✔
859
                        _, invoice, err := fetchInvoiceData(
4✔
860
                                ctx, db, sqlInvoice,
4✔
861
                                (*[32]byte)(ampInvoice.SetID), true,
4✔
862
                        )
4✔
863
                        if err != nil {
4✔
864
                                return fmt.Errorf("unable to fetch "+
×
865
                                        "AMP invoice(id=%d) from db: %w",
×
866
                                        ampInvoice.ID, err)
×
867
                        }
×
868

869
                        invoices = append(invoices, *invoice)
4✔
870
                }
871

872
                return nil
6✔
873
        }, func() {
6✔
874
                invoices = nil
6✔
875
        })
6✔
876
        if err != nil {
6✔
877
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
878
                        "index (excluding) %d: %w", idx, err)
×
879
        }
×
880

881
        return invoices, nil
6✔
882
}
883

884
// InvoicesAddedSince can be used by callers to seek into the event time series
885
// of all the invoices added in the database. This method will return all
886
// invoices with an add index greater than the specified idx.
887
//
888
// NOTE: The index starts from 1. As a result we enforce that specifying a value
889
// below the starting index value is a noop.
890
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
891
        []Invoice, error) {
40✔
892

40✔
893
        var result []Invoice
40✔
894

40✔
895
        if idx == 0 {
74✔
896
                return result, nil
34✔
897
        }
34✔
898

899
        readTxOpt := NewSQLInvoiceQueryReadTx()
6✔
900
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
12✔
901
                return queryWithLimit(func(offset int) (int, error) {
30✔
902
                        params := sqlc.FilterInvoicesParams{
24✔
903
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
24✔
904
                                NumOffset:   int32(offset),
24✔
905
                                NumLimit:    int32(i.opts.paginationLimit),
24✔
906
                                Reverse:     false,
24✔
907
                        }
24✔
908

24✔
909
                        rows, err := db.FilterInvoices(ctx, params)
24✔
910
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
24✔
911
                                return 0, fmt.Errorf("unable to get invoices "+
×
912
                                        "from db: %w", err)
×
913
                        }
×
914

915
                        // Load all the information for the invoices.
916
                        for _, row := range rows {
82✔
917
                                _, invoice, err := fetchInvoiceData(
58✔
918
                                        ctx, db, row, nil, true,
58✔
919
                                )
58✔
920
                                if err != nil {
58✔
921
                                        return 0, err
×
922
                                }
×
923

924
                                result = append(result, *invoice)
58✔
925
                        }
926

927
                        return len(rows), nil
24✔
928
                }, i.opts.paginationLimit)
929
        }, func() {
6✔
930
                result = nil
6✔
931
        })
6✔
932

933
        if err != nil {
6✔
934
                return nil, fmt.Errorf("unable to get invoices added since "+
×
935
                        "index %d: %w", idx, err)
×
936
        }
×
937

938
        return result, nil
6✔
939
}
940

941
// QueryInvoices allows a caller to query the invoice database for invoices
942
// within the specified add index range.
943
func (i *SQLStore) QueryInvoices(ctx context.Context,
944
        q InvoiceQuery) (InvoiceSlice, error) {
90✔
945

90✔
946
        var invoices []Invoice
90✔
947

90✔
948
        if q.NumMaxInvoices == 0 {
90✔
949
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
950
                        "be non-zero")
×
951
        }
×
952

953
        readTxOpt := NewSQLInvoiceQueryReadTx()
90✔
954
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
180✔
955
                return queryWithLimit(func(offset int) (int, error) {
486✔
956
                        params := sqlc.FilterInvoicesParams{
396✔
957
                                NumOffset:   int32(offset),
396✔
958
                                NumLimit:    int32(i.opts.paginationLimit),
396✔
959
                                PendingOnly: q.PendingOnly,
396✔
960
                                Reverse:     q.Reversed,
396✔
961
                        }
396✔
962

396✔
963
                        if q.Reversed {
510✔
964
                                // If the index offset was not set, we want to
114✔
965
                                // fetch from the lastest invoice.
114✔
966
                                if q.IndexOffset == 0 {
162✔
967
                                        params.AddIndexLet = sqldb.SQLInt64(
48✔
968
                                                int64(math.MaxInt64),
48✔
969
                                        )
48✔
970
                                } else {
114✔
971
                                        // The invoice with index offset id must
66✔
972
                                        // not be included in the results.
66✔
973
                                        params.AddIndexLet = sqldb.SQLInt64(
66✔
974
                                                q.IndexOffset - 1,
66✔
975
                                        )
66✔
976
                                }
66✔
977
                        } else {
282✔
978
                                // The invoice with index offset id must not be
282✔
979
                                // included in the results.
282✔
980
                                params.AddIndexGet = sqldb.SQLInt64(
282✔
981
                                        q.IndexOffset + 1,
282✔
982
                                )
282✔
983
                        }
282✔
984

985
                        if q.CreationDateStart != 0 {
458✔
986
                                params.CreatedAfter = sqldb.SQLTime(
62✔
987
                                        time.Unix(q.CreationDateStart, 0).UTC(),
62✔
988
                                )
62✔
989
                        }
62✔
990

991
                        if q.CreationDateEnd != 0 {
458✔
992
                                // We need to add 1 to the end date as we're
62✔
993
                                // checking less than the end date in SQL.
62✔
994
                                params.CreatedBefore = sqldb.SQLTime(
62✔
995
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
62✔
996
                                )
62✔
997
                        }
62✔
998

999
                        rows, err := db.FilterInvoices(ctx, params)
396✔
1000
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
396✔
1001
                                return 0, fmt.Errorf("unable to get invoices "+
×
1002
                                        "from db: %w", err)
×
1003
                        }
×
1004

1005
                        // Load all the information for the invoices.
1006
                        for _, row := range rows {
1,418✔
1007
                                _, invoice, err := fetchInvoiceData(
1,022✔
1008
                                        ctx, db, row, nil, true,
1,022✔
1009
                                )
1,022✔
1010
                                if err != nil {
1,022✔
1011
                                        return 0, err
×
1012
                                }
×
1013

1014
                                invoices = append(invoices, *invoice)
1,022✔
1015

1,022✔
1016
                                if len(invoices) == int(q.NumMaxInvoices) {
1,050✔
1017
                                        return 0, nil
28✔
1018
                                }
28✔
1019
                        }
1020

1021
                        return len(rows), nil
368✔
1022
                }, i.opts.paginationLimit)
1023
        }, func() {
90✔
1024
                invoices = nil
90✔
1025
        })
90✔
1026
        if err != nil {
90✔
1027
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
1028
                        "invoices: %w", err)
×
1029
        }
×
1030

1031
        if len(invoices) == 0 {
102✔
1032
                return InvoiceSlice{
12✔
1033
                        InvoiceQuery: q,
12✔
1034
                }, nil
12✔
1035
        }
12✔
1036

1037
        // If we iterated through the add index in reverse order, then
1038
        // we'll need to reverse the slice of invoices to return them in
1039
        // forward order.
1040
        if q.Reversed {
102✔
1041
                numInvoices := len(invoices)
24✔
1042
                for i := 0; i < numInvoices/2; i++ {
162✔
1043
                        reverse := numInvoices - i - 1
138✔
1044
                        invoices[i], invoices[reverse] =
138✔
1045
                                invoices[reverse], invoices[i]
138✔
1046
                }
138✔
1047
        }
1048

1049
        res := InvoiceSlice{
78✔
1050
                InvoiceQuery:     q,
78✔
1051
                Invoices:         invoices,
78✔
1052
                FirstIndexOffset: invoices[0].AddIndex,
78✔
1053
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
78✔
1054
        }
78✔
1055

78✔
1056
        return res, nil
78✔
1057
}
1058

1059
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1060
// a SQL database as the backend.
1061
type sqlInvoiceUpdater struct {
1062
        db         SQLInvoiceQueries
1063
        ctx        context.Context //nolint:containedctx
1064
        invoice    *Invoice
1065
        updateTime time.Time
1066
}
1067

1068
// AddHtlc adds a new htlc to the invoice.
1069
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
1070
        newHtlc *InvoiceHTLC) error {
586✔
1071

586✔
1072
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
586✔
1073
                s.ctx, sqlc.InsertInvoiceHTLCParams{
586✔
1074
                        HtlcID: int64(circuitKey.HtlcID),
586✔
1075
                        ChanID: strconv.FormatUint(
586✔
1076
                                circuitKey.ChanID.ToUint64(), 10,
586✔
1077
                        ),
586✔
1078
                        AmountMsat: int64(newHtlc.Amt),
586✔
1079
                        TotalMppMsat: sql.NullInt64{
586✔
1080
                                Int64: int64(newHtlc.MppTotalAmt),
586✔
1081
                                Valid: newHtlc.MppTotalAmt != 0,
586✔
1082
                        },
586✔
1083
                        AcceptHeight: int32(newHtlc.AcceptHeight),
586✔
1084
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
586✔
1085
                        ExpiryHeight: int32(newHtlc.Expiry),
586✔
1086
                        State:        int16(newHtlc.State),
586✔
1087
                        InvoiceID:    int64(s.invoice.AddIndex),
586✔
1088
                },
586✔
1089
        )
586✔
1090
        if err != nil {
586✔
1091
                return err
×
1092
        }
×
1093

1094
        for key, value := range newHtlc.CustomRecords {
598✔
1095
                err = s.db.InsertInvoiceHTLCCustomRecord(
12✔
1096
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
12✔
1097
                                // TODO(bhandras): schema might be wrong here
12✔
1098
                                // as the custom record key is an uint64.
12✔
1099
                                Key:    int64(key),
12✔
1100
                                Value:  value,
12✔
1101
                                HtlcID: htlcPrimaryKeyID,
12✔
1102
                        },
12✔
1103
                )
12✔
1104
                if err != nil {
12✔
1105
                        return err
×
1106
                }
×
1107
        }
1108

1109
        if newHtlc.AMP != nil {
624✔
1110
                setID := newHtlc.AMP.Record.SetID()
38✔
1111

38✔
1112
                upsertResult, err := s.db.UpsertAMPSubInvoice(
38✔
1113
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
38✔
1114
                                SetID:     setID[:],
38✔
1115
                                CreatedAt: s.updateTime.UTC(),
38✔
1116
                                InvoiceID: int64(s.invoice.AddIndex),
38✔
1117
                        },
38✔
1118
                )
38✔
1119
                if err != nil {
40✔
1120
                        mappedSQLErr := sqldb.MapSQLError(err)
2✔
1121
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
2✔
1122
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
4✔
1123
                                return ErrDuplicateSetID{
2✔
1124
                                        SetID: setID,
2✔
1125
                                }
2✔
1126
                        }
2✔
1127

1128
                        return err
×
1129
                }
1130

1131
                // If we're just inserting the AMP invoice, we'll get a non
1132
                // zero rows affected count.
1133
                rowsAffected, err := upsertResult.RowsAffected()
36✔
1134
                if err != nil {
36✔
1135
                        return err
×
1136
                }
×
1137
                if rowsAffected != 0 {
62✔
1138
                        // If we're inserting a new AMP invoice, we'll also
26✔
1139
                        // insert a new invoice event.
26✔
1140
                        err = s.db.OnAMPSubInvoiceCreated(
26✔
1141
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
26✔
1142
                                        AddedAt:   s.updateTime.UTC(),
26✔
1143
                                        InvoiceID: int64(s.invoice.AddIndex),
26✔
1144
                                        SetID:     setID[:],
26✔
1145
                                },
26✔
1146
                        )
26✔
1147
                        if err != nil {
26✔
1148
                                return err
×
1149
                        }
×
1150
                }
1151

1152
                rootShare := newHtlc.AMP.Record.RootShare()
36✔
1153

36✔
1154
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
36✔
1155
                        InvoiceID: int64(s.invoice.AddIndex),
36✔
1156
                        SetID:     setID[:],
36✔
1157
                        HtlcID:    htlcPrimaryKeyID,
36✔
1158
                        RootShare: rootShare[:],
36✔
1159
                        ChildIndex: int64(
36✔
1160
                                newHtlc.AMP.Record.ChildIndex(),
36✔
1161
                        ),
36✔
1162
                        Hash: newHtlc.AMP.Hash[:],
36✔
1163
                }
36✔
1164

36✔
1165
                if newHtlc.AMP.Preimage != nil {
56✔
1166
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
20✔
1167
                }
20✔
1168

1169
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
36✔
1170
                if err != nil {
36✔
1171
                        return err
×
1172
                }
×
1173
        }
1174

1175
        return nil
584✔
1176
}
1177

1178
// ResolveHtlc marks an htlc as resolved with the given state.
1179
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
1180
        state HtlcState, resolveTime time.Time) error {
574✔
1181

574✔
1182
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
574✔
1183
                HtlcID: int64(circuitKey.HtlcID),
574✔
1184
                ChanID: strconv.FormatUint(
574✔
1185
                        circuitKey.ChanID.ToUint64(), 10,
574✔
1186
                ),
574✔
1187
                InvoiceID:   int64(s.invoice.AddIndex),
574✔
1188
                State:       int16(state),
574✔
1189
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
574✔
1190
        })
574✔
1191
}
574✔
1192

1193
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1194
// identified by the setID.
1195
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
1196
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
12✔
1197

12✔
1198
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
12✔
1199
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
12✔
1200
                        InvoiceID: int64(s.invoice.AddIndex),
12✔
1201
                        SetID:     setID[:],
12✔
1202
                        HtlcID:    int64(circuitKey.HtlcID),
12✔
1203
                        Preimage:  preimage[:],
12✔
1204
                        ChanID: strconv.FormatUint(
12✔
1205
                                circuitKey.ChanID.ToUint64(), 10,
12✔
1206
                        ),
12✔
1207
                },
12✔
1208
        )
12✔
1209
        if err != nil {
12✔
1210
                return err
×
1211
        }
×
1212

1213
        rowsAffected, err := result.RowsAffected()
12✔
1214
        if err != nil {
12✔
1215
                return err
×
1216
        }
×
1217
        if rowsAffected == 0 {
12✔
1218
                return ErrInvoiceNotFound
×
1219
        }
×
1220

1221
        return nil
12✔
1222
}
1223

1224
// UpdateInvoiceState updates the invoice state to the new state.
1225
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
1226
        newState ContractState, preimage *lntypes.Preimage) error {
392✔
1227

392✔
1228
        var (
392✔
1229
                settleIndex sql.NullInt64
392✔
1230
                settledAt   sql.NullTime
392✔
1231
        )
392✔
1232

392✔
1233
        switch newState {
392✔
1234
        case ContractSettled:
316✔
1235
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
316✔
1236
                if err != nil {
316✔
1237
                        return err
×
1238
                }
×
1239

1240
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
316✔
1241

316✔
1242
                // If the invoice is settled, we'll also update the settle time.
316✔
1243
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
316✔
1244

316✔
1245
                err = s.db.OnInvoiceSettled(
316✔
1246
                        s.ctx, sqlc.OnInvoiceSettledParams{
316✔
1247
                                AddedAt:   s.updateTime.UTC(),
316✔
1248
                                InvoiceID: int64(s.invoice.AddIndex),
316✔
1249
                        },
316✔
1250
                )
316✔
1251
                if err != nil {
316✔
1252
                        return err
×
1253
                }
×
1254

1255
        case ContractCanceled:
58✔
1256
                err := s.db.OnInvoiceCanceled(
58✔
1257
                        s.ctx, sqlc.OnInvoiceCanceledParams{
58✔
1258
                                AddedAt:   s.updateTime.UTC(),
58✔
1259
                                InvoiceID: int64(s.invoice.AddIndex),
58✔
1260
                        },
58✔
1261
                )
58✔
1262
                if err != nil {
58✔
1263
                        return err
×
1264
                }
×
1265
        }
1266

1267
        params := sqlc.UpdateInvoiceStateParams{
392✔
1268
                ID:          int64(s.invoice.AddIndex),
392✔
1269
                State:       int16(newState),
392✔
1270
                SettleIndex: settleIndex,
392✔
1271
                SettledAt:   settledAt,
392✔
1272
        }
392✔
1273

392✔
1274
        if preimage != nil {
400✔
1275
                params.Preimage = preimage[:]
8✔
1276
        }
8✔
1277

1278
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
392✔
1279
        if err != nil {
392✔
1280
                return err
×
1281
        }
×
1282
        rowsAffected, err := result.RowsAffected()
392✔
1283
        if err != nil {
392✔
1284
                return err
×
1285
        }
×
1286

1287
        if rowsAffected == 0 {
392✔
1288
                return ErrInvoiceNotFound
×
1289
        }
×
1290

1291
        if settleIndex.Valid {
708✔
1292
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
316✔
1293
                s.invoice.SettleDate = s.updateTime
316✔
1294
        }
316✔
1295

1296
        return nil
392✔
1297
}
1298

1299
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1300
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
1301
        amtPaid lnwire.MilliSatoshi) error {
608✔
1302

608✔
1303
        _, err := s.db.UpdateInvoiceAmountPaid(
608✔
1304
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
608✔
1305
                        ID:             int64(s.invoice.AddIndex),
608✔
1306
                        AmountPaidMsat: int64(amtPaid),
608✔
1307
                },
608✔
1308
        )
608✔
1309

608✔
1310
        return err
608✔
1311
}
608✔
1312

1313
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1314
// setID.
1315
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
1316
        newState InvoiceStateAMP, _ models.CircuitKey) error {
64✔
1317

64✔
1318
        var (
64✔
1319
                settleIndex sql.NullInt64
64✔
1320
                settledAt   sql.NullTime
64✔
1321
        )
64✔
1322

64✔
1323
        switch newState.State {
64✔
1324
        case HtlcStateSettled:
22✔
1325
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
22✔
1326
                if err != nil {
22✔
1327
                        return err
×
1328
                }
×
1329

1330
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
22✔
1331

22✔
1332
                // If the invoice is settled, we'll also update the settle time.
22✔
1333
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
22✔
1334

22✔
1335
                err = s.db.OnAMPSubInvoiceSettled(
22✔
1336
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
22✔
1337
                                AddedAt:   s.updateTime.UTC(),
22✔
1338
                                InvoiceID: int64(s.invoice.AddIndex),
22✔
1339
                                SetID:     setID[:],
22✔
1340
                        },
22✔
1341
                )
22✔
1342
                if err != nil {
22✔
1343
                        return err
×
1344
                }
×
1345

1346
        case HtlcStateCanceled:
6✔
1347
                err := s.db.OnAMPSubInvoiceCanceled(
6✔
1348
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
6✔
1349
                                AddedAt:   s.updateTime.UTC(),
6✔
1350
                                InvoiceID: int64(s.invoice.AddIndex),
6✔
1351
                                SetID:     setID[:],
6✔
1352
                        },
6✔
1353
                )
6✔
1354
                if err != nil {
6✔
1355
                        return err
×
1356
                }
×
1357
        }
1358

1359
        err := s.db.UpdateAMPSubInvoiceState(
64✔
1360
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
64✔
1361
                        SetID:       setID[:],
64✔
1362
                        State:       int16(newState.State),
64✔
1363
                        SettleIndex: settleIndex,
64✔
1364
                        SettledAt:   settledAt,
64✔
1365
                },
64✔
1366
        )
64✔
1367
        if err != nil {
64✔
1368
                return err
×
1369
        }
×
1370

1371
        if settleIndex.Valid {
86✔
1372
                updatedState := s.invoice.AMPState[setID]
22✔
1373
                updatedState.SettleIndex = uint64(settleIndex.Int64)
22✔
1374
                updatedState.SettleDate = s.updateTime.UTC()
22✔
1375
                s.invoice.AMPState[setID] = updatedState
22✔
1376
        }
22✔
1377

1378
        return nil
64✔
1379
}
1380

1381
// Finalize finalizes the update before it is written to the database. Note that
1382
// we don't use this directly in the SQL implementation, so the function is just
1383
// a stub.
1384
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
672✔
1385
        return nil
672✔
1386
}
672✔
1387

1388
// UpdateInvoice attempts to update an invoice corresponding to the passed
1389
// reference. If an invoice matching the passed reference doesn't exist within
1390
// the database, then the action will fail with  ErrInvoiceNotFound error.
1391
//
1392
// The update is performed inside the same database transaction that fetches the
1393
// invoice and is therefore atomic. The fields to update are controlled by the
1394
// supplied callback.
1395
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1396
        setID *SetID, callback InvoiceUpdateCallback) (
1397
        *Invoice, error) {
736✔
1398

736✔
1399
        var updatedInvoice *Invoice
736✔
1400

736✔
1401
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
736✔
1402
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
1,472✔
1403
                if setID != nil {
798✔
1404
                        // Make sure to use the set ID if this is an AMP update.
62✔
1405
                        var setIDBytes [32]byte
62✔
1406
                        copy(setIDBytes[:], setID[:])
62✔
1407
                        ref.setID = &setIDBytes
62✔
1408

62✔
1409
                        // If we're updating an AMP invoice, we'll also only
62✔
1410
                        // need to fetch the HTLCs for the given set ID.
62✔
1411
                        ref.refModifier = HtlcSetOnlyModifier
62✔
1412
                }
62✔
1413

1414
                invoice, err := fetchInvoice(ctx, db, ref)
736✔
1415
                if err != nil {
744✔
1416
                        return err
8✔
1417
                }
8✔
1418

1419
                updateTime := i.clock.Now()
728✔
1420
                updater := &sqlInvoiceUpdater{
728✔
1421
                        db:         db,
728✔
1422
                        ctx:        ctx,
728✔
1423
                        invoice:    invoice,
728✔
1424
                        updateTime: updateTime,
728✔
1425
                }
728✔
1426

728✔
1427
                payHash := ref.PayHash()
728✔
1428
                updatedInvoice, err = UpdateInvoice(
728✔
1429
                        payHash, invoice, updateTime, callback, updater,
728✔
1430
                )
728✔
1431

728✔
1432
                return err
728✔
1433
        }, func() {})
736✔
1434
        if txErr != nil {
768✔
1435
                // If the invoice is already settled, we'll return the
32✔
1436
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
32✔
1437
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
44✔
1438
                        return updatedInvoice, txErr
12✔
1439
                }
12✔
1440

1441
                return nil, txErr
20✔
1442
        }
1443

1444
        return updatedInvoice, nil
704✔
1445
}
1446

1447
// DeleteInvoice attempts to delete the passed invoices and all their related
1448
// data from the database in one transaction.
1449
func (i *SQLStore) DeleteInvoice(ctx context.Context,
1450
        invoicesToDelete []InvoiceDeleteRef) error {
12✔
1451

12✔
1452
        // All the InvoiceDeleteRef instances include the add index of the
12✔
1453
        // invoice. The rest was added to ensure that the invoices were deleted
12✔
1454
        // properly in the kv database. When we have fully migrated we can
12✔
1455
        // remove the rest of the fields.
12✔
1456
        for _, ref := range invoicesToDelete {
40✔
1457
                if ref.AddIndex == 0 {
28✔
1458
                        return fmt.Errorf("unable to delete invoice using a "+
×
1459
                                "ref without AddIndex set: %v", ref)
×
1460
                }
×
1461
        }
1462

1463
        var writeTxOpt SQLInvoiceQueriesTxOptions
12✔
1464
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
24✔
1465
                for _, ref := range invoicesToDelete {
34✔
1466
                        params := sqlc.DeleteInvoiceParams{
22✔
1467
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
22✔
1468
                        }
22✔
1469

22✔
1470
                        if ref.SettleIndex != 0 {
28✔
1471
                                params.SettleIndex = sqldb.SQLInt64(
6✔
1472
                                        ref.SettleIndex,
6✔
1473
                                )
6✔
1474
                        }
6✔
1475

1476
                        if ref.PayHash != lntypes.ZeroHash {
44✔
1477
                                params.Hash = ref.PayHash[:]
22✔
1478
                        }
22✔
1479

1480
                        result, err := db.DeleteInvoice(ctx, params)
22✔
1481
                        if err != nil {
22✔
1482
                                return fmt.Errorf("unable to delete "+
×
1483
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1484
                        }
×
1485
                        rowsAffected, err := result.RowsAffected()
22✔
1486
                        if err != nil {
22✔
1487
                                return fmt.Errorf("unable to get rows "+
×
1488
                                        "affected: %w", err)
×
1489
                        }
×
1490
                        if rowsAffected == 0 {
28✔
1491
                                return fmt.Errorf("%w: %v",
6✔
1492
                                        ErrInvoiceNotFound, ref.AddIndex)
6✔
1493
                        }
6✔
1494
                }
1495

1496
                return nil
6✔
1497
        }, func() {})
12✔
1498

1499
        if err != nil {
18✔
1500
                return fmt.Errorf("unable to delete invoices: %w", err)
6✔
1501
        }
6✔
1502

1503
        return nil
6✔
1504
}
1505

1506
// DeleteCanceledInvoices removes all canceled invoices from the database.
1507
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
6✔
1508
        var writeTxOpt SQLInvoiceQueriesTxOptions
6✔
1509
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
12✔
1510
                _, err := db.DeleteCanceledInvoices(ctx)
6✔
1511
                if err != nil {
6✔
1512
                        return fmt.Errorf("unable to delete canceled "+
×
1513
                                "invoices: %w", err)
×
1514
                }
×
1515

1516
                return nil
6✔
1517
        }, func() {})
6✔
1518
        if err != nil {
6✔
1519
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1520
        }
×
1521

1522
        return nil
6✔
1523
}
1524

1525
// fetchInvoiceData fetches additional data for the given invoice. If the
1526
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1527
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1528
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1529
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1530
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
1531
        *Invoice, error) {
22,652✔
1532

22,652✔
1533
        // Unmarshal the common data.
22,652✔
1534
        hash, invoice, err := unmarshalInvoice(row)
22,652✔
1535
        if err != nil {
22,652✔
1536
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1537
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1538
        }
×
1539

1540
        // Fetch the invoice features.
1541
        features, err := getInvoiceFeatures(ctx, db, row.ID)
22,652✔
1542
        if err != nil {
22,652✔
1543
                return nil, nil, err
×
1544
        }
×
1545

1546
        invoice.Terms.Features = features
22,652✔
1547

22,652✔
1548
        // If this is an AMP invoice, we'll need fetch the AMP state along
22,652✔
1549
        // with the HTLCs (if requested).
22,652✔
1550
        if invoice.IsAMP() {
32,776✔
1551
                invoiceID := int64(invoice.AddIndex)
10,124✔
1552
                ampState, ampHtlcs, err := fetchAmpState(
10,124✔
1553
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
10,124✔
1554
                )
10,124✔
1555
                if err != nil {
10,124✔
1556
                        return nil, nil, err
×
1557
                }
×
1558

1559
                invoice.AMPState = ampState
10,124✔
1560
                invoice.Htlcs = ampHtlcs
10,124✔
1561

10,124✔
1562
                return hash, invoice, nil
10,124✔
1563
        }
1564

1565
        // Otherwise simply fetch the invoice HTLCs.
1566
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
12,528✔
1567
        if err != nil {
12,528✔
1568
                return nil, nil, err
×
1569
        }
×
1570

1571
        if len(htlcs) > 0 {
20,878✔
1572
                invoice.Htlcs = htlcs
8,350✔
1573
                var amountPaid lnwire.MilliSatoshi
8,350✔
1574
                for _, htlc := range htlcs {
31,621✔
1575
                        if htlc.State == HtlcStateSettled {
32,556✔
1576
                                amountPaid += htlc.Amt
9,285✔
1577
                        }
9,285✔
1578
                }
1579
                invoice.AmtPaid = amountPaid
8,350✔
1580
        }
1581

1582
        return hash, invoice, nil
12,528✔
1583
}
1584

1585
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1586
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
1587
        invoiceID int64) (*lnwire.FeatureVector, error) {
22,652✔
1588

22,652✔
1589
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
22,652✔
1590
        if err != nil {
22,652✔
1591
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1592
                        err)
×
1593
        }
×
1594

1595
        features := lnwire.EmptyFeatureVector()
22,652✔
1596
        for _, feature := range rows {
38,464✔
1597
                features.Set(lnwire.FeatureBit(feature.Feature))
15,812✔
1598
        }
15,812✔
1599

1600
        return features, nil
22,652✔
1601
}
1602

1603
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1604
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
1605
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
12,528✔
1606

12,528✔
1607
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
12,528✔
1608
        if err != nil {
12,528✔
1609
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1610
        }
×
1611

1612
        // We have no htlcs to unmarshal.
1613
        if len(htlcRows) == 0 {
16,706✔
1614
                return nil, nil
4,178✔
1615
        }
4,178✔
1616

1617
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
8,350✔
1618
        if err != nil {
8,350✔
1619
                return nil, fmt.Errorf("unable to get custom records for "+
×
1620
                        "invoice htlcs: %w", err)
×
1621
        }
×
1622

1623
        cr := make(map[int64]record.CustomSet, len(crRows))
8,350✔
1624
        for _, row := range crRows {
56,218✔
1625
                if _, ok := cr[row.HtlcID]; !ok {
65,396✔
1626
                        cr[row.HtlcID] = make(record.CustomSet)
17,528✔
1627
                }
17,528✔
1628

1629
                value := row.Value
47,868✔
1630
                if value == nil {
47,869✔
1631
                        value = []byte{}
1✔
1632
                }
1✔
1633
                cr[row.HtlcID][uint64(row.Key)] = value
47,868✔
1634
        }
1635

1636
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
8,350✔
1637

8,350✔
1638
        for _, row := range htlcRows {
31,621✔
1639
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
23,271✔
1640
                if err != nil {
23,271✔
1641
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1642
                                "htlc(%d): %w", row.ID, err)
×
1643
                }
×
1644

1645
                if customRecords, ok := cr[row.ID]; ok {
40,799✔
1646
                        htlc.CustomRecords = customRecords
17,528✔
1647
                } else {
23,271✔
1648
                        htlc.CustomRecords = make(record.CustomSet)
5,743✔
1649
                }
5,743✔
1650

1651
                htlcs[circuiteKey] = htlc
23,271✔
1652
        }
1653

1654
        return htlcs, nil
8,350✔
1655
}
1656

1657
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1658
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
1659
        error) {
22,652✔
1660

22,652✔
1661
        var (
22,652✔
1662
                settleIndex    int64
22,652✔
1663
                settledAt      time.Time
22,652✔
1664
                memo           []byte
22,652✔
1665
                paymentRequest []byte
22,652✔
1666
                preimage       *lntypes.Preimage
22,652✔
1667
                paymentAddr    [32]byte
22,652✔
1668
        )
22,652✔
1669

22,652✔
1670
        hash, err := lntypes.MakeHash(row.Hash)
22,652✔
1671
        if err != nil {
22,652✔
1672
                return nil, nil, err
×
1673
        }
×
1674

1675
        if row.SettleIndex.Valid {
28,750✔
1676
                settleIndex = row.SettleIndex.Int64
6,098✔
1677
        }
6,098✔
1678

1679
        if row.SettledAt.Valid {
28,750✔
1680
                settledAt = row.SettledAt.Time.Local()
6,098✔
1681
        }
6,098✔
1682

1683
        if row.Memo.Valid {
43,970✔
1684
                memo = []byte(row.Memo.String)
21,318✔
1685
        }
21,318✔
1686

1687
        // Keysend payments will have this field empty.
1688
        if row.PaymentRequest.Valid {
43,405✔
1689
                paymentRequest = []byte(row.PaymentRequest.String)
20,753✔
1690
        } else {
22,652✔
1691
                paymentRequest = []byte{}
1,899✔
1692
        }
1,899✔
1693

1694
        // We may not have the preimage if this a hodl invoice.
1695
        if row.Preimage != nil {
45,216✔
1696
                preimage = &lntypes.Preimage{}
22,564✔
1697
                copy(preimage[:], row.Preimage)
22,564✔
1698
        }
22,564✔
1699

1700
        copy(paymentAddr[:], row.PaymentAddr)
22,652✔
1701

22,652✔
1702
        var cltvDelta int32
22,652✔
1703
        if row.CltvDelta.Valid {
45,304✔
1704
                cltvDelta = row.CltvDelta.Int32
22,652✔
1705
        }
22,652✔
1706

1707
        expiry := time.Duration(row.Expiry) * time.Second
22,652✔
1708

22,652✔
1709
        invoice := &Invoice{
22,652✔
1710
                SettleIndex:    uint64(settleIndex),
22,652✔
1711
                SettleDate:     settledAt,
22,652✔
1712
                Memo:           memo,
22,652✔
1713
                PaymentRequest: paymentRequest,
22,652✔
1714
                CreationDate:   row.CreatedAt.Local(),
22,652✔
1715
                Terms: ContractTerm{
22,652✔
1716
                        FinalCltvDelta:  cltvDelta,
22,652✔
1717
                        Expiry:          expiry,
22,652✔
1718
                        PaymentPreimage: preimage,
22,652✔
1719
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
22,652✔
1720
                        PaymentAddr:     paymentAddr,
22,652✔
1721
                },
22,652✔
1722
                AddIndex:    uint64(row.ID),
22,652✔
1723
                State:       ContractState(row.State),
22,652✔
1724
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
22,652✔
1725
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
22,652✔
1726
                AMPState:    AMPInvoiceState{},
22,652✔
1727
                HodlInvoice: row.IsHodl,
22,652✔
1728
        }
22,652✔
1729

22,652✔
1730
        return &hash, invoice, nil
22,652✔
1731
}
1732

1733
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1734
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
1735
        *InvoiceHTLC, error) {
23,271✔
1736

23,271✔
1737
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
23,271✔
1738
        if err != nil {
23,271✔
1739
                return CircuitKey{}, nil, err
×
1740
        }
×
1741

1742
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
23,271✔
1743

23,271✔
1744
        if row.HtlcID < 0 {
23,271✔
1745
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1746
                        "value: %v", row.HtlcID)
×
1747
        }
×
1748

1749
        htlcID := uint64(row.HtlcID)
23,271✔
1750

23,271✔
1751
        circuitKey := CircuitKey{
23,271✔
1752
                ChanID: chanID,
23,271✔
1753
                HtlcID: htlcID,
23,271✔
1754
        }
23,271✔
1755

23,271✔
1756
        htlc := &InvoiceHTLC{
23,271✔
1757
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
23,271✔
1758
                AcceptHeight: uint32(row.AcceptHeight),
23,271✔
1759
                AcceptTime:   row.AcceptTime.Local(),
23,271✔
1760
                Expiry:       uint32(row.ExpiryHeight),
23,271✔
1761
                State:        HtlcState(row.State),
23,271✔
1762
        }
23,271✔
1763

23,271✔
1764
        if row.TotalMppMsat.Valid {
42,751✔
1765
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
19,480✔
1766
        }
19,480✔
1767

1768
        if row.ResolveTime.Valid {
39,129✔
1769
                htlc.ResolveTime = row.ResolveTime.Time.Local()
15,858✔
1770
        }
15,858✔
1771

1772
        return circuitKey, htlc, nil
23,271✔
1773
}
1774

1775
// queryWithLimit is a helper method that can be used to query the database
1776
// using a limit and offset. The passed query function should return the number
1777
// of rows returned and an error if any.
1778
func queryWithLimit(query func(int) (int, error), limit int) error {
358✔
1779
        offset := 0
358✔
1780
        for {
1,052✔
1781
                rows, err := query(offset)
694✔
1782
                if err != nil {
694✔
1783
                        return err
×
1784
                }
×
1785

1786
                if rows < limit {
1,052✔
1787
                        return nil
358✔
1788
                }
358✔
1789

1790
                offset += limit
336✔
1791
        }
1792
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc