• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15747152963

19 Jun 2025 01:22AM UTC coverage: 58.151% (-10.1%) from 68.248%
15747152963

push

github

web-flow
Merge pull request #9528 from Roasbeef/res-opt

fn: implement ResultOpt type for operations with optional values

97778 of 168145 relevant lines covered (58.15%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27

28
        // invoiceProgressLogInterval is the interval we use limiting the
29
        // logging output of invoice processing.
30
        invoiceProgressLogInterval = 30 * time.Second
31
)
32

33
// SQLInvoiceQueries is an interface that defines the set of operations that can
34
// be executed against the invoice SQL database.
35
type SQLInvoiceQueries interface { //nolint:interfacebloat
36
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
37
                error)
38

39
        // TODO(bhandras): remove this once migrations have been separated out.
40
        InsertMigratedInvoice(ctx context.Context,
41
                arg sqlc.InsertMigratedInvoiceParams) (int64, error)
42

43
        InsertInvoiceFeature(ctx context.Context,
44
                arg sqlc.InsertInvoiceFeatureParams) error
45

46
        InsertInvoiceHTLC(ctx context.Context,
47
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
48

49
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
50
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
51

52
        FilterInvoices(ctx context.Context,
53
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
54

55
        GetInvoice(ctx context.Context,
56
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
57

58
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
59
                error)
60

61
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
62
                error)
63

64
        GetInvoiceFeatures(ctx context.Context,
65
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
66

67
        GetInvoiceHTLCCustomRecords(ctx context.Context,
68
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
69

70
        GetInvoiceHTLCs(ctx context.Context,
71
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
72

73
        UpdateInvoiceState(ctx context.Context,
74
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
75

76
        UpdateInvoiceAmountPaid(ctx context.Context,
77
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
78

79
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
80

81
        UpdateInvoiceHTLC(ctx context.Context,
82
                arg sqlc.UpdateInvoiceHTLCParams) error
83

84
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
85
                sql.Result, error)
86

87
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
88

89
        // AMP sub invoice specific methods.
90
        UpsertAMPSubInvoice(ctx context.Context,
91
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
92

93
        // TODO(bhandras): remove this once migrations have been separated out.
94
        InsertAMPSubInvoice(ctx context.Context,
95
                arg sqlc.InsertAMPSubInvoiceParams) error
96

97
        UpdateAMPSubInvoiceState(ctx context.Context,
98
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
99

100
        InsertAMPSubInvoiceHTLC(ctx context.Context,
101
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
102

103
        FetchAMPSubInvoices(ctx context.Context,
104
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
105
                error)
106

107
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
108
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
109
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
110

111
        FetchSettledAMPSubInvoices(ctx context.Context,
112
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
113
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
114

115
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
116
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
117
                error)
118

119
        // Invoice events specific methods.
120
        OnInvoiceCreated(ctx context.Context,
121
                arg sqlc.OnInvoiceCreatedParams) error
122

123
        OnInvoiceCanceled(ctx context.Context,
124
                arg sqlc.OnInvoiceCanceledParams) error
125

126
        OnInvoiceSettled(ctx context.Context,
127
                arg sqlc.OnInvoiceSettledParams) error
128

129
        OnAMPSubInvoiceCreated(ctx context.Context,
130
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
131

132
        OnAMPSubInvoiceCanceled(ctx context.Context,
133
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
134

135
        OnAMPSubInvoiceSettled(ctx context.Context,
136
                arg sqlc.OnAMPSubInvoiceSettledParams) error
137

138
        // Migration specific methods.
139
        // TODO(bhandras): remove this once migrations have been separated out.
140
        InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
141
                arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
142

143
        SetKVInvoicePaymentHash(ctx context.Context,
144
                arg sqlc.SetKVInvoicePaymentHashParams) error
145

146
        GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
147
                []byte, error)
148

149
        ClearKVInvoiceHashIndex(ctx context.Context) error
150
}
151

152
var _ InvoiceDB = (*SQLStore)(nil)
153

154
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
155
// of batched database operations.
156
type BatchedSQLInvoiceQueries interface {
157
        SQLInvoiceQueries
158

159
        sqldb.BatchedTx[SQLInvoiceQueries]
160
}
161

162
// SQLStore represents a storage backend.
163
type SQLStore struct {
164
        db    BatchedSQLInvoiceQueries
165
        clock clock.Clock
166
        opts  SQLStoreOptions
167
}
168

169
// SQLStoreOptions holds the options for the SQL store.
170
type SQLStoreOptions struct {
171
        paginationLimit int
172
}
173

174
// defaultSQLStoreOptions returns the default options for the SQL store.
175
func defaultSQLStoreOptions() SQLStoreOptions {
×
176
        return SQLStoreOptions{
×
177
                paginationLimit: defaultQueryPaginationLimit,
×
178
        }
×
179
}
×
180

181
// SQLStoreOption is a functional option that can be used to optionally modify
182
// the behavior of the SQL store.
183
type SQLStoreOption func(*SQLStoreOptions)
184

185
// WithPaginationLimit sets the pagination limit for the SQL store queries that
186
// paginate results.
187
func WithPaginationLimit(limit int) SQLStoreOption {
×
188
        return func(o *SQLStoreOptions) {
×
189
                o.paginationLimit = limit
×
190
        }
×
191
}
192

193
// NewSQLStore creates a new SQLStore instance given a open
194
// BatchedSQLInvoiceQueries storage backend.
195
func NewSQLStore(db BatchedSQLInvoiceQueries,
196
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
×
197

×
198
        opts := defaultSQLStoreOptions()
×
199
        for _, applyOption := range options {
×
200
                applyOption(&opts)
×
201
        }
×
202

203
        return &SQLStore{
×
204
                db:    db,
×
205
                clock: clock,
×
206
                opts:  opts,
×
207
        }
×
208
}
209

210
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
211
        sqlc.InsertInvoiceParams, error) {
×
212

×
213
        // Precompute the payment request hash so we can use it in the query.
×
214
        var paymentRequestHash []byte
×
215
        if len(invoice.PaymentRequest) > 0 {
×
216
                h := sha256.New()
×
217
                h.Write(invoice.PaymentRequest)
×
218
                paymentRequestHash = h.Sum(nil)
×
219
        }
×
220

221
        params := sqlc.InsertInvoiceParams{
×
222
                Hash:       paymentHash[:],
×
223
                AmountMsat: int64(invoice.Terms.Value),
×
224
                CltvDelta: sqldb.SQLInt32(
×
225
                        invoice.Terms.FinalCltvDelta,
×
226
                ),
×
227
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
×
228
                // Note: keysend invoices don't have a payment request.
×
229
                PaymentRequest: sqldb.SQLStr(string(
×
230
                        invoice.PaymentRequest),
×
231
                ),
×
232
                PaymentRequestHash: paymentRequestHash,
×
233
                State:              int16(invoice.State),
×
234
                AmountPaidMsat:     int64(invoice.AmtPaid),
×
235
                IsAmp:              invoice.IsAMP(),
×
236
                IsHodl:             invoice.HodlInvoice,
×
237
                IsKeysend:          invoice.IsKeysend(),
×
238
                CreatedAt:          invoice.CreationDate.UTC(),
×
239
        }
×
240

×
241
        if invoice.Memo != nil {
×
242
                // Store the memo as a nullable string in the database. Note
×
243
                // that for compatibility reasons, we store the value as a valid
×
244
                // string even if it's empty.
×
245
                params.Memo = sql.NullString{
×
246
                        String: string(invoice.Memo),
×
247
                        Valid:  true,
×
248
                }
×
249
        }
×
250

251
        // Some invoices may not have a preimage, like in the case of HODL
252
        // invoices.
253
        if invoice.Terms.PaymentPreimage != nil {
×
254
                preimage := *invoice.Terms.PaymentPreimage
×
255
                if preimage == UnknownPreimage {
×
256
                        return sqlc.InsertInvoiceParams{},
×
257
                                errors.New("cannot use all-zeroes preimage")
×
258
                }
×
259
                params.Preimage = preimage[:]
×
260
        }
261

262
        // Some non MPP payments may have the default (invalid) value.
263
        if invoice.Terms.PaymentAddr != BlankPayAddr {
×
264
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
×
265
        }
×
266

267
        return params, nil
×
268
}
269

270
// AddInvoice inserts the targeted invoice into the database. If the invoice has
271
// *any* payment hashes which already exists within the database, then the
272
// insertion will be aborted and rejected due to the strict policy banning any
273
// duplicate payment hashes.
274
//
275
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
276
func (i *SQLStore) AddInvoice(ctx context.Context,
277
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
×
278

×
279
        // Make sure this is a valid invoice before trying to store it in our
×
280
        // DB.
×
281
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
282
                return 0, err
×
283
        }
×
284

285
        var (
×
286
                writeTxOpts = sqldb.WriteTxOpt()
×
287
                invoiceID   int64
×
288
        )
×
289

×
290
        insertInvoiceParams, err := makeInsertInvoiceParams(
×
291
                newInvoice, paymentHash,
×
292
        )
×
293
        if err != nil {
×
294
                return 0, err
×
295
        }
×
296

297
        err = i.db.ExecTx(ctx, writeTxOpts, func(db SQLInvoiceQueries) error {
×
298
                var err error
×
299
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
×
300
                if err != nil {
×
301
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
302
                }
×
303

304
                // TODO(positiveblue): if invocies do not have custom features
305
                // maybe just store the "invoice type" and populate the features
306
                // based on that.
307
                for feature := range newInvoice.Terms.Features.Features() {
×
308
                        params := sqlc.InsertInvoiceFeatureParams{
×
309
                                InvoiceID: invoiceID,
×
310
                                Feature:   int32(feature),
×
311
                        }
×
312

×
313
                        err := db.InsertInvoiceFeature(ctx, params)
×
314
                        if err != nil {
×
315
                                return fmt.Errorf("unable to insert invoice "+
×
316
                                        "feature(%v): %w", feature, err)
×
317
                        }
×
318
                }
319

320
                // Finally add a new event for this invoice.
321
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
322
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
323
                        InvoiceID: invoiceID,
×
324
                })
×
325
        }, sqldb.NoOpReset)
326
        if err != nil {
×
327
                mappedSQLErr := sqldb.MapSQLError(err)
×
328
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
329
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
330
                        // Add context to unique constraint errors.
×
331
                        return 0, ErrDuplicateInvoice
×
332
                }
×
333

334
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
335
                        paymentHash, err)
×
336
        }
337

338
        newInvoice.AddIndex = uint64(invoiceID)
×
339

×
340
        return newInvoice.AddIndex, nil
×
341
}
342

343
// getInvoiceByRef fetches the invoice with the given reference. The reference
344
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
345
func getInvoiceByRef(ctx context.Context,
346
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
×
347

×
348
        // If the reference is empty, we can't look up the invoice.
×
349
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
×
350
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
351
        }
×
352

353
        // If the reference is a hash only, we can look up the invoice directly
354
        // by the payment hash which is faster.
355
        if ref.IsHashOnly() {
×
356
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
×
357
                if errors.Is(err, sql.ErrNoRows) {
×
358
                        return sqlc.Invoice{}, ErrInvoiceNotFound
×
359
                }
×
360

361
                return invoice, err
×
362
        }
363

364
        // Otherwise the reference may include more fields, so we'll need to
365
        // assemble the query parameters based on the fields that are set.
366
        var params sqlc.GetInvoiceParams
×
367

×
368
        if ref.PayHash() != nil {
×
369
                params.Hash = ref.PayHash()[:]
×
370
        }
×
371

372
        // Newer invoices (0.11 and up) are indexed by payment address in
373
        // addition to payment hash, but pre 0.8 invoices do not have one at
374
        // all. Only allow lookups for payment address if it is not a blank
375
        // payment address, which is a special-cased value for legacy keysend
376
        // invoices.
377
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
378
                params.PaymentAddr = ref.PayAddr()[:]
×
379
        }
×
380

381
        // If the reference has a set ID we'll fetch the invoice which has the
382
        // corresponding AMP sub invoice.
383
        if ref.SetID() != nil {
×
384
                params.SetID = ref.SetID()[:]
×
385
        }
×
386

387
        var (
×
388
                rows []sqlc.Invoice
×
389
                err  error
×
390
        )
×
391

×
392
        // We need to split the query based on how we intend to look up the
×
393
        // invoice. If only the set ID is given then we want to have an exact
×
394
        // match on the set ID. If other fields are given, we want to match on
×
395
        // those fields and the set ID but with a less strict join condition.
×
396
        if params.Hash == nil && params.PaymentAddr == nil &&
×
397
                params.SetID != nil {
×
398

×
399
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
×
400
        } else {
×
401
                rows, err = db.GetInvoice(ctx, params)
×
402
        }
×
403

404
        switch {
×
405
        case len(rows) == 0:
×
406
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
407

408
        case len(rows) > 1:
×
409
                // In case the reference is ambiguous, meaning it matches more
×
410
                // than        one invoice, we'll return an error.
×
411
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
×
412
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
413

414
        case err != nil:
×
415
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
416
                        err)
×
417
        }
418

419
        return rows[0], nil
×
420
}
421

422
// fetchInvoice fetches the common invoice data and the AMP state for the
423
// invoice with the given reference.
424
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
425
        *Invoice, error) {
×
426

×
427
        // Fetch the invoice from the database.
×
428
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
×
429
        if err != nil {
×
430
                return nil, err
×
431
        }
×
432

433
        var (
×
434
                setID         *[32]byte
×
435
                fetchAmpHtlcs bool
×
436
        )
×
437

×
438
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
439
        // the modifier.
×
440
        switch ref.Modifier() {
×
441
        case DefaultModifier:
×
442
                // By default we'll fetch all AMP HTLCs.
×
443
                setID = nil
×
444
                fetchAmpHtlcs = true
×
445

446
        case HtlcSetOnlyModifier:
×
447
                // In this case we'll fetch all AMP HTLCs for the specified set
×
448
                // id.
×
449
                if ref.SetID() == nil {
×
450
                        return nil, fmt.Errorf("set ID is required to use " +
×
451
                                "the HTLC set only modifier")
×
452
                }
×
453

454
                setID = ref.SetID()
×
455
                fetchAmpHtlcs = true
×
456

457
        case HtlcSetBlankModifier:
×
458
                // No need to fetch any HTLCs.
×
459
                setID = nil
×
460
                fetchAmpHtlcs = false
×
461

462
        default:
×
463
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
464
                        ref.Modifier())
×
465
        }
466

467
        // Fetch the rest of the invoice data and fill the invoice struct.
468
        _, invoice, err := fetchInvoiceData(
×
469
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
×
470
        )
×
471
        if err != nil {
×
472
                return nil, err
×
473
        }
×
474

475
        return invoice, nil
×
476
}
477

478
// fetchAmpState fetches the AMP state for the invoice with the given ID.
479
// Optional setID can be provided to fetch the state for a specific AMP HTLC
480
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
481
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
482
// well.
483
//
484
//nolint:funlen
485
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
486
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
487
        HTLCSet, error) {
×
488

×
489
        var paramSetID []byte
×
490
        if setID != nil {
×
491
                paramSetID = setID[:]
×
492
        }
×
493

494
        // First fetch all the AMP sub invoices for this invoice or the one
495
        // matching the provided set ID.
496
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
497
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
498
                        InvoiceID: invoiceID,
×
499
                        SetID:     paramSetID,
×
500
                },
×
501
        )
×
502
        if err != nil {
×
503
                return nil, nil, err
×
504
        }
×
505

506
        ampState := make(map[SetID]InvoiceStateAMP)
×
507
        for _, row := range ampInvoiceRows {
×
508
                var rowSetID [32]byte
×
509

×
510
                if len(row.SetID) != 32 {
×
511
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
512
                                len(row.SetID))
×
513
                }
×
514

515
                var settleDate time.Time
×
516
                if row.SettledAt.Valid {
×
517
                        settleDate = row.SettledAt.Time.Local()
×
518
                }
×
519

520
                copy(rowSetID[:], row.SetID)
×
521
                ampState[rowSetID] = InvoiceStateAMP{
×
522
                        State:       HtlcState(row.State),
×
523
                        SettleIndex: uint64(row.SettleIndex.Int64),
×
524
                        SettleDate:  settleDate,
×
525
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
526
                }
×
527
        }
528

529
        if !fetchHtlcs {
×
530
                return ampState, nil, nil
×
531
        }
×
532

533
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
534
        if err != nil {
×
535
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
536
                        "invoice HTLCs: %w", err)
×
537
        }
×
538

539
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
540
        for _, row := range customRecordRows {
×
541
                if _, ok := customRecords[row.HtlcID]; !ok {
×
542
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
543
                }
×
544

545
                value := row.Value
×
546
                if value == nil {
×
547
                        value = []byte{}
×
548
                }
×
549

550
                customRecords[row.HtlcID][uint64(row.Key)] = value
×
551
        }
552

553
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
554
        // provided set ID.
555
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
×
556
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
×
557
                        InvoiceID: invoiceID,
×
558
                        SetID:     paramSetID,
×
559
                },
×
560
        )
×
561
        if err != nil {
×
562
                return nil, nil, err
×
563
        }
×
564

565
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
×
566
        for _, row := range ampHtlcRows {
×
567
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
568
                if err != nil {
×
569
                        return nil, nil, err
×
570
                }
×
571

572
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
573

×
574
                if row.HtlcID < 0 {
×
575
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
576
                                "value: %v", row.HtlcID)
×
577
                }
×
578

579
                htlcID := uint64(row.HtlcID)
×
580

×
581
                circuitKey := CircuitKey{
×
582
                        ChanID: chanID,
×
583
                        HtlcID: htlcID,
×
584
                }
×
585

×
586
                htlc := &InvoiceHTLC{
×
587
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
588
                        AcceptHeight: uint32(row.AcceptHeight),
×
589
                        AcceptTime:   row.AcceptTime.Local(),
×
590
                        Expiry:       uint32(row.ExpiryHeight),
×
591
                        State:        HtlcState(row.State),
×
592
                }
×
593

×
594
                if row.TotalMppMsat.Valid {
×
595
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
596
                                row.TotalMppMsat.Int64,
×
597
                        )
×
598
                }
×
599

600
                if row.ResolveTime.Valid {
×
601
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
602
                }
×
603

604
                var (
×
605
                        rootShare [32]byte
×
606
                        setID     [32]byte
×
607
                )
×
608

×
609
                if len(row.RootShare) != 32 {
×
610
                        return nil, nil, fmt.Errorf("invalid root share "+
×
611
                                "length: %d", len(row.RootShare))
×
612
                }
×
613
                copy(rootShare[:], row.RootShare)
×
614

×
615
                if len(row.SetID) != 32 {
×
616
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
617
                                len(row.SetID))
×
618
                }
×
619
                copy(setID[:], row.SetID)
×
620

×
621
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
622
                        return nil, nil, fmt.Errorf("invalid child index "+
×
623
                                "value: %v", row.ChildIndex)
×
624
                }
×
625

626
                ampRecord := record.NewAMP(
×
627
                        rootShare, setID, uint32(row.ChildIndex),
×
628
                )
×
629

×
630
                htlc.AMP = &InvoiceHtlcAMPData{
×
631
                        Record: *ampRecord,
×
632
                }
×
633

×
634
                if len(row.Hash) != 32 {
×
635
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
636
                                len(row.Hash))
×
637
                }
×
638
                copy(htlc.AMP.Hash[:], row.Hash)
×
639

×
640
                if row.Preimage != nil {
×
641
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
642
                        if err != nil {
×
643
                                return nil, nil, err
×
644
                        }
×
645

646
                        htlc.AMP.Preimage = &preimage
×
647
                }
648

649
                if _, ok := customRecords[row.ID]; ok {
×
650
                        htlc.CustomRecords = customRecords[row.ID]
×
651
                } else {
×
652
                        htlc.CustomRecords = make(record.CustomSet)
×
653
                }
×
654

655
                ampHtlcs[circuitKey] = htlc
×
656
        }
657

658
        if len(ampHtlcs) > 0 {
×
659
                for setID := range ampState {
×
660
                        var amtPaid lnwire.MilliSatoshi
×
661
                        invoiceKeys := make(
×
662
                                map[models.CircuitKey]struct{},
×
663
                        )
×
664

×
665
                        for key, htlc := range ampHtlcs {
×
666
                                if htlc.AMP.Record.SetID() != setID {
×
667
                                        continue
×
668
                                }
669

670
                                invoiceKeys[key] = struct{}{}
×
671

×
672
                                if htlc.State != HtlcStateCanceled {
×
673
                                        amtPaid += htlc.Amt
×
674
                                }
×
675
                        }
676

677
                        setState := ampState[setID]
×
678
                        setState.InvoiceKeys = invoiceKeys
×
679
                        setState.AmtPaid = amtPaid
×
680
                        ampState[setID] = setState
×
681
                }
682
        }
683

684
        return ampState, ampHtlcs, nil
×
685
}
686

687
// LookupInvoice attempts to look up an invoice corresponding the passed in
688
// reference. The reference may be a payment hash, a payment address, or a set
689
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
690
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
691
// error.
692
func (i *SQLStore) LookupInvoice(ctx context.Context,
693
        ref InvoiceRef) (Invoice, error) {
×
694

×
695
        var (
×
696
                invoice *Invoice
×
697
                err     error
×
698
        )
×
699

×
700
        readTxOpt := sqldb.ReadTxOpt()
×
701
        txErr := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
702
                invoice, err = fetchInvoice(ctx, db, ref)
×
703

×
704
                return err
×
705
        }, sqldb.NoOpReset)
×
706
        if txErr != nil {
×
707
                return Invoice{}, txErr
×
708
        }
×
709

710
        return *invoice, nil
×
711
}
712

713
// FetchPendingInvoices returns all the invoices that are currently in a
714
// "pending" state. An invoice is pending if it has been created but not yet
715
// settled or canceled.
716
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
717
        map[lntypes.Hash]Invoice, error) {
×
718

×
719
        var invoices map[lntypes.Hash]Invoice
×
720

×
721
        readTxOpt := sqldb.ReadTxOpt()
×
722
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
723
                return queryWithLimit(func(offset int) (int, error) {
×
724
                        params := sqlc.FilterInvoicesParams{
×
725
                                PendingOnly: true,
×
726
                                NumOffset:   int32(offset),
×
727
                                NumLimit:    int32(i.opts.paginationLimit),
×
728
                                Reverse:     false,
×
729
                        }
×
730

×
731
                        rows, err := db.FilterInvoices(ctx, params)
×
732
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
733
                                return 0, fmt.Errorf("unable to get invoices "+
×
734
                                        "from db: %w", err)
×
735
                        }
×
736

737
                        // Load all the information for the invoices.
738
                        for _, row := range rows {
×
739
                                hash, invoice, err := fetchInvoiceData(
×
740
                                        ctx, db, row, nil, true,
×
741
                                )
×
742
                                if err != nil {
×
743
                                        return 0, err
×
744
                                }
×
745

746
                                invoices[*hash] = *invoice
×
747
                        }
748

749
                        return len(rows), nil
×
750
                }, i.opts.paginationLimit)
751
        }, func() {
×
752
                invoices = make(map[lntypes.Hash]Invoice)
×
753
        })
×
754
        if err != nil {
×
755
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
756
                        err)
×
757
        }
×
758

759
        return invoices, nil
×
760
}
761

762
// InvoicesSettledSince can be used by callers to catch up any settled invoices
763
// they missed within the settled invoice time series. We'll return all known
764
// settled invoice that have a settle index higher than the passed idx.
765
//
766
// NOTE: The index starts from 1. As a result we enforce that specifying a value
767
// below the starting index value is a noop.
768
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
769
        []Invoice, error) {
×
770

×
771
        var (
×
772
                invoices       []Invoice
×
773
                start          = time.Now()
×
774
                lastLogTime    = time.Now()
×
775
                processedCount int
×
776
        )
×
777

×
778
        if idx == 0 {
×
779
                return invoices, nil
×
780
        }
×
781

782
        readTxOpt := sqldb.ReadTxOpt()
×
783
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
784
                err := queryWithLimit(func(offset int) (int, error) {
×
785
                        params := sqlc.FilterInvoicesParams{
×
786
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
787
                                NumOffset:      int32(offset),
×
788
                                NumLimit:       int32(i.opts.paginationLimit),
×
789
                                Reverse:        false,
×
790
                        }
×
791

×
792
                        rows, err := db.FilterInvoices(ctx, params)
×
793
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
794
                                return 0, fmt.Errorf("unable to get invoices "+
×
795
                                        "from db: %w", err)
×
796
                        }
×
797

798
                        // Load all the information for the invoices.
799
                        for _, row := range rows {
×
800
                                _, invoice, err := fetchInvoiceData(
×
801
                                        ctx, db, row, nil, true,
×
802
                                )
×
803
                                if err != nil {
×
804
                                        return 0, fmt.Errorf("unable to fetch "+
×
805
                                                "invoice(id=%d) from db: %w",
×
806
                                                row.ID, err)
×
807
                                }
×
808

809
                                invoices = append(invoices, *invoice)
×
810

×
811
                                processedCount++
×
812
                                if time.Since(lastLogTime) >=
×
813
                                        invoiceProgressLogInterval {
×
814

×
815
                                        log.Debugf("Processed %d settled "+
×
816
                                                "invoices which have a settle "+
×
817
                                                "index greater than %v",
×
818
                                                processedCount, idx)
×
819

×
820
                                        lastLogTime = time.Now()
×
821
                                }
×
822
                        }
823

824
                        return len(rows), nil
×
825
                }, i.opts.paginationLimit)
826
                if err != nil {
×
827
                        return err
×
828
                }
×
829

830
                // Now fetch all the AMP sub invoices that were settled since
831
                // the provided index.
832
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
×
833
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
834
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
835
                        },
×
836
                )
×
837
                if err != nil {
×
838
                        return err
×
839
                }
×
840

841
                for _, ampInvoice := range ampInvoices {
×
842
                        // Convert the row to a sqlc.Invoice so we can use the
×
843
                        // existing fetchInvoiceData function.
×
844
                        sqlInvoice := sqlc.Invoice{
×
845
                                ID:             ampInvoice.ID,
×
846
                                Hash:           ampInvoice.Hash,
×
847
                                Preimage:       ampInvoice.Preimage,
×
848
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
849
                                SettledAt:      ampInvoice.AmpSettledAt,
×
850
                                Memo:           ampInvoice.Memo,
×
851
                                AmountMsat:     ampInvoice.AmountMsat,
×
852
                                CltvDelta:      ampInvoice.CltvDelta,
×
853
                                Expiry:         ampInvoice.Expiry,
×
854
                                PaymentAddr:    ampInvoice.PaymentAddr,
×
855
                                PaymentRequest: ampInvoice.PaymentRequest,
×
856
                                State:          ampInvoice.State,
×
857
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
858
                                IsAmp:          ampInvoice.IsAmp,
×
859
                                IsHodl:         ampInvoice.IsHodl,
×
860
                                IsKeysend:      ampInvoice.IsKeysend,
×
861
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
862
                        }
×
863

×
864
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
865
                        _, invoice, err := fetchInvoiceData(
×
866
                                ctx, db, sqlInvoice,
×
867
                                (*[32]byte)(ampInvoice.SetID), true,
×
868
                        )
×
869
                        if err != nil {
×
870
                                return fmt.Errorf("unable to fetch "+
×
871
                                        "AMP invoice(id=%d) from db: %w",
×
872
                                        ampInvoice.ID, err)
×
873
                        }
×
874

875
                        invoices = append(invoices, *invoice)
×
876

×
877
                        processedCount++
×
878
                        if time.Since(lastLogTime) >=
×
879
                                invoiceProgressLogInterval {
×
880

×
881
                                log.Debugf("Processed %d settled invoices "+
×
882
                                        "including AMP sub invoices which "+
×
883
                                        "have a settle index greater than %v",
×
884
                                        processedCount, idx)
×
885

×
886
                                lastLogTime = time.Now()
×
887
                        }
×
888
                }
889

890
                return nil
×
891
        }, func() {
×
892
                invoices = nil
×
893
        })
×
894
        if err != nil {
×
895
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
896
                        "index (excluding) %d: %w", idx, err)
×
897
        }
×
898

899
        elapsed := time.Since(start)
×
900
        log.Debugf("Completed scanning for settled invoices starting at "+
×
901
                "index %v: total_processed=%d, found_invoices=%d, elapsed=%v",
×
902
                idx, processedCount, len(invoices),
×
903
                elapsed.Round(time.Millisecond))
×
904

×
905
        return invoices, nil
×
906
}
907

908
// InvoicesAddedSince can be used by callers to seek into the event time series
909
// of all the invoices added in the database. This method will return all
910
// invoices with an add index greater than the specified idx.
911
//
912
// NOTE: The index starts from 1. As a result we enforce that specifying a value
913
// below the starting index value is a noop.
914
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
915
        []Invoice, error) {
×
916

×
917
        var (
×
918
                result         []Invoice
×
919
                start          = time.Now()
×
920
                lastLogTime    = time.Now()
×
921
                processedCount int
×
922
        )
×
923

×
924
        if idx == 0 {
×
925
                return result, nil
×
926
        }
×
927

928
        readTxOpt := sqldb.ReadTxOpt()
×
929
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
930
                return queryWithLimit(func(offset int) (int, error) {
×
931
                        params := sqlc.FilterInvoicesParams{
×
932
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
933
                                NumOffset:   int32(offset),
×
934
                                NumLimit:    int32(i.opts.paginationLimit),
×
935
                                Reverse:     false,
×
936
                        }
×
937

×
938
                        rows, err := db.FilterInvoices(ctx, params)
×
939
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
940
                                return 0, fmt.Errorf("unable to get invoices "+
×
941
                                        "from db: %w", err)
×
942
                        }
×
943

944
                        // Load all the information for the invoices.
945
                        for _, row := range rows {
×
946
                                _, invoice, err := fetchInvoiceData(
×
947
                                        ctx, db, row, nil, true,
×
948
                                )
×
949
                                if err != nil {
×
950
                                        return 0, err
×
951
                                }
×
952

953
                                result = append(result, *invoice)
×
954

×
955
                                processedCount++
×
956
                                if time.Since(lastLogTime) >=
×
957
                                        invoiceProgressLogInterval {
×
958

×
959
                                        log.Debugf("Processed %d invoices "+
×
960
                                                "which were added since add "+
×
961
                                                "index %v", processedCount, idx)
×
962

×
963
                                        lastLogTime = time.Now()
×
964
                                }
×
965
                        }
966

967
                        return len(rows), nil
×
968
                }, i.opts.paginationLimit)
969
        }, func() {
×
970
                result = nil
×
971
        })
×
972

973
        if err != nil {
×
974
                return nil, fmt.Errorf("unable to get invoices added since "+
×
975
                        "index %d: %w", idx, err)
×
976
        }
×
977

978
        elapsed := time.Since(start)
×
979
        log.Debugf("Completed scanning for invoices added since index %v: "+
×
980
                "total_processed=%d, found_invoices=%d, elapsed=%v",
×
981
                idx, processedCount, len(result),
×
982
                elapsed.Round(time.Millisecond))
×
983

×
984
        return result, nil
×
985
}
986

987
// QueryInvoices allows a caller to query the invoice database for invoices
988
// within the specified add index range.
989
func (i *SQLStore) QueryInvoices(ctx context.Context,
990
        q InvoiceQuery) (InvoiceSlice, error) {
×
991

×
992
        var invoices []Invoice
×
993

×
994
        if q.NumMaxInvoices == 0 {
×
995
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
996
                        "be non-zero")
×
997
        }
×
998

999
        readTxOpt := sqldb.ReadTxOpt()
×
1000
        err := i.db.ExecTx(ctx, readTxOpt, func(db SQLInvoiceQueries) error {
×
1001
                return queryWithLimit(func(offset int) (int, error) {
×
1002
                        params := sqlc.FilterInvoicesParams{
×
1003
                                NumOffset:   int32(offset),
×
1004
                                NumLimit:    int32(i.opts.paginationLimit),
×
1005
                                PendingOnly: q.PendingOnly,
×
1006
                                Reverse:     q.Reversed,
×
1007
                        }
×
1008

×
1009
                        if q.Reversed {
×
1010
                                // If the index offset was not set, we want to
×
1011
                                // fetch from the lastest invoice.
×
1012
                                if q.IndexOffset == 0 {
×
1013
                                        params.AddIndexLet = sqldb.SQLInt64(
×
1014
                                                int64(math.MaxInt64),
×
1015
                                        )
×
1016
                                } else {
×
1017
                                        // The invoice with index offset id must
×
1018
                                        // not be included in the results.
×
1019
                                        params.AddIndexLet = sqldb.SQLInt64(
×
1020
                                                q.IndexOffset - 1,
×
1021
                                        )
×
1022
                                }
×
1023
                        } else {
×
1024
                                // The invoice with index offset id must not be
×
1025
                                // included in the results.
×
1026
                                params.AddIndexGet = sqldb.SQLInt64(
×
1027
                                        q.IndexOffset + 1,
×
1028
                                )
×
1029
                        }
×
1030

1031
                        if q.CreationDateStart != 0 {
×
1032
                                params.CreatedAfter = sqldb.SQLTime(
×
1033
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
1034
                                )
×
1035
                        }
×
1036

1037
                        if q.CreationDateEnd != 0 {
×
1038
                                // We need to add 1 to the end date as we're
×
1039
                                // checking less than the end date in SQL.
×
1040
                                params.CreatedBefore = sqldb.SQLTime(
×
1041
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
×
1042
                                )
×
1043
                        }
×
1044

1045
                        rows, err := db.FilterInvoices(ctx, params)
×
1046
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1047
                                return 0, fmt.Errorf("unable to get invoices "+
×
1048
                                        "from db: %w", err)
×
1049
                        }
×
1050

1051
                        // Load all the information for the invoices.
1052
                        for _, row := range rows {
×
1053
                                _, invoice, err := fetchInvoiceData(
×
1054
                                        ctx, db, row, nil, true,
×
1055
                                )
×
1056
                                if err != nil {
×
1057
                                        return 0, err
×
1058
                                }
×
1059

1060
                                invoices = append(invoices, *invoice)
×
1061

×
1062
                                if len(invoices) == int(q.NumMaxInvoices) {
×
1063
                                        return 0, nil
×
1064
                                }
×
1065
                        }
1066

1067
                        return len(rows), nil
×
1068
                }, i.opts.paginationLimit)
1069
        }, func() {
×
1070
                invoices = nil
×
1071
        })
×
1072
        if err != nil {
×
1073
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
1074
                        "invoices: %w", err)
×
1075
        }
×
1076

1077
        if len(invoices) == 0 {
×
1078
                return InvoiceSlice{
×
1079
                        InvoiceQuery: q,
×
1080
                }, nil
×
1081
        }
×
1082

1083
        // If we iterated through the add index in reverse order, then
1084
        // we'll need to reverse the slice of invoices to return them in
1085
        // forward order.
1086
        if q.Reversed {
×
1087
                numInvoices := len(invoices)
×
1088
                for i := 0; i < numInvoices/2; i++ {
×
1089
                        reverse := numInvoices - i - 1
×
1090
                        invoices[i], invoices[reverse] =
×
1091
                                invoices[reverse], invoices[i]
×
1092
                }
×
1093
        }
1094

1095
        res := InvoiceSlice{
×
1096
                InvoiceQuery:     q,
×
1097
                Invoices:         invoices,
×
1098
                FirstIndexOffset: invoices[0].AddIndex,
×
1099
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
1100
        }
×
1101

×
1102
        return res, nil
×
1103
}
1104

1105
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1106
// a SQL database as the backend.
1107
type sqlInvoiceUpdater struct {
1108
        db         SQLInvoiceQueries
1109
        ctx        context.Context //nolint:containedctx
1110
        invoice    *Invoice
1111
        updateTime time.Time
1112
}
1113

1114
// AddHtlc adds a new htlc to the invoice.
1115
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
1116
        newHtlc *InvoiceHTLC) error {
×
1117

×
1118
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
1119
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
1120
                        HtlcID: int64(circuitKey.HtlcID),
×
1121
                        ChanID: strconv.FormatUint(
×
1122
                                circuitKey.ChanID.ToUint64(), 10,
×
1123
                        ),
×
1124
                        AmountMsat: int64(newHtlc.Amt),
×
1125
                        TotalMppMsat: sql.NullInt64{
×
1126
                                Int64: int64(newHtlc.MppTotalAmt),
×
1127
                                Valid: newHtlc.MppTotalAmt != 0,
×
1128
                        },
×
1129
                        AcceptHeight: int32(newHtlc.AcceptHeight),
×
1130
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
1131
                        ExpiryHeight: int32(newHtlc.Expiry),
×
1132
                        State:        int16(newHtlc.State),
×
1133
                        InvoiceID:    int64(s.invoice.AddIndex),
×
1134
                },
×
1135
        )
×
1136
        if err != nil {
×
1137
                return err
×
1138
        }
×
1139

1140
        for key, value := range newHtlc.CustomRecords {
×
1141
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
1142
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
1143
                                // TODO(bhandras): schema might be wrong here
×
1144
                                // as the custom record key is an uint64.
×
1145
                                Key:    int64(key),
×
1146
                                Value:  value,
×
1147
                                HtlcID: htlcPrimaryKeyID,
×
1148
                        },
×
1149
                )
×
1150
                if err != nil {
×
1151
                        return err
×
1152
                }
×
1153
        }
1154

1155
        if newHtlc.AMP != nil {
×
1156
                setID := newHtlc.AMP.Record.SetID()
×
1157

×
1158
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
1159
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
1160
                                SetID:     setID[:],
×
1161
                                CreatedAt: s.updateTime.UTC(),
×
1162
                                InvoiceID: int64(s.invoice.AddIndex),
×
1163
                        },
×
1164
                )
×
1165
                if err != nil {
×
1166
                        mappedSQLErr := sqldb.MapSQLError(err)
×
1167
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
×
1168
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
1169
                                return ErrDuplicateSetID{
×
1170
                                        SetID: setID,
×
1171
                                }
×
1172
                        }
×
1173

1174
                        return err
×
1175
                }
1176

1177
                // If we're just inserting the AMP invoice, we'll get a non
1178
                // zero rows affected count.
1179
                rowsAffected, err := upsertResult.RowsAffected()
×
1180
                if err != nil {
×
1181
                        return err
×
1182
                }
×
1183
                if rowsAffected != 0 {
×
1184
                        // If we're inserting a new AMP invoice, we'll also
×
1185
                        // insert a new invoice event.
×
1186
                        err = s.db.OnAMPSubInvoiceCreated(
×
1187
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
1188
                                        AddedAt:   s.updateTime.UTC(),
×
1189
                                        InvoiceID: int64(s.invoice.AddIndex),
×
1190
                                        SetID:     setID[:],
×
1191
                                },
×
1192
                        )
×
1193
                        if err != nil {
×
1194
                                return err
×
1195
                        }
×
1196
                }
1197

1198
                rootShare := newHtlc.AMP.Record.RootShare()
×
1199

×
1200
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
1201
                        InvoiceID: int64(s.invoice.AddIndex),
×
1202
                        SetID:     setID[:],
×
1203
                        HtlcID:    htlcPrimaryKeyID,
×
1204
                        RootShare: rootShare[:],
×
1205
                        ChildIndex: int64(
×
1206
                                newHtlc.AMP.Record.ChildIndex(),
×
1207
                        ),
×
1208
                        Hash: newHtlc.AMP.Hash[:],
×
1209
                }
×
1210

×
1211
                if newHtlc.AMP.Preimage != nil {
×
1212
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
1213
                }
×
1214

1215
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
1216
                if err != nil {
×
1217
                        return err
×
1218
                }
×
1219
        }
1220

1221
        return nil
×
1222
}
1223

1224
// ResolveHtlc marks an htlc as resolved with the given state.
1225
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
1226
        state HtlcState, resolveTime time.Time) error {
×
1227

×
1228
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
1229
                HtlcID: int64(circuitKey.HtlcID),
×
1230
                ChanID: strconv.FormatUint(
×
1231
                        circuitKey.ChanID.ToUint64(), 10,
×
1232
                ),
×
1233
                InvoiceID:   int64(s.invoice.AddIndex),
×
1234
                State:       int16(state),
×
1235
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
1236
        })
×
1237
}
×
1238

1239
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1240
// identified by the setID.
1241
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
1242
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
1243

×
1244
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
1245
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
×
1246
                        InvoiceID: int64(s.invoice.AddIndex),
×
1247
                        SetID:     setID[:],
×
1248
                        HtlcID:    int64(circuitKey.HtlcID),
×
1249
                        Preimage:  preimage[:],
×
1250
                        ChanID: strconv.FormatUint(
×
1251
                                circuitKey.ChanID.ToUint64(), 10,
×
1252
                        ),
×
1253
                },
×
1254
        )
×
1255
        if err != nil {
×
1256
                return err
×
1257
        }
×
1258

1259
        rowsAffected, err := result.RowsAffected()
×
1260
        if err != nil {
×
1261
                return err
×
1262
        }
×
1263
        if rowsAffected == 0 {
×
1264
                return ErrInvoiceNotFound
×
1265
        }
×
1266

1267
        return nil
×
1268
}
1269

1270
// UpdateInvoiceState updates the invoice state to the new state.
1271
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
1272
        newState ContractState, preimage *lntypes.Preimage) error {
×
1273

×
1274
        var (
×
1275
                settleIndex sql.NullInt64
×
1276
                settledAt   sql.NullTime
×
1277
        )
×
1278

×
1279
        switch newState {
×
1280
        case ContractSettled:
×
1281
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1282
                if err != nil {
×
1283
                        return err
×
1284
                }
×
1285

1286
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1287

×
1288
                // If the invoice is settled, we'll also update the settle time.
×
1289
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
1290

×
1291
                err = s.db.OnInvoiceSettled(
×
1292
                        s.ctx, sqlc.OnInvoiceSettledParams{
×
1293
                                AddedAt:   s.updateTime.UTC(),
×
1294
                                InvoiceID: int64(s.invoice.AddIndex),
×
1295
                        },
×
1296
                )
×
1297
                if err != nil {
×
1298
                        return err
×
1299
                }
×
1300

1301
        case ContractCanceled:
×
1302
                err := s.db.OnInvoiceCanceled(
×
1303
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
1304
                                AddedAt:   s.updateTime.UTC(),
×
1305
                                InvoiceID: int64(s.invoice.AddIndex),
×
1306
                        },
×
1307
                )
×
1308
                if err != nil {
×
1309
                        return err
×
1310
                }
×
1311
        }
1312

1313
        params := sqlc.UpdateInvoiceStateParams{
×
1314
                ID:          int64(s.invoice.AddIndex),
×
1315
                State:       int16(newState),
×
1316
                SettleIndex: settleIndex,
×
1317
                SettledAt:   settledAt,
×
1318
        }
×
1319

×
1320
        if preimage != nil {
×
1321
                params.Preimage = preimage[:]
×
1322
        }
×
1323

1324
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
1325
        if err != nil {
×
1326
                return err
×
1327
        }
×
1328
        rowsAffected, err := result.RowsAffected()
×
1329
        if err != nil {
×
1330
                return err
×
1331
        }
×
1332

1333
        if rowsAffected == 0 {
×
1334
                return ErrInvoiceNotFound
×
1335
        }
×
1336

1337
        if settleIndex.Valid {
×
1338
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
×
1339
                s.invoice.SettleDate = s.updateTime
×
1340
        }
×
1341

1342
        return nil
×
1343
}
1344

1345
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1346
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
1347
        amtPaid lnwire.MilliSatoshi) error {
×
1348

×
1349
        _, err := s.db.UpdateInvoiceAmountPaid(
×
1350
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
1351
                        ID:             int64(s.invoice.AddIndex),
×
1352
                        AmountPaidMsat: int64(amtPaid),
×
1353
                },
×
1354
        )
×
1355

×
1356
        return err
×
1357
}
×
1358

1359
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1360
// setID.
1361
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
1362
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
1363

×
1364
        var (
×
1365
                settleIndex sql.NullInt64
×
1366
                settledAt   sql.NullTime
×
1367
        )
×
1368

×
1369
        switch newState.State {
×
1370
        case HtlcStateSettled:
×
1371
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1372
                if err != nil {
×
1373
                        return err
×
1374
                }
×
1375

1376
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1377

×
1378
                // If the invoice is settled, we'll also update the settle time.
×
1379
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
1380

×
1381
                err = s.db.OnAMPSubInvoiceSettled(
×
1382
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
×
1383
                                AddedAt:   s.updateTime.UTC(),
×
1384
                                InvoiceID: int64(s.invoice.AddIndex),
×
1385
                                SetID:     setID[:],
×
1386
                        },
×
1387
                )
×
1388
                if err != nil {
×
1389
                        return err
×
1390
                }
×
1391

1392
        case HtlcStateCanceled:
×
1393
                err := s.db.OnAMPSubInvoiceCanceled(
×
1394
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
1395
                                AddedAt:   s.updateTime.UTC(),
×
1396
                                InvoiceID: int64(s.invoice.AddIndex),
×
1397
                                SetID:     setID[:],
×
1398
                        },
×
1399
                )
×
1400
                if err != nil {
×
1401
                        return err
×
1402
                }
×
1403
        }
1404

1405
        err := s.db.UpdateAMPSubInvoiceState(
×
1406
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
×
1407
                        SetID:       setID[:],
×
1408
                        State:       int16(newState.State),
×
1409
                        SettleIndex: settleIndex,
×
1410
                        SettledAt:   settledAt,
×
1411
                },
×
1412
        )
×
1413
        if err != nil {
×
1414
                return err
×
1415
        }
×
1416

1417
        if settleIndex.Valid {
×
1418
                updatedState := s.invoice.AMPState[setID]
×
1419
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
1420
                updatedState.SettleDate = s.updateTime.UTC()
×
1421
                s.invoice.AMPState[setID] = updatedState
×
1422
        }
×
1423

1424
        return nil
×
1425
}
1426

1427
// Finalize finalizes the update before it is written to the database. Note that
1428
// we don't use this directly in the SQL implementation, so the function is just
1429
// a stub.
1430
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
×
1431
        return nil
×
1432
}
×
1433

1434
// UpdateInvoice attempts to update an invoice corresponding to the passed
1435
// reference. If an invoice matching the passed reference doesn't exist within
1436
// the database, then the action will fail with  ErrInvoiceNotFound error.
1437
//
1438
// The update is performed inside the same database transaction that fetches the
1439
// invoice and is therefore atomic. The fields to update are controlled by the
1440
// supplied callback.
1441
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1442
        setID *SetID, callback InvoiceUpdateCallback) (
1443
        *Invoice, error) {
×
1444

×
1445
        var updatedInvoice *Invoice
×
1446

×
1447
        txOpt := sqldb.WriteTxOpt()
×
1448
        txErr := i.db.ExecTx(ctx, txOpt, func(db SQLInvoiceQueries) error {
×
1449
                switch {
×
1450
                // For the default case we fetch all HTLCs.
1451
                case setID == nil:
×
1452
                        ref.refModifier = DefaultModifier
×
1453

1454
                // If the setID is the blank but NOT nil, we set the
1455
                // refModifier to HtlcSetBlankModifier to fetch no HTLC for the
1456
                // AMP invoice.
1457
                case *setID == BlankPayAddr:
×
1458
                        ref.refModifier = HtlcSetBlankModifier
×
1459

1460
                // A setID is provided, we use the refModifier to fetch only
1461
                // the HTLCs for the given setID and also make sure we add the
1462
                // setID to the ref.
1463
                default:
×
1464
                        var setIDBytes [32]byte
×
1465
                        copy(setIDBytes[:], setID[:])
×
1466
                        ref.setID = &setIDBytes
×
1467

×
1468
                        // We only fetch the HTLCs for the given setID.
×
1469
                        ref.refModifier = HtlcSetOnlyModifier
×
1470
                }
1471

1472
                invoice, err := fetchInvoice(ctx, db, ref)
×
1473
                if err != nil {
×
1474
                        return err
×
1475
                }
×
1476

1477
                updateTime := i.clock.Now()
×
1478
                updater := &sqlInvoiceUpdater{
×
1479
                        db:         db,
×
1480
                        ctx:        ctx,
×
1481
                        invoice:    invoice,
×
1482
                        updateTime: updateTime,
×
1483
                }
×
1484

×
1485
                payHash := ref.PayHash()
×
1486
                updatedInvoice, err = UpdateInvoice(
×
1487
                        payHash, invoice, updateTime, callback, updater,
×
1488
                )
×
1489

×
1490
                return err
×
1491
        }, sqldb.NoOpReset)
1492
        if txErr != nil {
×
1493
                // If the invoice is already settled, we'll return the
×
1494
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
1495
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
1496
                        return updatedInvoice, txErr
×
1497
                }
×
1498

1499
                return nil, txErr
×
1500
        }
1501

1502
        return updatedInvoice, nil
×
1503
}
1504

1505
// DeleteInvoice attempts to delete the passed invoices and all their related
1506
// data from the database in one transaction.
1507
func (i *SQLStore) DeleteInvoice(ctx context.Context,
1508
        invoicesToDelete []InvoiceDeleteRef) error {
×
1509

×
1510
        // All the InvoiceDeleteRef instances include the add index of the
×
1511
        // invoice. The rest was added to ensure that the invoices were deleted
×
1512
        // properly in the kv database. When we have fully migrated we can
×
1513
        // remove the rest of the fields.
×
1514
        for _, ref := range invoicesToDelete {
×
1515
                if ref.AddIndex == 0 {
×
1516
                        return fmt.Errorf("unable to delete invoice using a "+
×
1517
                                "ref without AddIndex set: %v", ref)
×
1518
                }
×
1519
        }
1520

1521
        writeTxOpt := sqldb.WriteTxOpt()
×
1522
        err := i.db.ExecTx(ctx, writeTxOpt, func(db SQLInvoiceQueries) error {
×
1523
                for _, ref := range invoicesToDelete {
×
1524
                        params := sqlc.DeleteInvoiceParams{
×
1525
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
1526
                        }
×
1527

×
1528
                        if ref.SettleIndex != 0 {
×
1529
                                params.SettleIndex = sqldb.SQLInt64(
×
1530
                                        ref.SettleIndex,
×
1531
                                )
×
1532
                        }
×
1533

1534
                        if ref.PayHash != lntypes.ZeroHash {
×
1535
                                params.Hash = ref.PayHash[:]
×
1536
                        }
×
1537

1538
                        result, err := db.DeleteInvoice(ctx, params)
×
1539
                        if err != nil {
×
1540
                                return fmt.Errorf("unable to delete "+
×
1541
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1542
                        }
×
1543
                        rowsAffected, err := result.RowsAffected()
×
1544
                        if err != nil {
×
1545
                                return fmt.Errorf("unable to get rows "+
×
1546
                                        "affected: %w", err)
×
1547
                        }
×
1548
                        if rowsAffected == 0 {
×
1549
                                return fmt.Errorf("%w: %v",
×
1550
                                        ErrInvoiceNotFound, ref.AddIndex)
×
1551
                        }
×
1552
                }
1553

1554
                return nil
×
1555
        }, sqldb.NoOpReset)
1556

1557
        if err != nil {
×
1558
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1559
        }
×
1560

1561
        return nil
×
1562
}
1563

1564
// DeleteCanceledInvoices removes all canceled invoices from the database.
1565
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
1566
        writeTxOpt := sqldb.WriteTxOpt()
×
1567
        err := i.db.ExecTx(ctx, writeTxOpt, func(db SQLInvoiceQueries) error {
×
1568
                _, err := db.DeleteCanceledInvoices(ctx)
×
1569
                if err != nil {
×
1570
                        return fmt.Errorf("unable to delete canceled "+
×
1571
                                "invoices: %w", err)
×
1572
                }
×
1573

1574
                return nil
×
1575
        }, sqldb.NoOpReset)
1576
        if err != nil {
×
1577
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1578
        }
×
1579

1580
        return nil
×
1581
}
1582

1583
// fetchInvoiceData fetches additional data for the given invoice. If the
1584
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1585
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1586
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1587
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1588
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
1589
        *Invoice, error) {
×
1590

×
1591
        // Unmarshal the common data.
×
1592
        hash, invoice, err := unmarshalInvoice(row)
×
1593
        if err != nil {
×
1594
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1595
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1596
        }
×
1597

1598
        // Fetch the invoice features.
1599
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
1600
        if err != nil {
×
1601
                return nil, nil, err
×
1602
        }
×
1603

1604
        invoice.Terms.Features = features
×
1605

×
1606
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
1607
        // with the HTLCs (if requested).
×
1608
        if invoice.IsAMP() {
×
1609
                invoiceID := int64(invoice.AddIndex)
×
1610
                ampState, ampHtlcs, err := fetchAmpState(
×
1611
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
1612
                )
×
1613
                if err != nil {
×
1614
                        return nil, nil, err
×
1615
                }
×
1616

1617
                invoice.AMPState = ampState
×
1618
                invoice.Htlcs = ampHtlcs
×
1619

×
1620
                return hash, invoice, nil
×
1621
        }
1622

1623
        // Otherwise simply fetch the invoice HTLCs.
1624
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
1625
        if err != nil {
×
1626
                return nil, nil, err
×
1627
        }
×
1628

1629
        if len(htlcs) > 0 {
×
1630
                invoice.Htlcs = htlcs
×
1631
        }
×
1632

1633
        return hash, invoice, nil
×
1634
}
1635

1636
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1637
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
1638
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
1639

×
1640
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
1641
        if err != nil {
×
1642
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1643
                        err)
×
1644
        }
×
1645

1646
        features := lnwire.EmptyFeatureVector()
×
1647
        for _, feature := range rows {
×
1648
                features.Set(lnwire.FeatureBit(feature.Feature))
×
1649
        }
×
1650

1651
        return features, nil
×
1652
}
1653

1654
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1655
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
1656
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
1657

×
1658
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
1659
        if err != nil {
×
1660
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1661
        }
×
1662

1663
        // We have no htlcs to unmarshal.
1664
        if len(htlcRows) == 0 {
×
1665
                return nil, nil
×
1666
        }
×
1667

1668
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
1669
        if err != nil {
×
1670
                return nil, fmt.Errorf("unable to get custom records for "+
×
1671
                        "invoice htlcs: %w", err)
×
1672
        }
×
1673

1674
        cr := make(map[int64]record.CustomSet, len(crRows))
×
1675
        for _, row := range crRows {
×
1676
                if _, ok := cr[row.HtlcID]; !ok {
×
1677
                        cr[row.HtlcID] = make(record.CustomSet)
×
1678
                }
×
1679

1680
                value := row.Value
×
1681
                if value == nil {
×
1682
                        value = []byte{}
×
1683
                }
×
1684
                cr[row.HtlcID][uint64(row.Key)] = value
×
1685
        }
1686

1687
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
1688

×
1689
        for _, row := range htlcRows {
×
1690
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
1691
                if err != nil {
×
1692
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1693
                                "htlc(%d): %w", row.ID, err)
×
1694
                }
×
1695

1696
                if customRecords, ok := cr[row.ID]; ok {
×
1697
                        htlc.CustomRecords = customRecords
×
1698
                } else {
×
1699
                        htlc.CustomRecords = make(record.CustomSet)
×
1700
                }
×
1701

1702
                htlcs[circuiteKey] = htlc
×
1703
        }
1704

1705
        return htlcs, nil
×
1706
}
1707

1708
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1709
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
1710
        error) {
×
1711

×
1712
        var (
×
1713
                settleIndex    int64
×
1714
                settledAt      time.Time
×
1715
                memo           []byte
×
1716
                paymentRequest []byte
×
1717
                preimage       *lntypes.Preimage
×
1718
                paymentAddr    [32]byte
×
1719
        )
×
1720

×
1721
        hash, err := lntypes.MakeHash(row.Hash)
×
1722
        if err != nil {
×
1723
                return nil, nil, err
×
1724
        }
×
1725

1726
        if row.SettleIndex.Valid {
×
1727
                settleIndex = row.SettleIndex.Int64
×
1728
        }
×
1729

1730
        if row.SettledAt.Valid {
×
1731
                settledAt = row.SettledAt.Time.Local()
×
1732
        }
×
1733

1734
        if row.Memo.Valid {
×
1735
                memo = []byte(row.Memo.String)
×
1736
        }
×
1737

1738
        // Keysend payments will have this field empty.
1739
        if row.PaymentRequest.Valid {
×
1740
                paymentRequest = []byte(row.PaymentRequest.String)
×
1741
        } else {
×
1742
                paymentRequest = []byte{}
×
1743
        }
×
1744

1745
        // We may not have the preimage if this a hodl invoice.
1746
        if row.Preimage != nil {
×
1747
                preimage = &lntypes.Preimage{}
×
1748
                copy(preimage[:], row.Preimage)
×
1749
        }
×
1750

1751
        copy(paymentAddr[:], row.PaymentAddr)
×
1752

×
1753
        var cltvDelta int32
×
1754
        if row.CltvDelta.Valid {
×
1755
                cltvDelta = row.CltvDelta.Int32
×
1756
        }
×
1757

1758
        expiry := time.Duration(row.Expiry) * time.Second
×
1759

×
1760
        invoice := &Invoice{
×
1761
                SettleIndex:    uint64(settleIndex),
×
1762
                SettleDate:     settledAt,
×
1763
                Memo:           memo,
×
1764
                PaymentRequest: paymentRequest,
×
1765
                CreationDate:   row.CreatedAt.Local(),
×
1766
                Terms: ContractTerm{
×
1767
                        FinalCltvDelta:  cltvDelta,
×
1768
                        Expiry:          expiry,
×
1769
                        PaymentPreimage: preimage,
×
1770
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
×
1771
                        PaymentAddr:     paymentAddr,
×
1772
                },
×
1773
                AddIndex:    uint64(row.ID),
×
1774
                State:       ContractState(row.State),
×
1775
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
1776
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
1777
                AMPState:    AMPInvoiceState{},
×
1778
                HodlInvoice: row.IsHodl,
×
1779
        }
×
1780

×
1781
        return &hash, invoice, nil
×
1782
}
1783

1784
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1785
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
1786
        *InvoiceHTLC, error) {
×
1787

×
1788
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
1789
        if err != nil {
×
1790
                return CircuitKey{}, nil, err
×
1791
        }
×
1792

1793
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
1794

×
1795
        if row.HtlcID < 0 {
×
1796
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1797
                        "value: %v", row.HtlcID)
×
1798
        }
×
1799

1800
        htlcID := uint64(row.HtlcID)
×
1801

×
1802
        circuitKey := CircuitKey{
×
1803
                ChanID: chanID,
×
1804
                HtlcID: htlcID,
×
1805
        }
×
1806

×
1807
        htlc := &InvoiceHTLC{
×
1808
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
1809
                AcceptHeight: uint32(row.AcceptHeight),
×
1810
                AcceptTime:   row.AcceptTime.Local(),
×
1811
                Expiry:       uint32(row.ExpiryHeight),
×
1812
                State:        HtlcState(row.State),
×
1813
        }
×
1814

×
1815
        if row.TotalMppMsat.Valid {
×
1816
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
1817
        }
×
1818

1819
        if row.ResolveTime.Valid {
×
1820
                htlc.ResolveTime = row.ResolveTime.Time.Local()
×
1821
        }
×
1822

1823
        return circuitKey, htlc, nil
×
1824
}
1825

1826
// queryWithLimit is a helper method that can be used to query the database
1827
// using a limit and offset. The passed query function should return the number
1828
// of rows returned and an error if any.
1829
func queryWithLimit(query func(int) (int, error), limit int) error {
×
1830
        offset := 0
×
1831
        for {
×
1832
                rows, err := query(offset)
×
1833
                if err != nil {
×
1834
                        return err
×
1835
                }
×
1836

1837
                if rows < limit {
×
1838
                        return nil
×
1839
                }
×
1840

1841
                offset += limit
×
1842
        }
1843
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc