• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13035292482

29 Jan 2025 03:59PM UTC coverage: 49.3% (-9.5%) from 58.777%
13035292482

Pull #9456

github

mohamedawnallah
docs: update release-notes-0.19.0.md

In this commit, we warn users about the removal
of RPCs `SendToRoute`, `SendToRouteSync`, `SendPayment`,
and `SendPaymentSync` in the next release 0.20.
Pull Request #9456: lnrpc+docs: deprecate warning `SendToRoute`, `SendToRouteSync`, `SendPayment`, and `SendPaymentSync` in Release 0.19

100634 of 204126 relevant lines covered (49.3%)

1.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27
)
28

29
// SQLInvoiceQueries is an interface that defines the set of operations that can
30
// be executed against the invoice SQL database.
31
type SQLInvoiceQueries interface { //nolint:interfacebloat
32
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
33
                error)
34

35
        InsertInvoiceFeature(ctx context.Context,
36
                arg sqlc.InsertInvoiceFeatureParams) error
37

38
        InsertInvoiceHTLC(ctx context.Context,
39
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
40

41
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
42
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
43

44
        FilterInvoices(ctx context.Context,
45
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
46

47
        GetInvoice(ctx context.Context,
48
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
49

50
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
51
                error)
52

53
        GetInvoiceFeatures(ctx context.Context,
54
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
55

56
        GetInvoiceHTLCCustomRecords(ctx context.Context,
57
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
58

59
        GetInvoiceHTLCs(ctx context.Context,
60
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
61

62
        UpdateInvoiceState(ctx context.Context,
63
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
64

65
        UpdateInvoiceAmountPaid(ctx context.Context,
66
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
67

68
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
69

70
        UpdateInvoiceHTLC(ctx context.Context,
71
                arg sqlc.UpdateInvoiceHTLCParams) error
72

73
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
74
                sql.Result, error)
75

76
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
77

78
        // AMP sub invoice specific methods.
79
        UpsertAMPSubInvoice(ctx context.Context,
80
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
81

82
        UpdateAMPSubInvoiceState(ctx context.Context,
83
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
84

85
        InsertAMPSubInvoiceHTLC(ctx context.Context,
86
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
87

88
        FetchAMPSubInvoices(ctx context.Context,
89
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
90
                error)
91

92
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
93
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
94
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
95

96
        FetchSettledAMPSubInvoices(ctx context.Context,
97
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
98
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
99

100
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
101
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
102
                error)
103

104
        // Invoice events specific methods.
105
        OnInvoiceCreated(ctx context.Context,
106
                arg sqlc.OnInvoiceCreatedParams) error
107

108
        OnInvoiceCanceled(ctx context.Context,
109
                arg sqlc.OnInvoiceCanceledParams) error
110

111
        OnInvoiceSettled(ctx context.Context,
112
                arg sqlc.OnInvoiceSettledParams) error
113

114
        OnAMPSubInvoiceCreated(ctx context.Context,
115
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
116

117
        OnAMPSubInvoiceCanceled(ctx context.Context,
118
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
119

120
        OnAMPSubInvoiceSettled(ctx context.Context,
121
                arg sqlc.OnAMPSubInvoiceSettledParams) error
122
}
123

124
var _ InvoiceDB = (*SQLStore)(nil)
125

126
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
127
// SQLInvoiceQueries understands.
128
type SQLInvoiceQueriesTxOptions struct {
129
        // readOnly governs if a read only transaction is needed or not.
130
        readOnly bool
131
}
132

133
// ReadOnly returns true if the transaction should be read only.
134
//
135
// NOTE: This implements the TxOptions.
136
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
137
        return a.readOnly
138
}
139

140
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
141
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
142
        return SQLInvoiceQueriesTxOptions{
143
                readOnly: true,
144
        }
145
}
146

147
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
148
// of batched database operations.
149
type BatchedSQLInvoiceQueries interface {
150
        SQLInvoiceQueries
151

152
        sqldb.BatchedTx[SQLInvoiceQueries]
153
}
154

155
// SQLStore represents a storage backend.
156
type SQLStore struct {
157
        db    BatchedSQLInvoiceQueries
158
        clock clock.Clock
159
        opts  SQLStoreOptions
160
}
×
161

×
162
// SQLStoreOptions holds the options for the SQL store.
×
163
type SQLStoreOptions struct {
164
        paginationLimit int
165
}
×
166

×
167
// defaultSQLStoreOptions returns the default options for the SQL store.
×
168
func defaultSQLStoreOptions() SQLStoreOptions {
×
169
        return SQLStoreOptions{
×
170
                paginationLimit: defaultQueryPaginationLimit,
171
        }
172
}
173

174
// SQLStoreOption is a functional option that can be used to optionally modify
175
// the behavior of the SQL store.
176
type SQLStoreOption func(*SQLStoreOptions)
177

178
// WithPaginationLimit sets the pagination limit for the SQL store queries that
179
// paginate results.
180
func WithPaginationLimit(limit int) SQLStoreOption {
181
        return func(o *SQLStoreOptions) {
182
                o.paginationLimit = limit
183
        }
184
}
185

186
// NewSQLStore creates a new SQLStore instance given a open
187
// BatchedSQLInvoiceQueries storage backend.
188
func NewSQLStore(db BatchedSQLInvoiceQueries,
189
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
190

191
        opts := defaultSQLStoreOptions()
192
        for _, applyOption := range options {
×
193
                applyOption(&opts)
×
194
        }
×
195

×
196
        return &SQLStore{
×
197
                db:    db,
198
                clock: clock,
199
                opts:  opts,
200
        }
201
}
202

203
// AddInvoice inserts the targeted invoice into the database. If the invoice has
204
// *any* payment hashes which already exists within the database, then the
×
205
// insertion will be aborted and rejected due to the strict policy banning any
×
206
// duplicate payment hashes.
×
207
//
×
208
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
209
func (i *SQLStore) AddInvoice(ctx context.Context,
210
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
211

212
        // Make sure this is a valid invoice before trying to store it in our
213
        // DB.
×
214
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
215
                return 0, err
×
216
        }
×
217

×
218
        var (
×
219
                writeTxOpts SQLInvoiceQueriesTxOptions
220
                invoiceID   int64
×
221
        )
×
222

×
223
        // Precompute the payment request hash so we can use it in the query.
×
224
        var paymentRequestHash []byte
×
225
        if len(newInvoice.PaymentRequest) > 0 {
226
                h := sha256.New()
227
                h.Write(newInvoice.PaymentRequest)
228
                paymentRequestHash = h.Sum(nil)
×
229
        }
×
230

×
231
        err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
×
232
                params := sqlc.InsertInvoiceParams{
×
233
                        Hash:       paymentHash[:],
×
234
                        Memo:       sqldb.SQLStr(string(newInvoice.Memo)),
×
235
                        AmountMsat: int64(newInvoice.Terms.Value),
×
236
                        // Note: BOLT12 invoices don't have a final cltv delta.
×
237
                        CltvDelta: sqldb.SQLInt32(
238
                                newInvoice.Terms.FinalCltvDelta,
×
239
                        ),
×
240
                        Expiry: int32(newInvoice.Terms.Expiry.Seconds()),
×
241
                        // Note: keysend invoices don't have a payment request.
×
242
                        PaymentRequest: sqldb.SQLStr(string(
×
243
                                newInvoice.PaymentRequest),
×
244
                        ),
×
245
                        PaymentRequestHash: paymentRequestHash,
×
246
                        State:              int16(newInvoice.State),
×
247
                        AmountPaidMsat:     int64(newInvoice.AmtPaid),
×
248
                        IsAmp:              newInvoice.IsAMP(),
×
249
                        IsHodl:             newInvoice.HodlInvoice,
×
250
                        IsKeysend:          newInvoice.IsKeysend(),
×
251
                        CreatedAt:          newInvoice.CreationDate.UTC(),
×
252
                }
×
253

×
254
                // Some invoices may not have a preimage, like in the case of
×
255
                // HODL invoices.
×
256
                if newInvoice.Terms.PaymentPreimage != nil {
×
257
                        preimage := *newInvoice.Terms.PaymentPreimage
×
258
                        if preimage == UnknownPreimage {
×
259
                                return errors.New("cannot use all-zeroes " +
×
260
                                        "preimage")
×
261
                        }
×
262
                        params.Preimage = preimage[:]
×
263
                }
×
264

×
265
                // Some non MPP payments may have the default (invalid) value.
×
266
                if newInvoice.Terms.PaymentAddr != BlankPayAddr {
×
267
                        params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
268
                }
269

270
                var err error
×
271
                invoiceID, err = db.InsertInvoice(ctx, params)
×
272
                if err != nil {
×
273
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
274
                }
×
275

×
276
                // TODO(positiveblue): if invocies do not have custom features
×
277
                // maybe just store the "invoice type" and populate the features
278
                // based on that.
279
                for feature := range newInvoice.Terms.Features.Features() {
280
                        params := sqlc.InsertInvoiceFeatureParams{
×
281
                                InvoiceID: invoiceID,
×
282
                                Feature:   int32(feature),
×
283
                        }
284

×
285
                        err := db.InsertInvoiceFeature(ctx, params)
286
                        if err != nil {
287
                                return fmt.Errorf("unable to insert invoice "+
288
                                        "feature(%v): %w", feature, err)
289
                        }
290
                }
291

292
                // Finally add a new event for this invoice.
293
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
294
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
295
                        InvoiceID: invoiceID,
×
296
                })
×
297
        }, func() {})
×
298
        if err != nil {
×
299
                mappedSQLErr := sqldb.MapSQLError(err)
×
300
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
301
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
302
                        // Add context to unique constraint errors.
×
303
                        return 0, ErrDuplicateInvoice
×
304
                }
×
305

×
306
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
307
                        paymentHash, err)
×
308
        }
×
309

×
310
        newInvoice.AddIndex = uint64(invoiceID)
×
311

×
312
        return newInvoice.AddIndex, nil
×
313
}
314

×
315
// fetchInvoice fetches the common invoice data and the AMP state for the
×
316
// invoice with the given reference.
×
317
func (i *SQLStore) fetchInvoice(ctx context.Context,
×
318
        db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
×
319

×
320
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
321
                return nil, ErrInvoiceNotFound
322
        }
323

324
        var (
×
325
                invoice *Invoice
×
326
                params  sqlc.GetInvoiceParams
×
327
        )
×
328

×
329
        // Given all invoices are uniquely identified by their payment hash,
×
330
        // we can use it to query a specific invoice.
×
331
        if ref.PayHash() != nil {
×
332
                params.Hash = ref.PayHash()[:]
×
333
        }
×
334

×
335
        // Newer invoices (0.11 and up) are indexed by payment address in
336
        // addition to payment hash, but pre 0.8 invoices do not have one at
337
        // all. Only allow lookups for payment address if it is not a blank
338
        // payment address, which is a special-cased value for legacy keysend
×
339
        // invoices.
×
340
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
341
                params.PaymentAddr = ref.PayAddr()[:]
×
342
        }
×
343

×
344
        // If the reference has a set ID we'll fetch the invoice which has the
×
345
        // corresponding AMP sub invoice.
×
346
        if ref.SetID() != nil {
×
347
                params.SetID = ref.SetID()[:]
×
348
        }
×
349

×
350
        var (
351
                rows []sqlc.Invoice
×
352
                err  error
×
353
        )
354

355
        // We need to split the query based on how we intend to look up the
×
356
        // invoice. If only the set ID is given then we want to have an exact
×
357
        // match on the set ID. If other fields are given, we want to match on
×
358
        // those fields and the set ID but with a less strict join condition.
359
        if params.Hash == nil && params.PaymentAddr == nil &&
360
                params.SetID != nil {
361

362
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
363
        } else {
×
364
                rows, err = db.GetInvoice(ctx, params)
×
365
        }
×
366
        switch {
×
367
        case len(rows) == 0:
×
368
                return nil, ErrInvoiceNotFound
×
369

370
        case len(rows) > 1:
371
                // In case the reference is ambiguous, meaning it matches more
372
                // than        one invoice, we'll return an error.
×
373
                return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
×
374
                        ref.String(), spew.Sdump(rows))
×
375

×
376
        case err != nil:
×
377
                return nil, fmt.Errorf("unable to fetch invoice: %w", err)
378
        }
×
379

380
        var (
381
                setID         *[32]byte
382
                fetchAmpHtlcs bool
383
        )
×
384

×
385
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
386
        // the modifier.
×
387
        switch ref.Modifier() {
×
388
        case DefaultModifier:
389
                // By default we'll fetch all AMP HTLCs.
390
                setID = nil
391
                fetchAmpHtlcs = true
392

393
        case HtlcSetOnlyModifier:
394
                // In this case we'll fetch all AMP HTLCs for the
×
395
                // specified set id.
×
396
                if ref.SetID() == nil {
×
397
                        return nil, fmt.Errorf("set ID is required to use " +
398
                                "the HTLC set only modifier")
399
                }
400

×
401
                setID = ref.SetID()
×
402
                fetchAmpHtlcs = true
×
403

404
        case HtlcSetBlankModifier:
×
405
                // No need to fetch any HTLCs.
×
406
                setID = nil
×
407
                fetchAmpHtlcs = false
×
408

×
409
        default:
×
410
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
411
                        ref.Modifier())
×
412
        }
×
413

×
414
        // Fetch the rest of the invoice data and fill the invoice struct.
×
415
        _, invoice, err = fetchInvoiceData(
×
416
                ctx, db, rows[0], setID, fetchAmpHtlcs,
×
417
        )
×
418
        if err != nil {
×
419
                return nil, err
×
420
        }
421

×
422
        return invoice, nil
×
423
}
×
424

425
// fetchAmpState fetches the AMP state for the invoice with the given ID.
×
426
// Optional setID can be provided to fetch the state for a specific AMP HTLC
×
427
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
×
428
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
×
429
// well.
×
430
//
431
//nolint:funlen
×
432
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
×
433
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
×
434
        HTLCSet, error) {
435

436
        var paramSetID []byte
×
437
        if setID != nil {
438
                paramSetID = setID[:]
439
        }
440

441
        // First fetch all the AMP sub invoices for this invoice or the one
442
        // matching the provided set ID.
×
443
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
444
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
445
                        InvoiceID: invoiceID,
×
446
                        SetID:     paramSetID,
×
447
                },
×
448
        )
×
449
        if err != nil {
450
                return nil, nil, err
×
451
        }
×
452

×
453
        ampState := make(map[SetID]InvoiceStateAMP)
×
454
        for _, row := range ampInvoiceRows {
×
455
                var rowSetID [32]byte
×
456

×
457
                if len(row.SetID) != 32 {
×
458
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
459
                                len(row.SetID))
×
460
                }
×
461

×
462
                var settleDate time.Time
463
                if row.SettledAt.Valid {
×
464
                        settleDate = row.SettledAt.Time.Local()
×
465
                }
×
466

×
467
                copy(rowSetID[:], row.SetID)
×
468
                ampState[rowSetID] = InvoiceStateAMP{
×
469
                        State:       HtlcState(row.State),
×
470
                        SettleIndex: uint64(row.SettleIndex.Int64),
471
                        SettleDate:  settleDate,
×
472
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
473
                }
474
        }
×
475

×
476
        if !fetchHtlcs {
×
477
                return ampState, nil, nil
×
478
        }
479

×
480
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
481
        if err != nil {
×
482
                return nil, nil, fmt.Errorf("unable to get custom records for "+
483
                        "invoice HTLCs: %w", err)
484
        }
485

×
486
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
487
        for _, row := range customRecordRows {
×
488
                if _, ok := customRecords[row.HtlcID]; !ok {
×
489
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
490
                }
×
491

492
                value := row.Value
×
493
                if value == nil {
494
                        value = []byte{}
495
                }
496

497
                customRecords[row.HtlcID][uint64(row.Key)] = value
498
        }
499

500
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
501
        // provided set ID.
502
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
503
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
504
                        InvoiceID: invoiceID,
×
505
                        SetID:     paramSetID,
×
506
                },
×
507
        )
×
508
        if err != nil {
×
509
                return nil, nil, err
×
510
        }
511

512
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
513
        for _, row := range ampHtlcRows {
×
514
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
515
                if err != nil {
×
516
                        return nil, nil, err
×
517
                }
×
518

×
519
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
520

×
521
                if row.HtlcID < 0 {
×
522
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
523
                                "value: %v", row.HtlcID)
×
524
                }
×
525

×
526
                htlcID := uint64(row.HtlcID)
×
527

×
528
                circuitKey := CircuitKey{
×
529
                        ChanID: chanID,
×
530
                        HtlcID: htlcID,
×
531
                }
532

×
533
                htlc := &InvoiceHTLC{
×
534
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
535
                        AcceptHeight: uint32(row.AcceptHeight),
×
536
                        AcceptTime:   row.AcceptTime.Local(),
537
                        Expiry:       uint32(row.ExpiryHeight),
×
538
                        State:        HtlcState(row.State),
×
539
                }
×
540

×
541
                if row.TotalMppMsat.Valid {
×
542
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
543
                                row.TotalMppMsat.Int64,
×
544
                        )
545
                }
546

×
547
                if row.ResolveTime.Valid {
×
548
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
549
                }
550

×
551
                var (
×
552
                        rootShare [32]byte
×
553
                        setID     [32]byte
×
554
                )
×
555

556
                if len(row.RootShare) != 32 {
×
557
                        return nil, nil, fmt.Errorf("invalid root share "+
×
558
                                "length: %d", len(row.RootShare))
×
559
                }
×
560
                copy(rootShare[:], row.RootShare)
×
561

562
                if len(row.SetID) != 32 {
×
563
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
564
                                len(row.SetID))
×
565
                }
×
566
                copy(setID[:], row.SetID)
567

×
568
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
569
                        return nil, nil, fmt.Errorf("invalid child index "+
570
                                "value: %v", row.ChildIndex)
571
                }
572

×
573
                ampRecord := record.NewAMP(
×
574
                        rootShare, setID, uint32(row.ChildIndex),
×
575
                )
×
576

×
577
                htlc.AMP = &InvoiceHtlcAMPData{
×
578
                        Record: *ampRecord,
×
579
                }
×
580

×
581
                if len(row.Hash) != 32 {
582
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
583
                                len(row.Hash))
×
584
                }
×
585
                copy(htlc.AMP.Hash[:], row.Hash)
×
586

×
587
                if row.Preimage != nil {
×
588
                        preimage, err := lntypes.MakePreimage(row.Preimage)
589
                        if err != nil {
×
590
                                return nil, nil, err
×
591
                        }
×
592

×
593
                        htlc.AMP.Preimage = &preimage
×
594
                }
×
595

596
                if _, ok := customRecords[row.ID]; ok {
×
597
                        htlc.CustomRecords = customRecords[row.ID]
×
598
                } else {
×
599
                        htlc.CustomRecords = make(record.CustomSet)
×
600
                }
×
601

×
602
                ampHtlcs[circuitKey] = htlc
×
603
        }
×
604

×
605
        if len(ampHtlcs) > 0 {
×
606
                for setID := range ampState {
×
607
                        var amtPaid lnwire.MilliSatoshi
×
608
                        invoiceKeys := make(
×
609
                                map[models.CircuitKey]struct{},
×
610
                        )
×
611

×
612
                        for key, htlc := range ampHtlcs {
×
613
                                if htlc.AMP.Record.SetID() != setID {
×
614
                                        continue
×
615
                                }
×
616

617
                                invoiceKeys[key] = struct{}{}
×
618

×
619
                                if htlc.State != HtlcStateCanceled { //nolint: ll
×
620
                                        amtPaid += htlc.Amt
621
                                }
×
622
                        }
×
623

×
624
                        setState := ampState[setID]
×
625
                        setState.InvoiceKeys = invoiceKeys
×
626
                        setState.AmtPaid = amtPaid
×
627
                        ampState[setID] = setState
×
628
                }
×
629
        }
×
630

×
631
        return ampState, ampHtlcs, nil
×
632
}
×
633

×
634
// LookupInvoice attempts to look up an invoice corresponding the passed in
×
635
// reference. The reference may be a payment hash, a payment address, or a set
×
636
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
×
637
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
×
638
// error.
×
639
func (i *SQLStore) LookupInvoice(ctx context.Context,
×
640
        ref InvoiceRef) (Invoice, error) {
×
641

×
642
        var (
643
                invoice *Invoice
×
644
                err     error
×
645
        )
×
646

×
647
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
648
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
649
                invoice, err = i.fetchInvoice(ctx, db, ref)
×
650

×
651
                return err
×
652
        }, func() {})
×
653
        if txErr != nil {
×
654
                return Invoice{}, txErr
×
655
        }
×
656

×
657
        return *invoice, nil
×
658
}
×
659

×
660
// FetchPendingInvoices returns all the invoices that are currently in a
×
661
// "pending" state. An invoice is pending if it has been created but not yet
×
662
// settled or canceled.
663
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
×
664
        map[lntypes.Hash]Invoice, error) {
665

666
        var invoices map[lntypes.Hash]Invoice
×
667

×
668
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
669
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
670
                return queryWithLimit(func(offset int) (int, error) {
×
671
                        params := sqlc.FilterInvoicesParams{
672
                                PendingOnly: true,
×
673
                                NumOffset:   int32(offset),
674
                                NumLimit:    int32(i.opts.paginationLimit),
675
                                Reverse:     false,
×
676
                        }
×
677

×
678
                        rows, err := db.FilterInvoices(ctx, params)
×
679
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
680
                                return 0, fmt.Errorf("unable to get invoices "+
×
681
                                        "from db: %w", err)
×
682
                        }
×
683

×
684
                        // Load all the information for the invoices.
×
685
                        for _, row := range rows {
686
                                hash, invoice, err := fetchInvoiceData(
687
                                        ctx, db, row, nil, true,
×
688
                                )
×
689
                                if err != nil {
×
690
                                        return 0, err
×
691
                                }
×
692

693
                                invoices[*hash] = *invoice
694
                        }
×
695

×
696
                        return len(rows), nil
×
697
                }, i.opts.paginationLimit)
×
698
        }, func() {
699
                invoices = make(map[lntypes.Hash]Invoice)
700
        })
701
        if err != nil {
×
702
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
703
                        err)
704
        }
705

706
        return invoices, nil
707
}
708

709
// InvoicesSettledSince can be used by callers to catch up any settled invoices
710
// they missed within the settled invoice time series. We'll return all known
×
711
// settled invoice that have a settle index higher than the passed idx.
×
712
//
×
713
// NOTE: The index starts from 1. As a result we enforce that specifying a value
×
714
// below the starting index value is a noop.
×
715
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
×
716
        []Invoice, error) {
×
717

×
718
        var invoices []Invoice
×
719

×
720
        if idx == 0 {
×
721
                return invoices, nil
×
722
        }
×
723

×
724
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
725
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
726
                err := queryWithLimit(func(offset int) (int, error) {
727
                        params := sqlc.FilterInvoicesParams{
×
728
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
729
                                NumOffset:      int32(offset),
730
                                NumLimit:       int32(i.opts.paginationLimit),
731
                                Reverse:        false,
732
                        }
733

734
                        rows, err := db.FilterInvoices(ctx, params)
×
735
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
736
                                return 0, fmt.Errorf("unable to get invoices "+
×
737
                                        "from db: %w", err)
×
738
                        }
×
739

×
740
                        // Load all the information for the invoices.
×
741
                        for _, row := range rows {
×
742
                                _, invoice, err := fetchInvoiceData(
×
743
                                        ctx, db, row, nil, true,
×
744
                                )
×
745
                                if err != nil {
×
746
                                        return 0, fmt.Errorf("unable to fetch "+
×
747
                                                "invoice(id=%d) from db: %w",
×
748
                                                row.ID, err)
×
749
                                }
×
750

×
751
                                invoices = append(invoices, *invoice)
×
752
                        }
×
753

754
                        return len(rows), nil
755
                }, i.opts.paginationLimit)
×
756
                if err != nil {
×
757
                        return err
×
758
                }
×
759

×
760
                // Now fetch all the AMP sub invoices that were settled since
×
761
                // the provided index.
×
762
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
763
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
764
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
765
                        },
766
                )
×
767
                if err != nil {
768
                        return err
×
769
                }
×
770

×
771
                for _, ampInvoice := range ampInvoices {
×
772
                        // Convert the row to a sqlc.Invoice so we can use the
×
773
                        // existing fetchInvoiceData function.
×
774
                        sqlInvoice := sqlc.Invoice{
×
775
                                ID:             ampInvoice.ID,
776
                                Hash:           ampInvoice.Hash,
×
777
                                Preimage:       ampInvoice.Preimage,
778
                                SettleIndex:    ampInvoice.AmpSettleIndex,
779
                                SettledAt:      ampInvoice.AmpSettledAt,
780
                                Memo:           ampInvoice.Memo,
781
                                AmountMsat:     ampInvoice.AmountMsat,
782
                                CltvDelta:      ampInvoice.CltvDelta,
783
                                Expiry:         ampInvoice.Expiry,
784
                                PaymentAddr:    ampInvoice.PaymentAddr,
785
                                PaymentRequest: ampInvoice.PaymentRequest,
786
                                State:          ampInvoice.State,
×
787
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
788
                                IsAmp:          ampInvoice.IsAmp,
×
789
                                IsHodl:         ampInvoice.IsHodl,
×
790
                                IsKeysend:      ampInvoice.IsKeysend,
×
791
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
792
                        }
×
793

794
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
795
                        _, invoice, err := fetchInvoiceData(
×
796
                                ctx, db, sqlInvoice,
×
797
                                (*[32]byte)(ampInvoice.SetID), true,
×
798
                        )
×
799
                        if err != nil {
×
800
                                return fmt.Errorf("unable to fetch "+
×
801
                                        "AMP invoice(id=%d) from db: %w",
×
802
                                        ampInvoice.ID, err)
×
803
                        }
×
804

×
805
                        invoices = append(invoices, *invoice)
×
806
                }
×
807

×
808
                return nil
×
809
        }, func() {
810
                invoices = nil
811
        })
×
812
        if err != nil {
×
813
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
814
                        "index (excluding) %d: %w", idx, err)
×
815
        }
×
816

×
817
        return invoices, nil
×
818
}
×
819

×
820
// InvoicesAddedSince can be used by callers to seek into the event time series
821
// of all the invoices added in the database. This method will return all
×
822
// invoices with an add index greater than the specified idx.
823
//
824
// NOTE: The index starts from 1. As a result we enforce that specifying a value
×
825
// below the starting index value is a noop.
826
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
×
827
        []Invoice, error) {
×
828

×
829
        var result []Invoice
830

831
        if idx == 0 {
832
                return result, nil
×
833
        }
×
834

×
835
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
836
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
837
                return queryWithLimit(func(offset int) (int, error) {
×
838
                        params := sqlc.FilterInvoicesParams{
×
839
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
840
                                NumOffset:   int32(offset),
841
                                NumLimit:    int32(i.opts.paginationLimit),
×
842
                                Reverse:     false,
×
843
                        }
×
844

×
845
                        rows, err := db.FilterInvoices(ctx, params)
×
846
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
847
                                return 0, fmt.Errorf("unable to get invoices "+
×
848
                                        "from db: %w", err)
×
849
                        }
×
850

×
851
                        // Load all the information for the invoices.
×
852
                        for _, row := range rows {
×
853
                                _, invoice, err := fetchInvoiceData(
×
854
                                        ctx, db, row, nil, true,
×
855
                                )
×
856
                                if err != nil {
×
857
                                        return 0, err
×
858
                                }
×
859

×
860
                                result = append(result, *invoice)
×
861
                        }
×
862

×
863
                        return len(rows), nil
×
864
                }, i.opts.paginationLimit)
×
865
        }, func() {
×
866
                result = nil
×
867
        })
×
868

×
869
        if err != nil {
×
870
                return nil, fmt.Errorf("unable to get invoices added since "+
×
871
                        "index %d: %w", idx, err)
×
872
        }
×
873

×
874
        return result, nil
875
}
×
876

877
// QueryInvoices allows a caller to query the invoice database for invoices
878
// within the specified add index range.
×
879
func (i *SQLStore) QueryInvoices(ctx context.Context,
×
880
        q InvoiceQuery) (InvoiceSlice, error) {
×
881

×
882
        var invoices []Invoice
×
883

×
884
        if q.NumMaxInvoices == 0 {
×
885
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
886
                        "be non-zero")
887
        }
×
888

889
        readTxOpt := NewSQLInvoiceQueryReadTx()
890
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
891
                return queryWithLimit(func(offset int) (int, error) {
892
                        params := sqlc.FilterInvoicesParams{
893
                                NumOffset:   int32(offset),
894
                                NumLimit:    int32(i.opts.paginationLimit),
895
                                PendingOnly: q.PendingOnly,
896
                                Reverse:     q.Reversed,
897
                        }
×
898

×
899
                        if q.Reversed {
×
900
                                // If the index offset was not set, we want to
×
901
                                // fetch from the lastest invoice.
×
902
                                if q.IndexOffset == 0 {
×
903
                                        params.AddIndexLet = sqldb.SQLInt64(
×
904
                                                int64(math.MaxInt64),
905
                                        )
×
906
                                } else {
×
907
                                        // The invoice with index offset id must
×
908
                                        // not be included in the results.
×
909
                                        params.AddIndexLet = sqldb.SQLInt64(
×
910
                                                q.IndexOffset - 1,
×
911
                                        )
×
912
                                }
×
913
                        } else {
×
914
                                // The invoice with index offset id must not be
×
915
                                // included in the results.
×
916
                                params.AddIndexGet = sqldb.SQLInt64(
×
917
                                        q.IndexOffset + 1,
×
918
                                )
×
919
                        }
×
920

921
                        if q.CreationDateStart != 0 {
922
                                params.CreatedAfter = sqldb.SQLTime(
×
923
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
924
                                )
×
925
                        }
×
926

×
927
                        if q.CreationDateEnd != 0 {
×
928
                                // We need to add 1 to the end date as we're
×
929
                                // checking less than the end date in SQL.
930
                                params.CreatedBefore = sqldb.SQLTime(
×
931
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
932
                                )
933
                        }
×
934

935
                        rows, err := db.FilterInvoices(ctx, params)
×
936
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
937
                                return 0, fmt.Errorf("unable to get invoices "+
×
938
                                        "from db: %w", err)
939
                        }
×
940

×
941
                        // Load all the information for the invoices.
×
942
                        for _, row := range rows {
×
943
                                _, invoice, err := fetchInvoiceData(
944
                                        ctx, db, row, nil, true,
×
945
                                )
946
                                if err != nil {
947
                                        return 0, err
948
                                }
949

950
                                invoices = append(invoices, *invoice)
×
951

×
952
                                if len(invoices) == int(q.NumMaxInvoices) {
×
953
                                        return 0, nil
×
954
                                }
×
955
                        }
×
956

×
957
                        return len(rows), nil
×
958
                }, i.opts.paginationLimit)
959
        }, func() {
×
960
                invoices = nil
×
961
        })
×
962
        if err != nil {
×
963
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
964
                        "invoices: %w", err)
×
965
        }
×
966

×
967
        if len(invoices) == 0 {
×
968
                return InvoiceSlice{
×
969
                        InvoiceQuery: q,
×
970
                }, nil
×
971
        }
×
972

×
973
        // If we iterated through the add index in reverse order, then
×
974
        // we'll need to reverse the slice of invoices to return them in
×
975
        // forward order.
×
976
        if q.Reversed {
×
977
                numInvoices := len(invoices)
×
978
                for i := 0; i < numInvoices/2; i++ {
×
979
                        reverse := numInvoices - i - 1
×
980
                        invoices[i], invoices[reverse] =
×
981
                                invoices[reverse], invoices[i]
×
982
                }
×
983
        }
×
984

×
985
        res := InvoiceSlice{
×
986
                InvoiceQuery:     q,
×
987
                Invoices:         invoices,
×
988
                FirstIndexOffset: invoices[0].AddIndex,
×
989
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
990
        }
991

×
992
        return res, nil
×
993
}
×
994

×
995
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
×
996
// a SQL database as the backend.
997
type sqlInvoiceUpdater struct {
×
998
        db         SQLInvoiceQueries
×
999
        ctx        context.Context //nolint:containedctx
×
1000
        invoice    *Invoice
×
1001
        updateTime time.Time
×
1002
}
×
1003

×
1004
// AddHtlc adds a new htlc to the invoice.
1005
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
×
1006
        newHtlc *InvoiceHTLC) error {
×
1007

×
1008
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
1009
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
1010
                        HtlcID: int64(circuitKey.HtlcID),
1011
                        ChanID: strconv.FormatUint(
1012
                                circuitKey.ChanID.ToUint64(), 10,
×
1013
                        ),
×
1014
                        AmountMsat: int64(newHtlc.Amt),
×
1015
                        TotalMppMsat: sql.NullInt64{
×
1016
                                Int64: int64(newHtlc.MppTotalAmt),
×
1017
                                Valid: newHtlc.MppTotalAmt != 0,
×
1018
                        },
×
1019
                        AcceptHeight: int32(newHtlc.AcceptHeight),
1020
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
1021
                        ExpiryHeight: int32(newHtlc.Expiry),
×
1022
                        State:        int16(newHtlc.State),
×
1023
                        InvoiceID:    int64(s.invoice.AddIndex),
×
1024
                },
×
1025
        )
1026
        if err != nil {
1027
                return err
×
1028
        }
1029

×
1030
        for key, value := range newHtlc.CustomRecords {
×
1031
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
1032
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
1033
                                // TODO(bhandras): schema might be wrong here
×
1034
                                // as the custom record key is an uint64.
×
1035
                                Key:    int64(key),
×
1036
                                Value:  value,
1037
                                HtlcID: htlcPrimaryKeyID,
×
1038
                        },
×
1039
                )
×
1040
                if err != nil {
×
1041
                        return err
×
1042
                }
1043
        }
1044

1045
        if newHtlc.AMP != nil {
1046
                setID := newHtlc.AMP.Record.SetID()
×
1047

×
1048
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
1049
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
1050
                                SetID:     setID[:],
×
1051
                                CreatedAt: s.updateTime.UTC(),
×
1052
                                InvoiceID: int64(s.invoice.AddIndex),
×
1053
                        },
1054
                )
1055
                if err != nil {
×
1056
                        mappedSQLErr := sqldb.MapSQLError(err)
×
1057
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
×
1058
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
1059
                                return ErrDuplicateSetID{
×
1060
                                        SetID: setID,
×
1061
                                }
×
1062
                        }
×
1063

1064
                        return err
1065
                }
1066

1067
                // If we're just inserting the AMP invoice, we'll get a non
1068
                // zero rows affected count.
1069
                rowsAffected, err := upsertResult.RowsAffected()
1070
                if err != nil {
1071
                        return err
1072
                }
1073
                if rowsAffected != 0 {
1074
                        // If we're inserting a new AMP invoice, we'll also
1075
                        // insert a new invoice event.
1076
                        err = s.db.OnAMPSubInvoiceCreated(
×
1077
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
1078
                                        AddedAt:   s.updateTime.UTC(),
×
1079
                                        InvoiceID: int64(s.invoice.AddIndex),
×
1080
                                        SetID:     setID[:],
×
1081
                                },
×
1082
                        )
×
1083
                        if err != nil {
×
1084
                                return err
×
1085
                        }
×
1086
                }
×
1087

×
1088
                rootShare := newHtlc.AMP.Record.RootShare()
×
1089

×
1090
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
1091
                        InvoiceID: int64(s.invoice.AddIndex),
×
1092
                        SetID:     setID[:],
×
1093
                        HtlcID:    htlcPrimaryKeyID,
×
1094
                        RootShare: rootShare[:],
×
1095
                        ChildIndex: int64(
×
1096
                                newHtlc.AMP.Record.ChildIndex(),
×
1097
                        ),
×
1098
                        Hash: newHtlc.AMP.Hash[:],
×
1099
                }
1100

×
1101
                if newHtlc.AMP.Preimage != nil {
×
1102
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
1103
                }
×
1104

×
1105
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
1106
                if err != nil {
×
1107
                        return err
×
1108
                }
×
1109
        }
×
1110

×
1111
        return nil
×
1112
}
×
1113

1114
// ResolveHtlc marks an htlc as resolved with the given state.
1115
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
×
1116
        state HtlcState, resolveTime time.Time) error {
×
1117

×
1118
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
1119
                HtlcID: int64(circuitKey.HtlcID),
×
1120
                ChanID: strconv.FormatUint(
×
1121
                        circuitKey.ChanID.ToUint64(), 10,
×
1122
                ),
×
1123
                InvoiceID:   int64(s.invoice.AddIndex),
×
1124
                State:       int16(state),
×
1125
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
1126
        })
×
1127
}
×
1128

×
1129
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
×
1130
// identified by the setID.
×
1131
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
×
1132
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
1133

1134
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
1135
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
1136
                        InvoiceID: int64(s.invoice.AddIndex),
1137
                        SetID:     setID[:],
1138
                        HtlcID:    int64(circuitKey.HtlcID),
1139
                        Preimage:  preimage[:],
×
1140
                        ChanID: strconv.FormatUint(
×
1141
                                circuitKey.ChanID.ToUint64(), 10,
×
1142
                        ),
×
1143
                },
×
1144
        )
×
1145
        if err != nil {
×
1146
                return err
×
1147
        }
×
1148

×
1149
        rowsAffected, err := result.RowsAffected()
×
1150
        if err != nil {
×
1151
                return err
×
1152
        }
×
1153
        if rowsAffected == 0 {
×
1154
                return ErrInvoiceNotFound
×
1155
        }
×
1156

1157
        return nil
1158
}
×
1159

×
1160
// UpdateInvoiceState updates the invoice state to the new state.
×
1161
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
×
1162
        newState ContractState, preimage *lntypes.Preimage) error {
×
1163

×
1164
        var (
×
1165
                settleIndex sql.NullInt64
×
1166
                settledAt   sql.NullTime
×
1167
        )
×
1168

×
1169
        switch newState {
×
1170
        case ContractSettled:
×
1171
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1172
                if err != nil {
×
1173
                        return err
×
1174
                }
1175

×
1176
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1177

×
1178
                // If the invoice is settled, we'll also update the settle time.
×
1179
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
1180

1181
                err = s.db.OnInvoiceSettled(
×
1182
                        s.ctx, sqlc.OnInvoiceSettledParams{
1183
                                AddedAt:   s.updateTime.UTC(),
1184
                                InvoiceID: int64(s.invoice.AddIndex),
1185
                        },
1186
                )
×
1187
                if err != nil {
×
1188
                        return err
×
1189
                }
×
1190

×
1191
        case ContractCanceled:
×
1192
                err := s.db.OnInvoiceCanceled(
×
1193
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
1194
                                AddedAt:   s.updateTime.UTC(),
×
1195
                                InvoiceID: int64(s.invoice.AddIndex),
×
1196
                        },
×
1197
                )
×
1198
                if err != nil {
1199
                        return err
1200
                }
1201
        }
1202

×
1203
        params := sqlc.UpdateInvoiceStateParams{
×
1204
                ID:          int64(s.invoice.AddIndex),
×
1205
                State:       int16(newState),
×
1206
                SettleIndex: settleIndex,
×
1207
                SettledAt:   settledAt,
×
1208
        }
×
1209

×
1210
        if preimage != nil {
×
1211
                params.Preimage = preimage[:]
×
1212
        }
×
1213

×
1214
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
1215
        if err != nil {
×
1216
                return err
×
1217
        }
×
1218
        rowsAffected, err := result.RowsAffected()
1219
        if err != nil {
×
1220
                return err
×
1221
        }
×
1222

×
1223
        if rowsAffected == 0 {
×
1224
                return ErrInvoiceNotFound
×
1225
        }
×
1226

1227
        if settleIndex.Valid {
×
1228
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
1229
                s.invoice.SettleDate = s.updateTime
1230
        }
1231

1232
        return nil
×
1233
}
×
1234

×
1235
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
×
1236
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
×
1237
        amtPaid lnwire.MilliSatoshi) error {
×
1238

×
1239
        _, err := s.db.UpdateInvoiceAmountPaid(
×
1240
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
1241
                        ID:             int64(s.invoice.AddIndex),
×
1242
                        AmountPaidMsat: int64(amtPaid),
×
1243
                },
×
1244
        )
×
1245

1246
        return err
×
1247
}
×
1248

×
1249
// UpdateAmpState updates the state of the AMP sub invoice identified by the
×
1250
// setID.
×
1251
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
×
1252
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
1253

×
1254
        var (
×
1255
                settleIndex sql.NullInt64
×
1256
                settledAt   sql.NullTime
×
1257
        )
×
1258

×
1259
        switch newState.State {
×
1260
        case HtlcStateSettled:
1261
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

×
1266
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1267

×
1268
                // If the invoice is settled, we'll also update the settle time.
×
1269
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
1270

×
1271
                err = s.db.OnAMPSubInvoiceSettled(
1272
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
1273
                                AddedAt:   s.updateTime.UTC(),
×
1274
                                InvoiceID: int64(s.invoice.AddIndex),
×
1275
                                SetID:     setID[:],
×
1276
                        },
×
1277
                )
×
1278
                if err != nil {
×
1279
                        return err
×
1280
                }
×
1281

×
1282
        case HtlcStateCanceled:
×
1283
                err := s.db.OnAMPSubInvoiceCanceled(
1284
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
1285
                                AddedAt:   s.updateTime.UTC(),
×
1286
                                InvoiceID: int64(s.invoice.AddIndex),
×
1287
                                SetID:     setID[:],
×
1288
                        },
×
1289
                )
×
1290
                if err != nil {
×
1291
                        return err
×
1292
                }
1293
        }
×
1294

×
1295
        err := s.db.UpdateAMPSubInvoiceState(
×
1296
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
1297
                        SetID:       setID[:],
×
1298
                        State:       int16(newState.State),
×
1299
                        SettleIndex: settleIndex,
×
1300
                        SettledAt:   settledAt,
×
1301
                },
1302
        )
×
1303
        if err != nil {
1304
                return err
1305
        }
1306

1307
        if settleIndex.Valid {
×
1308
                updatedState := s.invoice.AMPState[setID]
×
1309
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
1310
                updatedState.SettleDate = s.updateTime.UTC()
×
1311
                s.invoice.AMPState[setID] = updatedState
×
1312
        }
×
1313

×
1314
        return nil
×
1315
}
×
1316

×
1317
// Finalize finalizes the update before it is written to the database. Note that
×
1318
// we don't use this directly in the SQL implementation, so the function is just
1319
// a stub.
1320
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
1321
        return nil
1322
}
×
1323

×
1324
// UpdateInvoice attempts to update an invoice corresponding to the passed
×
1325
// reference. If an invoice matching the passed reference doesn't exist within
×
1326
// the database, then the action will fail with  ErrInvoiceNotFound error.
×
1327
//
×
1328
// The update is performed inside the same database transaction that fetches the
×
1329
// invoice and is therefore atomic. The fields to update are controlled by the
×
1330
// supplied callback.
×
1331
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
×
1332
        setID *SetID, callback InvoiceUpdateCallback) (
×
1333
        *Invoice, error) {
×
1334

×
1335
        var updatedInvoice *Invoice
1336

×
1337
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
×
1338
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
×
1339
                if setID != nil {
×
1340
                        // Make sure to use the set ID if this is an AMP update.
×
1341
                        var setIDBytes [32]byte
×
1342
                        copy(setIDBytes[:], setID[:])
×
1343
                        ref.setID = &setIDBytes
×
1344

×
1345
                        // If we're updating an AMP invoice, we'll also only
×
1346
                        // need to fetch the HTLCs for the given set ID.
×
1347
                        ref.refModifier = HtlcSetOnlyModifier
×
1348
                }
×
1349

×
1350
                invoice, err := i.fetchInvoice(ctx, db, ref)
×
1351
                if err != nil {
1352
                        return err
×
1353
                }
×
1354

×
1355
                updateTime := i.clock.Now()
×
1356
                updater := &sqlInvoiceUpdater{
×
1357
                        db:         db,
×
1358
                        ctx:        ctx,
×
1359
                        invoice:    invoice,
×
1360
                        updateTime: updateTime,
×
1361
                }
×
1362

×
1363
                payHash := ref.PayHash()
1364
                updatedInvoice, err = UpdateInvoice(
1365
                        payHash, invoice, updateTime, callback, updater,
×
1366
                )
×
1367

×
1368
                return err
×
1369
        }, func() {})
×
1370
        if txErr != nil {
×
1371
                // If the invoice is already settled, we'll return the
×
1372
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
1373
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
1374
                        return updatedInvoice, txErr
×
1375
                }
×
1376

1377
                return nil, txErr
×
1378
        }
×
1379

×
1380
        return updatedInvoice, nil
×
1381
}
×
1382

×
1383
// DeleteInvoice attempts to delete the passed invoices and all their related
1384
// data from the database in one transaction.
×
1385
func (i *SQLStore) DeleteInvoice(ctx context.Context,
1386
        invoicesToDelete []InvoiceDeleteRef) error {
1387

1388
        // All the InvoiceDeleteRef instances include the add index of the
1389
        // invoice. The rest was added to ensure that the invoices were deleted
1390
        // properly in the kv database. When we have fully migrated we can
×
1391
        // remove the rest of the fields.
×
1392
        for _, ref := range invoicesToDelete {
×
1393
                if ref.AddIndex == 0 {
1394
                        return fmt.Errorf("unable to delete invoice using a "+
1395
                                "ref without AddIndex set: %v", ref)
1396
                }
1397
        }
1398

1399
        var writeTxOpt SQLInvoiceQueriesTxOptions
1400
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
1401
                for _, ref := range invoicesToDelete {
1402
                        params := sqlc.DeleteInvoiceParams{
1403
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
1404
                        }
×
1405

×
1406
                        if ref.SettleIndex != 0 {
×
1407
                                params.SettleIndex = sqldb.SQLInt64(
×
1408
                                        ref.SettleIndex,
×
1409
                                )
×
1410
                        }
×
1411

×
1412
                        if ref.PayHash != lntypes.ZeroHash {
×
1413
                                params.Hash = ref.PayHash[:]
×
1414
                        }
×
1415

×
1416
                        result, err := db.DeleteInvoice(ctx, params)
×
1417
                        if err != nil {
×
1418
                                return fmt.Errorf("unable to delete "+
×
1419
                                        "invoice(%v): %w", ref.AddIndex, err)
1420
                        }
×
1421
                        rowsAffected, err := result.RowsAffected()
×
1422
                        if err != nil {
×
1423
                                return fmt.Errorf("unable to get rows "+
×
1424
                                        "affected: %w", err)
1425
                        }
×
1426
                        if rowsAffected == 0 {
×
1427
                                return fmt.Errorf("%w: %v",
×
1428
                                        ErrInvoiceNotFound, ref.AddIndex)
×
1429
                        }
×
1430
                }
×
1431

×
1432
                return nil
×
1433
        }, func() {})
×
1434

×
1435
        if err != nil {
×
1436
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1437
        }
×
1438

×
1439
        return nil
×
1440
}
×
1441

×
1442
// DeleteCanceledInvoices removes all canceled invoices from the database.
×
1443
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
1444
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
1445
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
1446
                _, err := db.DeleteCanceledInvoices(ctx)
1447
                if err != nil {
×
1448
                        return fmt.Errorf("unable to delete canceled "+
1449
                                "invoices: %w", err)
1450
                }
×
1451

1452
                return nil
1453
        }, func() {})
1454
        if err != nil {
1455
                return fmt.Errorf("unable to delete invoices: %w", err)
1456
        }
×
1457

×
1458
        return nil
×
1459
}
×
1460

×
1461
// fetchInvoiceData fetches additional data for the given invoice. If the
×
1462
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
×
1463
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
×
1464
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
×
1465
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
×
1466
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
×
1467
        *Invoice, error) {
1468

1469
        // Unmarshal the common data.
×
1470
        hash, invoice, err := unmarshalInvoice(row)
×
1471
        if err != nil {
×
1472
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1473
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1474
        }
×
1475

×
1476
        // Fetch the invoice features.
×
1477
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
1478
        if err != nil {
×
1479
                return nil, nil, err
×
1480
        }
×
1481

1482
        invoice.Terms.Features = features
×
1483

×
1484
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
1485
        // with the HTLCs (if requested).
1486
        if invoice.IsAMP() {
×
1487
                invoiceID := int64(invoice.AddIndex)
×
1488
                ampState, ampHtlcs, err := fetchAmpState(
×
1489
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
1490
                )
×
1491
                if err != nil {
×
1492
                        return nil, nil, err
×
1493
                }
×
1494

×
1495
                invoice.AMPState = ampState
×
1496
                invoice.Htlcs = ampHtlcs
×
1497

×
1498
                return hash, invoice, nil
×
1499
        }
×
1500

1501
        // Otherwise simply fetch the invoice HTLCs.
1502
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
1503
        if err != nil {
×
1504
                return nil, nil, err
1505
        }
×
1506

×
1507
        if len(htlcs) > 0 {
×
1508
                invoice.Htlcs = htlcs
1509
                var amountPaid lnwire.MilliSatoshi
×
1510
                for _, htlc := range htlcs {
1511
                        if htlc.State == HtlcStateSettled {
1512
                                amountPaid += htlc.Amt
1513
                        }
×
1514
                }
×
1515
                invoice.AmtPaid = amountPaid
×
1516
        }
×
1517

×
1518
        return hash, invoice, nil
×
1519
}
×
1520

×
1521
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1522
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
×
1523
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
1524

×
1525
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
1526
        if err != nil {
×
1527
                return nil, fmt.Errorf("unable to get invoice features: %w",
1528
                        err)
×
1529
        }
1530

1531
        features := lnwire.EmptyFeatureVector()
1532
        for _, feature := range rows {
1533
                features.Set(lnwire.FeatureBit(feature.Feature))
1534
        }
1535

1536
        return features, nil
1537
}
×
1538

×
1539
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
×
1540
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
×
1541
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
1542

×
1543
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
1544
        if err != nil {
×
1545
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
1546
        }
1547

×
1548
        // We have no htlcs to unmarshal.
×
1549
        if len(htlcRows) == 0 {
×
1550
                return nil, nil
×
1551
        }
1552

×
1553
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
1554
        if err != nil {
×
1555
                return nil, fmt.Errorf("unable to get custom records for "+
×
1556
                        "invoice htlcs: %w", err)
×
1557
        }
×
1558

×
1559
        cr := make(map[int64]record.CustomSet, len(crRows))
×
1560
        for _, row := range crRows {
×
1561
                if _, ok := cr[row.HtlcID]; !ok {
×
1562
                        cr[row.HtlcID] = make(record.CustomSet)
×
1563
                }
×
1564

1565
                value := row.Value
×
1566
                if value == nil {
×
1567
                        value = []byte{}
×
1568
                }
×
1569
                cr[row.HtlcID][uint64(row.Key)] = value
1570
        }
1571

1572
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
1573

×
1574
        for _, row := range htlcRows {
×
1575
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
1576
                if err != nil {
1577
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1578
                                "htlc(%d): %w", row.ID, err)
×
1579
                }
×
1580

1581
                if customRecords, ok := cr[row.ID]; ok {
×
1582
                        htlc.CustomRecords = customRecords
1583
                } else {
1584
                        htlc.CustomRecords = make(record.CustomSet)
1585
                }
1586

×
1587
                htlcs[circuiteKey] = htlc
×
1588
        }
×
1589

×
1590
        return htlcs, nil
×
1591
}
×
1592

×
1593
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1594
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
×
1595
        error) {
×
1596

×
1597
        var (
×
1598
                settleIndex    int64
1599
                settledAt      time.Time
×
1600
                memo           []byte
1601
                paymentRequest []byte
1602
                preimage       *lntypes.Preimage
1603
                paymentAddr    [32]byte
1604
        )
×
1605

×
1606
        hash, err := lntypes.MakeHash(row.Hash)
×
1607
        if err != nil {
×
1608
                return nil, nil, err
×
1609
        }
×
1610

1611
        if row.SettleIndex.Valid {
1612
                settleIndex = row.SettleIndex.Int64
×
1613
        }
×
1614

×
1615
        if row.SettledAt.Valid {
1616
                settledAt = row.SettledAt.Time.Local()
×
1617
        }
×
1618

×
1619
        if row.Memo.Valid {
×
1620
                memo = []byte(row.Memo.String)
×
1621
        }
1622

×
1623
        // Keysend payments will have this field empty.
×
1624
        if row.PaymentRequest.Valid {
×
1625
                paymentRequest = []byte(row.PaymentRequest.String)
×
1626
        } else {
×
1627
                paymentRequest = []byte{}
1628
        }
×
1629

×
1630
        // We may not have the preimage if this a hodl invoice.
×
1631
        if row.Preimage != nil {
×
1632
                preimage = &lntypes.Preimage{}
×
1633
                copy(preimage[:], row.Preimage)
1634
        }
1635

×
1636
        copy(paymentAddr[:], row.PaymentAddr)
×
1637

×
1638
        var cltvDelta int32
×
1639
        if row.CltvDelta.Valid {
×
1640
                cltvDelta = row.CltvDelta.Int32
×
1641
        }
×
1642

×
1643
        expiry := time.Duration(row.Expiry) * time.Second
1644

×
1645
        invoice := &Invoice{
×
1646
                SettleIndex:    uint64(settleIndex),
×
1647
                SettleDate:     settledAt,
×
1648
                Memo:           memo,
×
1649
                PaymentRequest: paymentRequest,
1650
                CreationDate:   row.CreatedAt.Local(),
×
1651
                Terms: ContractTerm{
1652
                        FinalCltvDelta:  cltvDelta,
1653
                        Expiry:          expiry,
×
1654
                        PaymentPreimage: preimage,
1655
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
1656
                        PaymentAddr:     paymentAddr,
1657
                },
1658
                AddIndex:    uint64(row.ID),
×
1659
                State:       ContractState(row.State),
×
1660
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
1661
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
1662
                AMPState:    AMPInvoiceState{},
×
1663
                HodlInvoice: row.IsHodl,
×
1664
        }
×
1665

×
1666
        return &hash, invoice, nil
×
1667
}
×
1668

×
1669
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
×
1670
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
×
1671
        *InvoiceHTLC, error) {
×
1672

×
1673
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
1674
        if err != nil {
×
1675
                return CircuitKey{}, nil, err
×
1676
        }
×
1677

1678
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
1679

×
1680
        if row.HtlcID < 0 {
×
1681
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
1682
                        "value: %v", row.HtlcID)
×
1683
        }
×
1684

×
1685
        htlcID := uint64(row.HtlcID)
1686

1687
        circuitKey := CircuitKey{
×
1688
                ChanID: chanID,
×
1689
                HtlcID: htlcID,
×
1690
        }
×
1691

×
1692
        htlc := &InvoiceHTLC{
1693
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
1694
                AcceptHeight: uint32(row.AcceptHeight),
×
1695
                AcceptTime:   row.AcceptTime.Local(),
×
1696
                Expiry:       uint32(row.ExpiryHeight),
×
1697
                State:        HtlcState(row.State),
×
1698
        }
1699

×
1700
        if row.TotalMppMsat.Valid {
×
1701
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
1702
        }
×
1703

×
1704
        if row.ResolveTime.Valid {
×
1705
                htlc.ResolveTime = row.ResolveTime.Time.Local()
1706
        }
×
1707

×
1708
        return circuitKey, htlc, nil
×
1709
}
×
1710

×
1711
// queryWithLimit is a helper method that can be used to query the database
×
1712
// using a limit and offset. The passed query function should return the number
×
1713
// of rows returned and an error if any.
×
1714
func queryWithLimit(query func(int) (int, error), limit int) error {
×
1715
        offset := 0
×
1716
        for {
×
1717
                rows, err := query(offset)
×
1718
                if err != nil {
×
1719
                        return err
×
1720
                }
×
1721

×
1722
                if rows < limit {
×
1723
                        return nil
×
1724
                }
×
1725

×
1726
                offset += limit
×
1727
        }
×
1728
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc