• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 14885001079

07 May 2025 01:46PM UTC coverage: 58.609% (-10.4%) from 68.992%
14885001079

Pull #9793

github

web-flow
Merge ecb9755e1 into 67a40c90a
Pull Request #9793: tlv: catch unhandled type in SizeBigSize

97439 of 166252 relevant lines covered (58.61%)

1.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27

28
        // invoiceProgressLogInterval is the interval we use limiting the
29
        // logging output of invoice processing.
30
        invoiceProgressLogInterval = 30 * time.Second
31
)
32

33
// SQLInvoiceQueries is an interface that defines the set of operations that can
34
// be executed against the invoice SQL database.
35
type SQLInvoiceQueries interface { //nolint:interfacebloat
36
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
37
                error)
38

39
        // TODO(bhandras): remove this once migrations have been separated out.
40
        InsertMigratedInvoice(ctx context.Context,
41
                arg sqlc.InsertMigratedInvoiceParams) (int64, error)
42

43
        InsertInvoiceFeature(ctx context.Context,
44
                arg sqlc.InsertInvoiceFeatureParams) error
45

46
        InsertInvoiceHTLC(ctx context.Context,
47
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
48

49
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
50
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
51

52
        FilterInvoices(ctx context.Context,
53
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
54

55
        GetInvoice(ctx context.Context,
56
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
57

58
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
59
                error)
60

61
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
62
                error)
63

64
        GetInvoiceFeatures(ctx context.Context,
65
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
66

67
        GetInvoiceHTLCCustomRecords(ctx context.Context,
68
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
69

70
        GetInvoiceHTLCs(ctx context.Context,
71
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
72

73
        UpdateInvoiceState(ctx context.Context,
74
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
75

76
        UpdateInvoiceAmountPaid(ctx context.Context,
77
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
78

79
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
80

81
        UpdateInvoiceHTLC(ctx context.Context,
82
                arg sqlc.UpdateInvoiceHTLCParams) error
83

84
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
85
                sql.Result, error)
86

87
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
88

89
        // AMP sub invoice specific methods.
90
        UpsertAMPSubInvoice(ctx context.Context,
91
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
92

93
        // TODO(bhandras): remove this once migrations have been separated out.
94
        InsertAMPSubInvoice(ctx context.Context,
95
                arg sqlc.InsertAMPSubInvoiceParams) error
96

97
        UpdateAMPSubInvoiceState(ctx context.Context,
98
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
99

100
        InsertAMPSubInvoiceHTLC(ctx context.Context,
101
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
102

103
        FetchAMPSubInvoices(ctx context.Context,
104
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
105
                error)
106

107
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
108
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
109
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
110

111
        FetchSettledAMPSubInvoices(ctx context.Context,
112
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
113
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
114

115
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
116
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
117
                error)
118

119
        // Invoice events specific methods.
120
        OnInvoiceCreated(ctx context.Context,
121
                arg sqlc.OnInvoiceCreatedParams) error
122

123
        OnInvoiceCanceled(ctx context.Context,
124
                arg sqlc.OnInvoiceCanceledParams) error
125

126
        OnInvoiceSettled(ctx context.Context,
127
                arg sqlc.OnInvoiceSettledParams) error
128

129
        OnAMPSubInvoiceCreated(ctx context.Context,
130
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
131

132
        OnAMPSubInvoiceCanceled(ctx context.Context,
133
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
134

135
        OnAMPSubInvoiceSettled(ctx context.Context,
136
                arg sqlc.OnAMPSubInvoiceSettledParams) error
137

138
        // Migration specific methods.
139
        // TODO(bhandras): remove this once migrations have been separated out.
140
        InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
141
                arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
142

143
        SetKVInvoicePaymentHash(ctx context.Context,
144
                arg sqlc.SetKVInvoicePaymentHashParams) error
145

146
        GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
147
                []byte, error)
148

149
        ClearKVInvoiceHashIndex(ctx context.Context) error
150
}
151

152
var _ InvoiceDB = (*SQLStore)(nil)
153

154
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
155
// SQLInvoiceQueries understands.
156
type SQLInvoiceQueriesTxOptions struct {
157
        // readOnly governs if a read only transaction is needed or not.
158
        readOnly bool
159
}
160

161
// ReadOnly returns true if the transaction should be read only.
162
//
163
// NOTE: This implements the TxOptions.
164
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
×
165
        return a.readOnly
×
166
}
×
167

168
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
169
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
×
170
        return SQLInvoiceQueriesTxOptions{
×
171
                readOnly: true,
×
172
        }
×
173
}
×
174

175
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
176
// of batched database operations.
177
type BatchedSQLInvoiceQueries interface {
178
        SQLInvoiceQueries
179

180
        sqldb.BatchedTx[SQLInvoiceQueries]
181
}
182

183
// SQLStore represents a storage backend.
184
type SQLStore struct {
185
        db    BatchedSQLInvoiceQueries
186
        clock clock.Clock
187
        opts  SQLStoreOptions
188
}
189

190
// SQLStoreOptions holds the options for the SQL store.
191
type SQLStoreOptions struct {
192
        paginationLimit int
193
}
194

195
// defaultSQLStoreOptions returns the default options for the SQL store.
196
func defaultSQLStoreOptions() SQLStoreOptions {
×
197
        return SQLStoreOptions{
×
198
                paginationLimit: defaultQueryPaginationLimit,
×
199
        }
×
200
}
×
201

202
// SQLStoreOption is a functional option that can be used to optionally modify
203
// the behavior of the SQL store.
204
type SQLStoreOption func(*SQLStoreOptions)
205

206
// WithPaginationLimit sets the pagination limit for the SQL store queries that
207
// paginate results.
208
func WithPaginationLimit(limit int) SQLStoreOption {
×
209
        return func(o *SQLStoreOptions) {
×
210
                o.paginationLimit = limit
×
211
        }
×
212
}
213

214
// NewSQLStore creates a new SQLStore instance given a open
215
// BatchedSQLInvoiceQueries storage backend.
216
func NewSQLStore(db BatchedSQLInvoiceQueries,
217
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
×
218

×
219
        opts := defaultSQLStoreOptions()
×
220
        for _, applyOption := range options {
×
221
                applyOption(&opts)
×
222
        }
×
223

224
        return &SQLStore{
×
225
                db:    db,
×
226
                clock: clock,
×
227
                opts:  opts,
×
228
        }
×
229
}
230

231
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
232
        sqlc.InsertInvoiceParams, error) {
×
233

×
234
        // Precompute the payment request hash so we can use it in the query.
×
235
        var paymentRequestHash []byte
×
236
        if len(invoice.PaymentRequest) > 0 {
×
237
                h := sha256.New()
×
238
                h.Write(invoice.PaymentRequest)
×
239
                paymentRequestHash = h.Sum(nil)
×
240
        }
×
241

242
        params := sqlc.InsertInvoiceParams{
×
243
                Hash:       paymentHash[:],
×
244
                AmountMsat: int64(invoice.Terms.Value),
×
245
                CltvDelta: sqldb.SQLInt32(
×
246
                        invoice.Terms.FinalCltvDelta,
×
247
                ),
×
248
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
×
249
                // Note: keysend invoices don't have a payment request.
×
250
                PaymentRequest: sqldb.SQLStr(string(
×
251
                        invoice.PaymentRequest),
×
252
                ),
×
253
                PaymentRequestHash: paymentRequestHash,
×
254
                State:              int16(invoice.State),
×
255
                AmountPaidMsat:     int64(invoice.AmtPaid),
×
256
                IsAmp:              invoice.IsAMP(),
×
257
                IsHodl:             invoice.HodlInvoice,
×
258
                IsKeysend:          invoice.IsKeysend(),
×
259
                CreatedAt:          invoice.CreationDate.UTC(),
×
260
        }
×
261

×
262
        if invoice.Memo != nil {
×
263
                // Store the memo as a nullable string in the database. Note
×
264
                // that for compatibility reasons, we store the value as a valid
×
265
                // string even if it's empty.
×
266
                params.Memo = sql.NullString{
×
267
                        String: string(invoice.Memo),
×
268
                        Valid:  true,
×
269
                }
×
270
        }
×
271

272
        // Some invoices may not have a preimage, like in the case of HODL
273
        // invoices.
274
        if invoice.Terms.PaymentPreimage != nil {
×
275
                preimage := *invoice.Terms.PaymentPreimage
×
276
                if preimage == UnknownPreimage {
×
277
                        return sqlc.InsertInvoiceParams{},
×
278
                                errors.New("cannot use all-zeroes preimage")
×
279
                }
×
280
                params.Preimage = preimage[:]
×
281
        }
282

283
        // Some non MPP payments may have the default (invalid) value.
284
        if invoice.Terms.PaymentAddr != BlankPayAddr {
×
285
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
×
286
        }
×
287

288
        return params, nil
×
289
}
290

291
// AddInvoice inserts the targeted invoice into the database. If the invoice has
292
// *any* payment hashes which already exists within the database, then the
293
// insertion will be aborted and rejected due to the strict policy banning any
294
// duplicate payment hashes.
295
//
296
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
297
func (i *SQLStore) AddInvoice(ctx context.Context,
298
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
×
299

×
300
        // Make sure this is a valid invoice before trying to store it in our
×
301
        // DB.
×
302
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
303
                return 0, err
×
304
        }
×
305

306
        var (
×
307
                writeTxOpts SQLInvoiceQueriesTxOptions
×
308
                invoiceID   int64
×
309
        )
×
310

×
311
        insertInvoiceParams, err := makeInsertInvoiceParams(
×
312
                newInvoice, paymentHash,
×
313
        )
×
314
        if err != nil {
×
315
                return 0, err
×
316
        }
×
317

318
        err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
×
319
                var err error
×
320
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
×
321
                if err != nil {
×
322
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
323
                }
×
324

325
                // TODO(positiveblue): if invocies do not have custom features
326
                // maybe just store the "invoice type" and populate the features
327
                // based on that.
328
                for feature := range newInvoice.Terms.Features.Features() {
×
329
                        params := sqlc.InsertInvoiceFeatureParams{
×
330
                                InvoiceID: invoiceID,
×
331
                                Feature:   int32(feature),
×
332
                        }
×
333

×
334
                        err := db.InsertInvoiceFeature(ctx, params)
×
335
                        if err != nil {
×
336
                                return fmt.Errorf("unable to insert invoice "+
×
337
                                        "feature(%v): %w", feature, err)
×
338
                        }
×
339
                }
340

341
                // Finally add a new event for this invoice.
342
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
343
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
344
                        InvoiceID: invoiceID,
×
345
                })
×
346
        }, func() {})
×
347
        if err != nil {
×
348
                mappedSQLErr := sqldb.MapSQLError(err)
×
349
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
350
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
351
                        // Add context to unique constraint errors.
×
352
                        return 0, ErrDuplicateInvoice
×
353
                }
×
354

355
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
356
                        paymentHash, err)
×
357
        }
358

359
        newInvoice.AddIndex = uint64(invoiceID)
×
360

×
361
        return newInvoice.AddIndex, nil
×
362
}
363

364
// getInvoiceByRef fetches the invoice with the given reference. The reference
365
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
366
func getInvoiceByRef(ctx context.Context,
367
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
×
368

×
369
        // If the reference is empty, we can't look up the invoice.
×
370
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
×
371
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
372
        }
×
373

374
        // If the reference is a hash only, we can look up the invoice directly
375
        // by the payment hash which is faster.
376
        if ref.IsHashOnly() {
×
377
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
×
378
                if errors.Is(err, sql.ErrNoRows) {
×
379
                        return sqlc.Invoice{}, ErrInvoiceNotFound
×
380
                }
×
381

382
                return invoice, err
×
383
        }
384

385
        // Otherwise the reference may include more fields, so we'll need to
386
        // assemble the query parameters based on the fields that are set.
387
        var params sqlc.GetInvoiceParams
×
388

×
389
        if ref.PayHash() != nil {
×
390
                params.Hash = ref.PayHash()[:]
×
391
        }
×
392

393
        // Newer invoices (0.11 and up) are indexed by payment address in
394
        // addition to payment hash, but pre 0.8 invoices do not have one at
395
        // all. Only allow lookups for payment address if it is not a blank
396
        // payment address, which is a special-cased value for legacy keysend
397
        // invoices.
398
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
399
                params.PaymentAddr = ref.PayAddr()[:]
×
400
        }
×
401

402
        // If the reference has a set ID we'll fetch the invoice which has the
403
        // corresponding AMP sub invoice.
404
        if ref.SetID() != nil {
×
405
                params.SetID = ref.SetID()[:]
×
406
        }
×
407

408
        var (
×
409
                rows []sqlc.Invoice
×
410
                err  error
×
411
        )
×
412

×
413
        // We need to split the query based on how we intend to look up the
×
414
        // invoice. If only the set ID is given then we want to have an exact
×
415
        // match on the set ID. If other fields are given, we want to match on
×
416
        // those fields and the set ID but with a less strict join condition.
×
417
        if params.Hash == nil && params.PaymentAddr == nil &&
×
418
                params.SetID != nil {
×
419

×
420
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
×
421
        } else {
×
422
                rows, err = db.GetInvoice(ctx, params)
×
423
        }
×
424

425
        switch {
×
426
        case len(rows) == 0:
×
427
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
428

429
        case len(rows) > 1:
×
430
                // In case the reference is ambiguous, meaning it matches more
×
431
                // than        one invoice, we'll return an error.
×
432
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
×
433
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
434

435
        case err != nil:
×
436
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
437
                        err)
×
438
        }
439

440
        return rows[0], nil
×
441
}
442

443
// fetchInvoice fetches the common invoice data and the AMP state for the
444
// invoice with the given reference.
445
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
446
        *Invoice, error) {
×
447

×
448
        // Fetch the invoice from the database.
×
449
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
×
450
        if err != nil {
×
451
                return nil, err
×
452
        }
×
453

454
        var (
×
455
                setID         *[32]byte
×
456
                fetchAmpHtlcs bool
×
457
        )
×
458

×
459
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
460
        // the modifier.
×
461
        switch ref.Modifier() {
×
462
        case DefaultModifier:
×
463
                // By default we'll fetch all AMP HTLCs.
×
464
                setID = nil
×
465
                fetchAmpHtlcs = true
×
466

467
        case HtlcSetOnlyModifier:
×
468
                // In this case we'll fetch all AMP HTLCs for the specified set
×
469
                // id.
×
470
                if ref.SetID() == nil {
×
471
                        return nil, fmt.Errorf("set ID is required to use " +
×
472
                                "the HTLC set only modifier")
×
473
                }
×
474

475
                setID = ref.SetID()
×
476
                fetchAmpHtlcs = true
×
477

478
        case HtlcSetBlankModifier:
×
479
                // No need to fetch any HTLCs.
×
480
                setID = nil
×
481
                fetchAmpHtlcs = false
×
482

483
        default:
×
484
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
485
                        ref.Modifier())
×
486
        }
487

488
        // Fetch the rest of the invoice data and fill the invoice struct.
489
        _, invoice, err := fetchInvoiceData(
×
490
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
×
491
        )
×
492
        if err != nil {
×
493
                return nil, err
×
494
        }
×
495

496
        return invoice, nil
×
497
}
498

499
// fetchAmpState fetches the AMP state for the invoice with the given ID.
500
// Optional setID can be provided to fetch the state for a specific AMP HTLC
501
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
502
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
503
// well.
504
//
505
//nolint:funlen
506
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
507
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
508
        HTLCSet, error) {
×
509

×
510
        var paramSetID []byte
×
511
        if setID != nil {
×
512
                paramSetID = setID[:]
×
513
        }
×
514

515
        // First fetch all the AMP sub invoices for this invoice or the one
516
        // matching the provided set ID.
517
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
518
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
519
                        InvoiceID: invoiceID,
×
520
                        SetID:     paramSetID,
×
521
                },
×
522
        )
×
523
        if err != nil {
×
524
                return nil, nil, err
×
525
        }
×
526

527
        ampState := make(map[SetID]InvoiceStateAMP)
×
528
        for _, row := range ampInvoiceRows {
×
529
                var rowSetID [32]byte
×
530

×
531
                if len(row.SetID) != 32 {
×
532
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
533
                                len(row.SetID))
×
534
                }
×
535

536
                var settleDate time.Time
×
537
                if row.SettledAt.Valid {
×
538
                        settleDate = row.SettledAt.Time.Local()
×
539
                }
×
540

541
                copy(rowSetID[:], row.SetID)
×
542
                ampState[rowSetID] = InvoiceStateAMP{
×
543
                        State:       HtlcState(row.State),
×
544
                        SettleIndex: uint64(row.SettleIndex.Int64),
×
545
                        SettleDate:  settleDate,
×
546
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
547
                }
×
548
        }
549

550
        if !fetchHtlcs {
×
551
                return ampState, nil, nil
×
552
        }
×
553

554
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
555
        if err != nil {
×
556
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
557
                        "invoice HTLCs: %w", err)
×
558
        }
×
559

560
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
561
        for _, row := range customRecordRows {
×
562
                if _, ok := customRecords[row.HtlcID]; !ok {
×
563
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
564
                }
×
565

566
                value := row.Value
×
567
                if value == nil {
×
568
                        value = []byte{}
×
569
                }
×
570

571
                customRecords[row.HtlcID][uint64(row.Key)] = value
×
572
        }
573

574
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
575
        // provided set ID.
576
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
×
577
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
×
578
                        InvoiceID: invoiceID,
×
579
                        SetID:     paramSetID,
×
580
                },
×
581
        )
×
582
        if err != nil {
×
583
                return nil, nil, err
×
584
        }
×
585

586
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
×
587
        for _, row := range ampHtlcRows {
×
588
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
589
                if err != nil {
×
590
                        return nil, nil, err
×
591
                }
×
592

593
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
594

×
595
                if row.HtlcID < 0 {
×
596
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
597
                                "value: %v", row.HtlcID)
×
598
                }
×
599

600
                htlcID := uint64(row.HtlcID)
×
601

×
602
                circuitKey := CircuitKey{
×
603
                        ChanID: chanID,
×
604
                        HtlcID: htlcID,
×
605
                }
×
606

×
607
                htlc := &InvoiceHTLC{
×
608
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
609
                        AcceptHeight: uint32(row.AcceptHeight),
×
610
                        AcceptTime:   row.AcceptTime.Local(),
×
611
                        Expiry:       uint32(row.ExpiryHeight),
×
612
                        State:        HtlcState(row.State),
×
613
                }
×
614

×
615
                if row.TotalMppMsat.Valid {
×
616
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
617
                                row.TotalMppMsat.Int64,
×
618
                        )
×
619
                }
×
620

621
                if row.ResolveTime.Valid {
×
622
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
623
                }
×
624

625
                var (
×
626
                        rootShare [32]byte
×
627
                        setID     [32]byte
×
628
                )
×
629

×
630
                if len(row.RootShare) != 32 {
×
631
                        return nil, nil, fmt.Errorf("invalid root share "+
×
632
                                "length: %d", len(row.RootShare))
×
633
                }
×
634
                copy(rootShare[:], row.RootShare)
×
635

×
636
                if len(row.SetID) != 32 {
×
637
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
638
                                len(row.SetID))
×
639
                }
×
640
                copy(setID[:], row.SetID)
×
641

×
642
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
643
                        return nil, nil, fmt.Errorf("invalid child index "+
×
644
                                "value: %v", row.ChildIndex)
×
645
                }
×
646

647
                ampRecord := record.NewAMP(
×
648
                        rootShare, setID, uint32(row.ChildIndex),
×
649
                )
×
650

×
651
                htlc.AMP = &InvoiceHtlcAMPData{
×
652
                        Record: *ampRecord,
×
653
                }
×
654

×
655
                if len(row.Hash) != 32 {
×
656
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
657
                                len(row.Hash))
×
658
                }
×
659
                copy(htlc.AMP.Hash[:], row.Hash)
×
660

×
661
                if row.Preimage != nil {
×
662
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
663
                        if err != nil {
×
664
                                return nil, nil, err
×
665
                        }
×
666

667
                        htlc.AMP.Preimage = &preimage
×
668
                }
669

670
                if _, ok := customRecords[row.ID]; ok {
×
671
                        htlc.CustomRecords = customRecords[row.ID]
×
672
                } else {
×
673
                        htlc.CustomRecords = make(record.CustomSet)
×
674
                }
×
675

676
                ampHtlcs[circuitKey] = htlc
×
677
        }
678

679
        if len(ampHtlcs) > 0 {
×
680
                for setID := range ampState {
×
681
                        var amtPaid lnwire.MilliSatoshi
×
682
                        invoiceKeys := make(
×
683
                                map[models.CircuitKey]struct{},
×
684
                        )
×
685

×
686
                        for key, htlc := range ampHtlcs {
×
687
                                if htlc.AMP.Record.SetID() != setID {
×
688
                                        continue
×
689
                                }
690

691
                                invoiceKeys[key] = struct{}{}
×
692

×
693
                                if htlc.State != HtlcStateCanceled {
×
694
                                        amtPaid += htlc.Amt
×
695
                                }
×
696
                        }
697

698
                        setState := ampState[setID]
×
699
                        setState.InvoiceKeys = invoiceKeys
×
700
                        setState.AmtPaid = amtPaid
×
701
                        ampState[setID] = setState
×
702
                }
703
        }
704

705
        return ampState, ampHtlcs, nil
×
706
}
707

708
// LookupInvoice attempts to look up an invoice corresponding the passed in
709
// reference. The reference may be a payment hash, a payment address, or a set
710
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
711
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
712
// error.
713
func (i *SQLStore) LookupInvoice(ctx context.Context,
714
        ref InvoiceRef) (Invoice, error) {
×
715

×
716
        var (
×
717
                invoice *Invoice
×
718
                err     error
×
719
        )
×
720

×
721
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
722
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
723
                invoice, err = fetchInvoice(ctx, db, ref)
×
724

×
725
                return err
×
726
        }, func() {})
×
727
        if txErr != nil {
×
728
                return Invoice{}, txErr
×
729
        }
×
730

731
        return *invoice, nil
×
732
}
733

734
// FetchPendingInvoices returns all the invoices that are currently in a
735
// "pending" state. An invoice is pending if it has been created but not yet
736
// settled or canceled.
737
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
738
        map[lntypes.Hash]Invoice, error) {
×
739

×
740
        var invoices map[lntypes.Hash]Invoice
×
741

×
742
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
743
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
744
                return queryWithLimit(func(offset int) (int, error) {
×
745
                        params := sqlc.FilterInvoicesParams{
×
746
                                PendingOnly: true,
×
747
                                NumOffset:   int32(offset),
×
748
                                NumLimit:    int32(i.opts.paginationLimit),
×
749
                                Reverse:     false,
×
750
                        }
×
751

×
752
                        rows, err := db.FilterInvoices(ctx, params)
×
753
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
754
                                return 0, fmt.Errorf("unable to get invoices "+
×
755
                                        "from db: %w", err)
×
756
                        }
×
757

758
                        // Load all the information for the invoices.
759
                        for _, row := range rows {
×
760
                                hash, invoice, err := fetchInvoiceData(
×
761
                                        ctx, db, row, nil, true,
×
762
                                )
×
763
                                if err != nil {
×
764
                                        return 0, err
×
765
                                }
×
766

767
                                invoices[*hash] = *invoice
×
768
                        }
769

770
                        return len(rows), nil
×
771
                }, i.opts.paginationLimit)
772
        }, func() {
×
773
                invoices = make(map[lntypes.Hash]Invoice)
×
774
        })
×
775
        if err != nil {
×
776
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
777
                        err)
×
778
        }
×
779

780
        return invoices, nil
×
781
}
782

783
// InvoicesSettledSince can be used by callers to catch up any settled invoices
784
// they missed within the settled invoice time series. We'll return all known
785
// settled invoice that have a settle index higher than the passed idx.
786
//
787
// NOTE: The index starts from 1. As a result we enforce that specifying a value
788
// below the starting index value is a noop.
789
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
790
        []Invoice, error) {
×
791

×
792
        var (
×
793
                invoices       []Invoice
×
794
                start          = time.Now()
×
795
                lastLogTime    = time.Now()
×
796
                processedCount int
×
797
        )
×
798

×
799
        if idx == 0 {
×
800
                return invoices, nil
×
801
        }
×
802

803
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
804
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
805
                err := queryWithLimit(func(offset int) (int, error) {
×
806
                        params := sqlc.FilterInvoicesParams{
×
807
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
808
                                NumOffset:      int32(offset),
×
809
                                NumLimit:       int32(i.opts.paginationLimit),
×
810
                                Reverse:        false,
×
811
                        }
×
812

×
813
                        rows, err := db.FilterInvoices(ctx, params)
×
814
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
815
                                return 0, fmt.Errorf("unable to get invoices "+
×
816
                                        "from db: %w", err)
×
817
                        }
×
818

819
                        // Load all the information for the invoices.
820
                        for _, row := range rows {
×
821
                                _, invoice, err := fetchInvoiceData(
×
822
                                        ctx, db, row, nil, true,
×
823
                                )
×
824
                                if err != nil {
×
825
                                        return 0, fmt.Errorf("unable to fetch "+
×
826
                                                "invoice(id=%d) from db: %w",
×
827
                                                row.ID, err)
×
828
                                }
×
829

830
                                invoices = append(invoices, *invoice)
×
831

×
832
                                processedCount++
×
833
                                if time.Since(lastLogTime) >=
×
834
                                        invoiceProgressLogInterval {
×
835

×
836
                                        log.Debugf("Processed %d settled "+
×
837
                                                "invoices which have a settle "+
×
838
                                                "index greater than %v",
×
839
                                                processedCount, idx)
×
840

×
841
                                        lastLogTime = time.Now()
×
842
                                }
×
843
                        }
844

845
                        return len(rows), nil
×
846
                }, i.opts.paginationLimit)
847
                if err != nil {
×
848
                        return err
×
849
                }
×
850

851
                // Now fetch all the AMP sub invoices that were settled since
852
                // the provided index.
853
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
×
854
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
855
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
856
                        },
×
857
                )
×
858
                if err != nil {
×
859
                        return err
×
860
                }
×
861

862
                for _, ampInvoice := range ampInvoices {
×
863
                        // Convert the row to a sqlc.Invoice so we can use the
×
864
                        // existing fetchInvoiceData function.
×
865
                        sqlInvoice := sqlc.Invoice{
×
866
                                ID:             ampInvoice.ID,
×
867
                                Hash:           ampInvoice.Hash,
×
868
                                Preimage:       ampInvoice.Preimage,
×
869
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
870
                                SettledAt:      ampInvoice.AmpSettledAt,
×
871
                                Memo:           ampInvoice.Memo,
×
872
                                AmountMsat:     ampInvoice.AmountMsat,
×
873
                                CltvDelta:      ampInvoice.CltvDelta,
×
874
                                Expiry:         ampInvoice.Expiry,
×
875
                                PaymentAddr:    ampInvoice.PaymentAddr,
×
876
                                PaymentRequest: ampInvoice.PaymentRequest,
×
877
                                State:          ampInvoice.State,
×
878
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
879
                                IsAmp:          ampInvoice.IsAmp,
×
880
                                IsHodl:         ampInvoice.IsHodl,
×
881
                                IsKeysend:      ampInvoice.IsKeysend,
×
882
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
883
                        }
×
884

×
885
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
886
                        _, invoice, err := fetchInvoiceData(
×
887
                                ctx, db, sqlInvoice,
×
888
                                (*[32]byte)(ampInvoice.SetID), true,
×
889
                        )
×
890
                        if err != nil {
×
891
                                return fmt.Errorf("unable to fetch "+
×
892
                                        "AMP invoice(id=%d) from db: %w",
×
893
                                        ampInvoice.ID, err)
×
894
                        }
×
895

896
                        invoices = append(invoices, *invoice)
×
897

×
898
                        processedCount++
×
899
                        if time.Since(lastLogTime) >=
×
900
                                invoiceProgressLogInterval {
×
901

×
902
                                log.Debugf("Processed %d settled invoices "+
×
903
                                        "including AMP sub invoices which "+
×
904
                                        "have a settle index greater than %v",
×
905
                                        processedCount, idx)
×
906

×
907
                                lastLogTime = time.Now()
×
908
                        }
×
909
                }
910

911
                return nil
×
912
        }, func() {
×
913
                invoices = nil
×
914
        })
×
915
        if err != nil {
×
916
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
917
                        "index (excluding) %d: %w", idx, err)
×
918
        }
×
919

920
        elapsed := time.Since(start)
×
921
        log.Debugf("Completed scanning for settled invoices starting at "+
×
922
                "index %v: total_processed=%d, found_invoices=%d, elapsed=%v",
×
923
                idx, processedCount, len(invoices),
×
924
                elapsed.Round(time.Millisecond))
×
925

×
926
        return invoices, nil
×
927
}
928

929
// InvoicesAddedSince can be used by callers to seek into the event time series
930
// of all the invoices added in the database. This method will return all
931
// invoices with an add index greater than the specified idx.
932
//
933
// NOTE: The index starts from 1. As a result we enforce that specifying a value
934
// below the starting index value is a noop.
935
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
936
        []Invoice, error) {
×
937

×
938
        var (
×
939
                result         []Invoice
×
940
                start          = time.Now()
×
941
                lastLogTime    = time.Now()
×
942
                processedCount int
×
943
        )
×
944

×
945
        if idx == 0 {
×
946
                return result, nil
×
947
        }
×
948

949
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
950
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
951
                return queryWithLimit(func(offset int) (int, error) {
×
952
                        params := sqlc.FilterInvoicesParams{
×
953
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
954
                                NumOffset:   int32(offset),
×
955
                                NumLimit:    int32(i.opts.paginationLimit),
×
956
                                Reverse:     false,
×
957
                        }
×
958

×
959
                        rows, err := db.FilterInvoices(ctx, params)
×
960
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
961
                                return 0, fmt.Errorf("unable to get invoices "+
×
962
                                        "from db: %w", err)
×
963
                        }
×
964

965
                        // Load all the information for the invoices.
966
                        for _, row := range rows {
×
967
                                _, invoice, err := fetchInvoiceData(
×
968
                                        ctx, db, row, nil, true,
×
969
                                )
×
970
                                if err != nil {
×
971
                                        return 0, err
×
972
                                }
×
973

974
                                result = append(result, *invoice)
×
975

×
976
                                processedCount++
×
977
                                if time.Since(lastLogTime) >=
×
978
                                        invoiceProgressLogInterval {
×
979

×
980
                                        log.Debugf("Processed %d invoices "+
×
981
                                                "which were added since add "+
×
982
                                                "index %v", processedCount, idx)
×
983

×
984
                                        lastLogTime = time.Now()
×
985
                                }
×
986
                        }
987

988
                        return len(rows), nil
×
989
                }, i.opts.paginationLimit)
990
        }, func() {
×
991
                result = nil
×
992
        })
×
993

994
        if err != nil {
×
995
                return nil, fmt.Errorf("unable to get invoices added since "+
×
996
                        "index %d: %w", idx, err)
×
997
        }
×
998

999
        elapsed := time.Since(start)
×
1000
        log.Debugf("Completed scanning for invoices added since index %v: "+
×
1001
                "total_processed=%d, found_invoices=%d, elapsed=%v",
×
1002
                idx, processedCount, len(result),
×
1003
                elapsed.Round(time.Millisecond))
×
1004

×
1005
        return result, nil
×
1006
}
1007

1008
// QueryInvoices allows a caller to query the invoice database for invoices
1009
// within the specified add index range.
1010
func (i *SQLStore) QueryInvoices(ctx context.Context,
1011
        q InvoiceQuery) (InvoiceSlice, error) {
×
1012

×
1013
        var invoices []Invoice
×
1014

×
1015
        if q.NumMaxInvoices == 0 {
×
1016
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
1017
                        "be non-zero")
×
1018
        }
×
1019

1020
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
1021
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
1022
                return queryWithLimit(func(offset int) (int, error) {
×
1023
                        params := sqlc.FilterInvoicesParams{
×
1024
                                NumOffset:   int32(offset),
×
1025
                                NumLimit:    int32(i.opts.paginationLimit),
×
1026
                                PendingOnly: q.PendingOnly,
×
1027
                                Reverse:     q.Reversed,
×
1028
                        }
×
1029

×
1030
                        if q.Reversed {
×
1031
                                // If the index offset was not set, we want to
×
1032
                                // fetch from the lastest invoice.
×
1033
                                if q.IndexOffset == 0 {
×
1034
                                        params.AddIndexLet = sqldb.SQLInt64(
×
1035
                                                int64(math.MaxInt64),
×
1036
                                        )
×
1037
                                } else {
×
1038
                                        // The invoice with index offset id must
×
1039
                                        // not be included in the results.
×
1040
                                        params.AddIndexLet = sqldb.SQLInt64(
×
1041
                                                q.IndexOffset - 1,
×
1042
                                        )
×
1043
                                }
×
1044
                        } else {
×
1045
                                // The invoice with index offset id must not be
×
1046
                                // included in the results.
×
1047
                                params.AddIndexGet = sqldb.SQLInt64(
×
1048
                                        q.IndexOffset + 1,
×
1049
                                )
×
1050
                        }
×
1051

1052
                        if q.CreationDateStart != 0 {
×
1053
                                params.CreatedAfter = sqldb.SQLTime(
×
1054
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
1055
                                )
×
1056
                        }
×
1057

1058
                        if q.CreationDateEnd != 0 {
×
1059
                                // We need to add 1 to the end date as we're
×
1060
                                // checking less than the end date in SQL.
×
1061
                                params.CreatedBefore = sqldb.SQLTime(
×
1062
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
×
1063
                                )
×
1064
                        }
×
1065

1066
                        rows, err := db.FilterInvoices(ctx, params)
×
1067
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1068
                                return 0, fmt.Errorf("unable to get invoices "+
×
1069
                                        "from db: %w", err)
×
1070
                        }
×
1071

1072
                        // Load all the information for the invoices.
1073
                        for _, row := range rows {
×
1074
                                _, invoice, err := fetchInvoiceData(
×
1075
                                        ctx, db, row, nil, true,
×
1076
                                )
×
1077
                                if err != nil {
×
1078
                                        return 0, err
×
1079
                                }
×
1080

1081
                                invoices = append(invoices, *invoice)
×
1082

×
1083
                                if len(invoices) == int(q.NumMaxInvoices) {
×
1084
                                        return 0, nil
×
1085
                                }
×
1086
                        }
1087

1088
                        return len(rows), nil
×
1089
                }, i.opts.paginationLimit)
1090
        }, func() {
×
1091
                invoices = nil
×
1092
        })
×
1093
        if err != nil {
×
1094
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
1095
                        "invoices: %w", err)
×
1096
        }
×
1097

1098
        if len(invoices) == 0 {
×
1099
                return InvoiceSlice{
×
1100
                        InvoiceQuery: q,
×
1101
                }, nil
×
1102
        }
×
1103

1104
        // If we iterated through the add index in reverse order, then
1105
        // we'll need to reverse the slice of invoices to return them in
1106
        // forward order.
1107
        if q.Reversed {
×
1108
                numInvoices := len(invoices)
×
1109
                for i := 0; i < numInvoices/2; i++ {
×
1110
                        reverse := numInvoices - i - 1
×
1111
                        invoices[i], invoices[reverse] =
×
1112
                                invoices[reverse], invoices[i]
×
1113
                }
×
1114
        }
1115

1116
        res := InvoiceSlice{
×
1117
                InvoiceQuery:     q,
×
1118
                Invoices:         invoices,
×
1119
                FirstIndexOffset: invoices[0].AddIndex,
×
1120
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
1121
        }
×
1122

×
1123
        return res, nil
×
1124
}
1125

1126
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1127
// a SQL database as the backend.
1128
type sqlInvoiceUpdater struct {
1129
        db         SQLInvoiceQueries
1130
        ctx        context.Context //nolint:containedctx
1131
        invoice    *Invoice
1132
        updateTime time.Time
1133
}
1134

1135
// AddHtlc adds a new htlc to the invoice.
1136
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
1137
        newHtlc *InvoiceHTLC) error {
×
1138

×
1139
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
1140
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
1141
                        HtlcID: int64(circuitKey.HtlcID),
×
1142
                        ChanID: strconv.FormatUint(
×
1143
                                circuitKey.ChanID.ToUint64(), 10,
×
1144
                        ),
×
1145
                        AmountMsat: int64(newHtlc.Amt),
×
1146
                        TotalMppMsat: sql.NullInt64{
×
1147
                                Int64: int64(newHtlc.MppTotalAmt),
×
1148
                                Valid: newHtlc.MppTotalAmt != 0,
×
1149
                        },
×
1150
                        AcceptHeight: int32(newHtlc.AcceptHeight),
×
1151
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
1152
                        ExpiryHeight: int32(newHtlc.Expiry),
×
1153
                        State:        int16(newHtlc.State),
×
1154
                        InvoiceID:    int64(s.invoice.AddIndex),
×
1155
                },
×
1156
        )
×
1157
        if err != nil {
×
1158
                return err
×
1159
        }
×
1160

1161
        for key, value := range newHtlc.CustomRecords {
×
1162
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
1163
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
1164
                                // TODO(bhandras): schema might be wrong here
×
1165
                                // as the custom record key is an uint64.
×
1166
                                Key:    int64(key),
×
1167
                                Value:  value,
×
1168
                                HtlcID: htlcPrimaryKeyID,
×
1169
                        },
×
1170
                )
×
1171
                if err != nil {
×
1172
                        return err
×
1173
                }
×
1174
        }
1175

1176
        if newHtlc.AMP != nil {
×
1177
                setID := newHtlc.AMP.Record.SetID()
×
1178

×
1179
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
1180
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
1181
                                SetID:     setID[:],
×
1182
                                CreatedAt: s.updateTime.UTC(),
×
1183
                                InvoiceID: int64(s.invoice.AddIndex),
×
1184
                        },
×
1185
                )
×
1186
                if err != nil {
×
1187
                        mappedSQLErr := sqldb.MapSQLError(err)
×
1188
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
×
1189
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
1190
                                return ErrDuplicateSetID{
×
1191
                                        SetID: setID,
×
1192
                                }
×
1193
                        }
×
1194

1195
                        return err
×
1196
                }
1197

1198
                // If we're just inserting the AMP invoice, we'll get a non
1199
                // zero rows affected count.
1200
                rowsAffected, err := upsertResult.RowsAffected()
×
1201
                if err != nil {
×
1202
                        return err
×
1203
                }
×
1204
                if rowsAffected != 0 {
×
1205
                        // If we're inserting a new AMP invoice, we'll also
×
1206
                        // insert a new invoice event.
×
1207
                        err = s.db.OnAMPSubInvoiceCreated(
×
1208
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
1209
                                        AddedAt:   s.updateTime.UTC(),
×
1210
                                        InvoiceID: int64(s.invoice.AddIndex),
×
1211
                                        SetID:     setID[:],
×
1212
                                },
×
1213
                        )
×
1214
                        if err != nil {
×
1215
                                return err
×
1216
                        }
×
1217
                }
1218

1219
                rootShare := newHtlc.AMP.Record.RootShare()
×
1220

×
1221
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
1222
                        InvoiceID: int64(s.invoice.AddIndex),
×
1223
                        SetID:     setID[:],
×
1224
                        HtlcID:    htlcPrimaryKeyID,
×
1225
                        RootShare: rootShare[:],
×
1226
                        ChildIndex: int64(
×
1227
                                newHtlc.AMP.Record.ChildIndex(),
×
1228
                        ),
×
1229
                        Hash: newHtlc.AMP.Hash[:],
×
1230
                }
×
1231

×
1232
                if newHtlc.AMP.Preimage != nil {
×
1233
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
1234
                }
×
1235

1236
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
1237
                if err != nil {
×
1238
                        return err
×
1239
                }
×
1240
        }
1241

1242
        return nil
×
1243
}
1244

1245
// ResolveHtlc marks an htlc as resolved with the given state.
1246
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
1247
        state HtlcState, resolveTime time.Time) error {
×
1248

×
1249
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
1250
                HtlcID: int64(circuitKey.HtlcID),
×
1251
                ChanID: strconv.FormatUint(
×
1252
                        circuitKey.ChanID.ToUint64(), 10,
×
1253
                ),
×
1254
                InvoiceID:   int64(s.invoice.AddIndex),
×
1255
                State:       int16(state),
×
1256
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
1257
        })
×
1258
}
×
1259

1260
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1261
// identified by the setID.
1262
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
1263
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
1264

×
1265
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
1266
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
×
1267
                        InvoiceID: int64(s.invoice.AddIndex),
×
1268
                        SetID:     setID[:],
×
1269
                        HtlcID:    int64(circuitKey.HtlcID),
×
1270
                        Preimage:  preimage[:],
×
1271
                        ChanID: strconv.FormatUint(
×
1272
                                circuitKey.ChanID.ToUint64(), 10,
×
1273
                        ),
×
1274
                },
×
1275
        )
×
1276
        if err != nil {
×
1277
                return err
×
1278
        }
×
1279

1280
        rowsAffected, err := result.RowsAffected()
×
1281
        if err != nil {
×
1282
                return err
×
1283
        }
×
1284
        if rowsAffected == 0 {
×
1285
                return ErrInvoiceNotFound
×
1286
        }
×
1287

1288
        return nil
×
1289
}
1290

1291
// UpdateInvoiceState updates the invoice state to the new state.
1292
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
1293
        newState ContractState, preimage *lntypes.Preimage) error {
×
1294

×
1295
        var (
×
1296
                settleIndex sql.NullInt64
×
1297
                settledAt   sql.NullTime
×
1298
        )
×
1299

×
1300
        switch newState {
×
1301
        case ContractSettled:
×
1302
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1303
                if err != nil {
×
1304
                        return err
×
1305
                }
×
1306

1307
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1308

×
1309
                // If the invoice is settled, we'll also update the settle time.
×
1310
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
1311

×
1312
                err = s.db.OnInvoiceSettled(
×
1313
                        s.ctx, sqlc.OnInvoiceSettledParams{
×
1314
                                AddedAt:   s.updateTime.UTC(),
×
1315
                                InvoiceID: int64(s.invoice.AddIndex),
×
1316
                        },
×
1317
                )
×
1318
                if err != nil {
×
1319
                        return err
×
1320
                }
×
1321

1322
        case ContractCanceled:
×
1323
                err := s.db.OnInvoiceCanceled(
×
1324
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
1325
                                AddedAt:   s.updateTime.UTC(),
×
1326
                                InvoiceID: int64(s.invoice.AddIndex),
×
1327
                        },
×
1328
                )
×
1329
                if err != nil {
×
1330
                        return err
×
1331
                }
×
1332
        }
1333

1334
        params := sqlc.UpdateInvoiceStateParams{
×
1335
                ID:          int64(s.invoice.AddIndex),
×
1336
                State:       int16(newState),
×
1337
                SettleIndex: settleIndex,
×
1338
                SettledAt:   settledAt,
×
1339
        }
×
1340

×
1341
        if preimage != nil {
×
1342
                params.Preimage = preimage[:]
×
1343
        }
×
1344

1345
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
1346
        if err != nil {
×
1347
                return err
×
1348
        }
×
1349
        rowsAffected, err := result.RowsAffected()
×
1350
        if err != nil {
×
1351
                return err
×
1352
        }
×
1353

1354
        if rowsAffected == 0 {
×
1355
                return ErrInvoiceNotFound
×
1356
        }
×
1357

1358
        if settleIndex.Valid {
×
1359
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
×
1360
                s.invoice.SettleDate = s.updateTime
×
1361
        }
×
1362

1363
        return nil
×
1364
}
1365

1366
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1367
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
1368
        amtPaid lnwire.MilliSatoshi) error {
×
1369

×
1370
        _, err := s.db.UpdateInvoiceAmountPaid(
×
1371
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
1372
                        ID:             int64(s.invoice.AddIndex),
×
1373
                        AmountPaidMsat: int64(amtPaid),
×
1374
                },
×
1375
        )
×
1376

×
1377
        return err
×
1378
}
×
1379

1380
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1381
// setID.
1382
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
1383
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
1384

×
1385
        var (
×
1386
                settleIndex sql.NullInt64
×
1387
                settledAt   sql.NullTime
×
1388
        )
×
1389

×
1390
        switch newState.State {
×
1391
        case HtlcStateSettled:
×
1392
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1393
                if err != nil {
×
1394
                        return err
×
1395
                }
×
1396

1397
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1398

×
1399
                // If the invoice is settled, we'll also update the settle time.
×
1400
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
1401

×
1402
                err = s.db.OnAMPSubInvoiceSettled(
×
1403
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
×
1404
                                AddedAt:   s.updateTime.UTC(),
×
1405
                                InvoiceID: int64(s.invoice.AddIndex),
×
1406
                                SetID:     setID[:],
×
1407
                        },
×
1408
                )
×
1409
                if err != nil {
×
1410
                        return err
×
1411
                }
×
1412

1413
        case HtlcStateCanceled:
×
1414
                err := s.db.OnAMPSubInvoiceCanceled(
×
1415
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
1416
                                AddedAt:   s.updateTime.UTC(),
×
1417
                                InvoiceID: int64(s.invoice.AddIndex),
×
1418
                                SetID:     setID[:],
×
1419
                        },
×
1420
                )
×
1421
                if err != nil {
×
1422
                        return err
×
1423
                }
×
1424
        }
1425

1426
        err := s.db.UpdateAMPSubInvoiceState(
×
1427
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
×
1428
                        SetID:       setID[:],
×
1429
                        State:       int16(newState.State),
×
1430
                        SettleIndex: settleIndex,
×
1431
                        SettledAt:   settledAt,
×
1432
                },
×
1433
        )
×
1434
        if err != nil {
×
1435
                return err
×
1436
        }
×
1437

1438
        if settleIndex.Valid {
×
1439
                updatedState := s.invoice.AMPState[setID]
×
1440
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
1441
                updatedState.SettleDate = s.updateTime.UTC()
×
1442
                s.invoice.AMPState[setID] = updatedState
×
1443
        }
×
1444

1445
        return nil
×
1446
}
1447

1448
// Finalize finalizes the update before it is written to the database. Note that
1449
// we don't use this directly in the SQL implementation, so the function is just
1450
// a stub.
1451
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
×
1452
        return nil
×
1453
}
×
1454

1455
// UpdateInvoice attempts to update an invoice corresponding to the passed
1456
// reference. If an invoice matching the passed reference doesn't exist within
1457
// the database, then the action will fail with  ErrInvoiceNotFound error.
1458
//
1459
// The update is performed inside the same database transaction that fetches the
1460
// invoice and is therefore atomic. The fields to update are controlled by the
1461
// supplied callback.
1462
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1463
        setID *SetID, callback InvoiceUpdateCallback) (
1464
        *Invoice, error) {
×
1465

×
1466
        var updatedInvoice *Invoice
×
1467

×
1468
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
×
1469
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
×
1470
                switch {
×
1471
                // For the default case we fetch all HTLCs.
1472
                case setID == nil:
×
1473
                        ref.refModifier = DefaultModifier
×
1474

1475
                // If the setID is the blank but NOT nil, we set the
1476
                // refModifier to HtlcSetBlankModifier to fetch no HTLC for the
1477
                // AMP invoice.
1478
                case *setID == BlankPayAddr:
×
1479
                        ref.refModifier = HtlcSetBlankModifier
×
1480

1481
                // A setID is provided, we use the refModifier to fetch only
1482
                // the HTLCs for the given setID and also make sure we add the
1483
                // setID to the ref.
1484
                default:
×
1485
                        var setIDBytes [32]byte
×
1486
                        copy(setIDBytes[:], setID[:])
×
1487
                        ref.setID = &setIDBytes
×
1488

×
1489
                        // We only fetch the HTLCs for the given setID.
×
1490
                        ref.refModifier = HtlcSetOnlyModifier
×
1491
                }
1492

1493
                invoice, err := fetchInvoice(ctx, db, ref)
×
1494
                if err != nil {
×
1495
                        return err
×
1496
                }
×
1497

1498
                updateTime := i.clock.Now()
×
1499
                updater := &sqlInvoiceUpdater{
×
1500
                        db:         db,
×
1501
                        ctx:        ctx,
×
1502
                        invoice:    invoice,
×
1503
                        updateTime: updateTime,
×
1504
                }
×
1505

×
1506
                payHash := ref.PayHash()
×
1507
                updatedInvoice, err = UpdateInvoice(
×
1508
                        payHash, invoice, updateTime, callback, updater,
×
1509
                )
×
1510

×
1511
                return err
×
1512
        }, func() {})
×
1513
        if txErr != nil {
×
1514
                // If the invoice is already settled, we'll return the
×
1515
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
1516
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
1517
                        return updatedInvoice, txErr
×
1518
                }
×
1519

1520
                return nil, txErr
×
1521
        }
1522

1523
        return updatedInvoice, nil
×
1524
}
1525

1526
// DeleteInvoice attempts to delete the passed invoices and all their related
1527
// data from the database in one transaction.
1528
func (i *SQLStore) DeleteInvoice(ctx context.Context,
1529
        invoicesToDelete []InvoiceDeleteRef) error {
×
1530

×
1531
        // All the InvoiceDeleteRef instances include the add index of the
×
1532
        // invoice. The rest was added to ensure that the invoices were deleted
×
1533
        // properly in the kv database. When we have fully migrated we can
×
1534
        // remove the rest of the fields.
×
1535
        for _, ref := range invoicesToDelete {
×
1536
                if ref.AddIndex == 0 {
×
1537
                        return fmt.Errorf("unable to delete invoice using a "+
×
1538
                                "ref without AddIndex set: %v", ref)
×
1539
                }
×
1540
        }
1541

1542
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
1543
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
1544
                for _, ref := range invoicesToDelete {
×
1545
                        params := sqlc.DeleteInvoiceParams{
×
1546
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
1547
                        }
×
1548

×
1549
                        if ref.SettleIndex != 0 {
×
1550
                                params.SettleIndex = sqldb.SQLInt64(
×
1551
                                        ref.SettleIndex,
×
1552
                                )
×
1553
                        }
×
1554

1555
                        if ref.PayHash != lntypes.ZeroHash {
×
1556
                                params.Hash = ref.PayHash[:]
×
1557
                        }
×
1558

1559
                        result, err := db.DeleteInvoice(ctx, params)
×
1560
                        if err != nil {
×
1561
                                return fmt.Errorf("unable to delete "+
×
1562
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1563
                        }
×
1564
                        rowsAffected, err := result.RowsAffected()
×
1565
                        if err != nil {
×
1566
                                return fmt.Errorf("unable to get rows "+
×
1567
                                        "affected: %w", err)
×
1568
                        }
×
1569
                        if rowsAffected == 0 {
×
1570
                                return fmt.Errorf("%w: %v",
×
1571
                                        ErrInvoiceNotFound, ref.AddIndex)
×
1572
                        }
×
1573
                }
1574

1575
                return nil
×
1576
        }, func() {})
×
1577

1578
        if err != nil {
×
1579
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1580
        }
×
1581

1582
        return nil
×
1583
}
1584

1585
// DeleteCanceledInvoices removes all canceled invoices from the database.
1586
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
1587
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
1588
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
1589
                _, err := db.DeleteCanceledInvoices(ctx)
×
1590
                if err != nil {
×
1591
                        return fmt.Errorf("unable to delete canceled "+
×
1592
                                "invoices: %w", err)
×
1593
                }
×
1594

1595
                return nil
×
1596
        }, func() {})
×
1597
        if err != nil {
×
1598
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1599
        }
×
1600

1601
        return nil
×
1602
}
1603

1604
// fetchInvoiceData fetches additional data for the given invoice. If the
1605
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1606
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1607
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1608
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1609
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
1610
        *Invoice, error) {
×
1611

×
1612
        // Unmarshal the common data.
×
1613
        hash, invoice, err := unmarshalInvoice(row)
×
1614
        if err != nil {
×
1615
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1616
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1617
        }
×
1618

1619
        // Fetch the invoice features.
1620
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
1621
        if err != nil {
×
1622
                return nil, nil, err
×
1623
        }
×
1624

1625
        invoice.Terms.Features = features
×
1626

×
1627
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
1628
        // with the HTLCs (if requested).
×
1629
        if invoice.IsAMP() {
×
1630
                invoiceID := int64(invoice.AddIndex)
×
1631
                ampState, ampHtlcs, err := fetchAmpState(
×
1632
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
1633
                )
×
1634
                if err != nil {
×
1635
                        return nil, nil, err
×
1636
                }
×
1637

1638
                invoice.AMPState = ampState
×
1639
                invoice.Htlcs = ampHtlcs
×
1640

×
1641
                return hash, invoice, nil
×
1642
        }
1643

1644
        // Otherwise simply fetch the invoice HTLCs.
1645
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
1646
        if err != nil {
×
1647
                return nil, nil, err
×
1648
        }
×
1649

1650
        if len(htlcs) > 0 {
×
1651
                invoice.Htlcs = htlcs
×
1652
        }
×
1653

1654
        return hash, invoice, nil
×
1655
}
1656

1657
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1658
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
1659
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
1660

×
1661
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
1662
        if err != nil {
×
1663
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1664
                        err)
×
1665
        }
×
1666

1667
        features := lnwire.EmptyFeatureVector()
×
1668
        for _, feature := range rows {
×
1669
                features.Set(lnwire.FeatureBit(feature.Feature))
×
1670
        }
×
1671

1672
        return features, nil
×
1673
}
1674

1675
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1676
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
1677
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
1678

×
1679
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
1680
        if err != nil {
×
1681
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1682
        }
×
1683

1684
        // We have no htlcs to unmarshal.
1685
        if len(htlcRows) == 0 {
×
1686
                return nil, nil
×
1687
        }
×
1688

1689
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
1690
        if err != nil {
×
1691
                return nil, fmt.Errorf("unable to get custom records for "+
×
1692
                        "invoice htlcs: %w", err)
×
1693
        }
×
1694

1695
        cr := make(map[int64]record.CustomSet, len(crRows))
×
1696
        for _, row := range crRows {
×
1697
                if _, ok := cr[row.HtlcID]; !ok {
×
1698
                        cr[row.HtlcID] = make(record.CustomSet)
×
1699
                }
×
1700

1701
                value := row.Value
×
1702
                if value == nil {
×
1703
                        value = []byte{}
×
1704
                }
×
1705
                cr[row.HtlcID][uint64(row.Key)] = value
×
1706
        }
1707

1708
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
1709

×
1710
        for _, row := range htlcRows {
×
1711
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
1712
                if err != nil {
×
1713
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1714
                                "htlc(%d): %w", row.ID, err)
×
1715
                }
×
1716

1717
                if customRecords, ok := cr[row.ID]; ok {
×
1718
                        htlc.CustomRecords = customRecords
×
1719
                } else {
×
1720
                        htlc.CustomRecords = make(record.CustomSet)
×
1721
                }
×
1722

1723
                htlcs[circuiteKey] = htlc
×
1724
        }
1725

1726
        return htlcs, nil
×
1727
}
1728

1729
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1730
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
1731
        error) {
×
1732

×
1733
        var (
×
1734
                settleIndex    int64
×
1735
                settledAt      time.Time
×
1736
                memo           []byte
×
1737
                paymentRequest []byte
×
1738
                preimage       *lntypes.Preimage
×
1739
                paymentAddr    [32]byte
×
1740
        )
×
1741

×
1742
        hash, err := lntypes.MakeHash(row.Hash)
×
1743
        if err != nil {
×
1744
                return nil, nil, err
×
1745
        }
×
1746

1747
        if row.SettleIndex.Valid {
×
1748
                settleIndex = row.SettleIndex.Int64
×
1749
        }
×
1750

1751
        if row.SettledAt.Valid {
×
1752
                settledAt = row.SettledAt.Time.Local()
×
1753
        }
×
1754

1755
        if row.Memo.Valid {
×
1756
                memo = []byte(row.Memo.String)
×
1757
        }
×
1758

1759
        // Keysend payments will have this field empty.
1760
        if row.PaymentRequest.Valid {
×
1761
                paymentRequest = []byte(row.PaymentRequest.String)
×
1762
        } else {
×
1763
                paymentRequest = []byte{}
×
1764
        }
×
1765

1766
        // We may not have the preimage if this a hodl invoice.
1767
        if row.Preimage != nil {
×
1768
                preimage = &lntypes.Preimage{}
×
1769
                copy(preimage[:], row.Preimage)
×
1770
        }
×
1771

1772
        copy(paymentAddr[:], row.PaymentAddr)
×
1773

×
1774
        var cltvDelta int32
×
1775
        if row.CltvDelta.Valid {
×
1776
                cltvDelta = row.CltvDelta.Int32
×
1777
        }
×
1778

1779
        expiry := time.Duration(row.Expiry) * time.Second
×
1780

×
1781
        invoice := &Invoice{
×
1782
                SettleIndex:    uint64(settleIndex),
×
1783
                SettleDate:     settledAt,
×
1784
                Memo:           memo,
×
1785
                PaymentRequest: paymentRequest,
×
1786
                CreationDate:   row.CreatedAt.Local(),
×
1787
                Terms: ContractTerm{
×
1788
                        FinalCltvDelta:  cltvDelta,
×
1789
                        Expiry:          expiry,
×
1790
                        PaymentPreimage: preimage,
×
1791
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
×
1792
                        PaymentAddr:     paymentAddr,
×
1793
                },
×
1794
                AddIndex:    uint64(row.ID),
×
1795
                State:       ContractState(row.State),
×
1796
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
1797
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
1798
                AMPState:    AMPInvoiceState{},
×
1799
                HodlInvoice: row.IsHodl,
×
1800
        }
×
1801

×
1802
        return &hash, invoice, nil
×
1803
}
1804

1805
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1806
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
1807
        *InvoiceHTLC, error) {
×
1808

×
1809
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
1810
        if err != nil {
×
1811
                return CircuitKey{}, nil, err
×
1812
        }
×
1813

1814
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
1815

×
1816
        if row.HtlcID < 0 {
×
1817
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1818
                        "value: %v", row.HtlcID)
×
1819
        }
×
1820

1821
        htlcID := uint64(row.HtlcID)
×
1822

×
1823
        circuitKey := CircuitKey{
×
1824
                ChanID: chanID,
×
1825
                HtlcID: htlcID,
×
1826
        }
×
1827

×
1828
        htlc := &InvoiceHTLC{
×
1829
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
1830
                AcceptHeight: uint32(row.AcceptHeight),
×
1831
                AcceptTime:   row.AcceptTime.Local(),
×
1832
                Expiry:       uint32(row.ExpiryHeight),
×
1833
                State:        HtlcState(row.State),
×
1834
        }
×
1835

×
1836
        if row.TotalMppMsat.Valid {
×
1837
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
1838
        }
×
1839

1840
        if row.ResolveTime.Valid {
×
1841
                htlc.ResolveTime = row.ResolveTime.Time.Local()
×
1842
        }
×
1843

1844
        return circuitKey, htlc, nil
×
1845
}
1846

1847
// queryWithLimit is a helper method that can be used to query the database
1848
// using a limit and offset. The passed query function should return the number
1849
// of rows returned and an error if any.
1850
func queryWithLimit(query func(int) (int, error), limit int) error {
×
1851
        offset := 0
×
1852
        for {
×
1853
                rows, err := query(offset)
×
1854
                if err != nil {
×
1855
                        return err
×
1856
                }
×
1857

1858
                if rows < limit {
×
1859
                        return nil
×
1860
                }
×
1861

1862
                offset += limit
×
1863
        }
1864
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc