• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 12199391122

06 Dec 2024 01:10PM UTC coverage: 49.807% (-9.1%) from 58.933%
12199391122

push

github

web-flow
Merge pull request #9337 from Guayaba221/patch-1

chore: fix typo in ruby.md

100137 of 201051 relevant lines covered (49.81%)

2.07 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27
)
28

29
// SQLInvoiceQueries is an interface that defines the set of operations that can
30
// be executed against the invoice SQL database.
31
type SQLInvoiceQueries interface { //nolint:interfacebloat
32
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
33
                error)
34

35
        InsertInvoiceFeature(ctx context.Context,
36
                arg sqlc.InsertInvoiceFeatureParams) error
37

38
        InsertInvoiceHTLC(ctx context.Context,
39
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
40

41
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
42
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
43

44
        FilterInvoices(ctx context.Context,
45
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
46

47
        GetInvoice(ctx context.Context,
48
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
49

50
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
51
                error)
52

53
        GetInvoiceFeatures(ctx context.Context,
54
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
55

56
        GetInvoiceHTLCCustomRecords(ctx context.Context,
57
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
58

59
        GetInvoiceHTLCs(ctx context.Context,
60
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
61

62
        UpdateInvoiceState(ctx context.Context,
63
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
64

65
        UpdateInvoiceAmountPaid(ctx context.Context,
66
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
67

68
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
69

70
        UpdateInvoiceHTLC(ctx context.Context,
71
                arg sqlc.UpdateInvoiceHTLCParams) error
72

73
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
74
                sql.Result, error)
75

76
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
77

78
        // AMP sub invoice specific methods.
79
        UpsertAMPSubInvoice(ctx context.Context,
80
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
81

82
        UpdateAMPSubInvoiceState(ctx context.Context,
83
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
84

85
        InsertAMPSubInvoiceHTLC(ctx context.Context,
86
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
87

88
        FetchAMPSubInvoices(ctx context.Context,
89
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
90
                error)
91

92
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
93
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
94
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
95

96
        FetchSettledAMPSubInvoices(ctx context.Context,
97
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
98
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
99

100
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
101
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
102
                error)
103

104
        // Invoice events specific methods.
105
        OnInvoiceCreated(ctx context.Context,
106
                arg sqlc.OnInvoiceCreatedParams) error
107

108
        OnInvoiceCanceled(ctx context.Context,
109
                arg sqlc.OnInvoiceCanceledParams) error
110

111
        OnInvoiceSettled(ctx context.Context,
112
                arg sqlc.OnInvoiceSettledParams) error
113

114
        OnAMPSubInvoiceCreated(ctx context.Context,
115
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
116

117
        OnAMPSubInvoiceCanceled(ctx context.Context,
118
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
119

120
        OnAMPSubInvoiceSettled(ctx context.Context,
121
                arg sqlc.OnAMPSubInvoiceSettledParams) error
122
}
123

124
var _ InvoiceDB = (*SQLStore)(nil)
125

126
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
127
// SQLInvoiceQueries understands.
128
type SQLInvoiceQueriesTxOptions struct {
129
        // readOnly governs if a read only transaction is needed or not.
130
        readOnly bool
131
}
132

133
// ReadOnly returns true if the transaction should be read only.
134
//
135
// NOTE: This implements the TxOptions.
136
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
×
137
        return a.readOnly
×
138
}
×
139

140
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
141
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
×
142
        return SQLInvoiceQueriesTxOptions{
×
143
                readOnly: true,
×
144
        }
×
145
}
×
146

147
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
148
// of batched database operations.
149
type BatchedSQLInvoiceQueries interface {
150
        SQLInvoiceQueries
151

152
        sqldb.BatchedTx[SQLInvoiceQueries]
153
}
154

155
// SQLStore represents a storage backend.
156
type SQLStore struct {
157
        db    BatchedSQLInvoiceQueries
158
        clock clock.Clock
159
        opts  SQLStoreOptions
160
}
161

162
// SQLStoreOptions holds the options for the SQL store.
163
type SQLStoreOptions struct {
164
        paginationLimit int
165
}
166

167
// defaultSQLStoreOptions returns the default options for the SQL store.
168
func defaultSQLStoreOptions() SQLStoreOptions {
×
169
        return SQLStoreOptions{
×
170
                paginationLimit: defaultQueryPaginationLimit,
×
171
        }
×
172
}
×
173

174
// SQLStoreOption is a functional option that can be used to optionally modify
175
// the behavior of the SQL store.
176
type SQLStoreOption func(*SQLStoreOptions)
177

178
// WithPaginationLimit sets the pagination limit for the SQL store queries that
179
// paginate results.
180
func WithPaginationLimit(limit int) SQLStoreOption {
×
181
        return func(o *SQLStoreOptions) {
×
182
                o.paginationLimit = limit
×
183
        }
×
184
}
185

186
// NewSQLStore creates a new SQLStore instance given a open
187
// BatchedSQLInvoiceQueries storage backend.
188
func NewSQLStore(db BatchedSQLInvoiceQueries,
189
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
×
190

×
191
        opts := defaultSQLStoreOptions()
×
192
        for _, applyOption := range options {
×
193
                applyOption(&opts)
×
194
        }
×
195

196
        return &SQLStore{
×
197
                db:    db,
×
198
                clock: clock,
×
199
                opts:  opts,
×
200
        }
×
201
}
202

203
// AddInvoice inserts the targeted invoice into the database. If the invoice has
204
// *any* payment hashes which already exists within the database, then the
205
// insertion will be aborted and rejected due to the strict policy banning any
206
// duplicate payment hashes.
207
//
208
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
209
func (i *SQLStore) AddInvoice(ctx context.Context,
210
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
×
211

×
212
        // Make sure this is a valid invoice before trying to store it in our
×
213
        // DB.
×
214
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
215
                return 0, err
×
216
        }
×
217

218
        var (
×
219
                writeTxOpts SQLInvoiceQueriesTxOptions
×
220
                invoiceID   int64
×
221
        )
×
222

×
223
        // Precompute the payment request hash so we can use it in the query.
×
224
        var paymentRequestHash []byte
×
225
        if len(newInvoice.PaymentRequest) > 0 {
×
226
                h := sha256.New()
×
227
                h.Write(newInvoice.PaymentRequest)
×
228
                paymentRequestHash = h.Sum(nil)
×
229
        }
×
230

231
        err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
×
232
                params := sqlc.InsertInvoiceParams{
×
233
                        Hash:       paymentHash[:],
×
234
                        Memo:       sqldb.SQLStr(string(newInvoice.Memo)),
×
235
                        AmountMsat: int64(newInvoice.Terms.Value),
×
236
                        // Note: BOLT12 invoices don't have a final cltv delta.
×
237
                        CltvDelta: sqldb.SQLInt32(
×
238
                                newInvoice.Terms.FinalCltvDelta,
×
239
                        ),
×
240
                        Expiry: int32(newInvoice.Terms.Expiry.Seconds()),
×
241
                        // Note: keysend invoices don't have a payment request.
×
242
                        PaymentRequest: sqldb.SQLStr(string(
×
243
                                newInvoice.PaymentRequest),
×
244
                        ),
×
245
                        PaymentRequestHash: paymentRequestHash,
×
246
                        State:              int16(newInvoice.State),
×
247
                        AmountPaidMsat:     int64(newInvoice.AmtPaid),
×
248
                        IsAmp:              newInvoice.IsAMP(),
×
249
                        IsHodl:             newInvoice.HodlInvoice,
×
250
                        IsKeysend:          newInvoice.IsKeysend(),
×
251
                        CreatedAt:          newInvoice.CreationDate.UTC(),
×
252
                }
×
253

×
254
                // Some invoices may not have a preimage, like in the case of
×
255
                // HODL invoices.
×
256
                if newInvoice.Terms.PaymentPreimage != nil {
×
257
                        preimage := *newInvoice.Terms.PaymentPreimage
×
258
                        if preimage == UnknownPreimage {
×
259
                                return errors.New("cannot use all-zeroes " +
×
260
                                        "preimage")
×
261
                        }
×
262
                        params.Preimage = preimage[:]
×
263
                }
264

265
                // Some non MPP payments may have the default (invalid) value.
266
                if newInvoice.Terms.PaymentAddr != BlankPayAddr {
×
267
                        params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
×
268
                }
×
269

270
                var err error
×
271
                invoiceID, err = db.InsertInvoice(ctx, params)
×
272
                if err != nil {
×
273
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
274
                }
×
275

276
                // TODO(positiveblue): if invocies do not have custom features
277
                // maybe just store the "invoice type" and populate the features
278
                // based on that.
279
                for feature := range newInvoice.Terms.Features.Features() {
×
280
                        params := sqlc.InsertInvoiceFeatureParams{
×
281
                                InvoiceID: invoiceID,
×
282
                                Feature:   int32(feature),
×
283
                        }
×
284

×
285
                        err := db.InsertInvoiceFeature(ctx, params)
×
286
                        if err != nil {
×
287
                                return fmt.Errorf("unable to insert invoice "+
×
288
                                        "feature(%v): %w", feature, err)
×
289
                        }
×
290
                }
291

292
                // Finally add a new event for this invoice.
293
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
294
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
295
                        InvoiceID: invoiceID,
×
296
                })
×
297
        }, func() {})
×
298
        if err != nil {
×
299
                mappedSQLErr := sqldb.MapSQLError(err)
×
300
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
301
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
302
                        // Add context to unique constraint errors.
×
303
                        return 0, ErrDuplicateInvoice
×
304
                }
×
305

306
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
307
                        paymentHash, err)
×
308
        }
309

310
        newInvoice.AddIndex = uint64(invoiceID)
×
311

×
312
        return newInvoice.AddIndex, nil
×
313
}
314

315
// fetchInvoice fetches the common invoice data and the AMP state for the
316
// invoice with the given reference.
317
func (i *SQLStore) fetchInvoice(ctx context.Context,
318
        db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
×
319

×
320
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
×
321
                return nil, ErrInvoiceNotFound
×
322
        }
×
323

324
        var (
×
325
                invoice *Invoice
×
326
                params  sqlc.GetInvoiceParams
×
327
        )
×
328

×
329
        // Given all invoices are uniquely identified by their payment hash,
×
330
        // we can use it to query a specific invoice.
×
331
        if ref.PayHash() != nil {
×
332
                params.Hash = ref.PayHash()[:]
×
333
        }
×
334

335
        // Newer invoices (0.11 and up) are indexed by payment address in
336
        // addition to payment hash, but pre 0.8 invoices do not have one at
337
        // all. Only allow lookups for payment address if it is not a blank
338
        // payment address, which is a special-cased value for legacy keysend
339
        // invoices.
340
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
341
                params.PaymentAddr = ref.PayAddr()[:]
×
342
        }
×
343

344
        // If the reference has a set ID we'll fetch the invoice which has the
345
        // corresponding AMP sub invoice.
346
        if ref.SetID() != nil {
×
347
                params.SetID = ref.SetID()[:]
×
348
        }
×
349

350
        var (
×
351
                rows []sqlc.Invoice
×
352
                err  error
×
353
        )
×
354

×
355
        // We need to split the query based on how we intend to look up the
×
356
        // invoice. If only the set ID is given then we want to have an exact
×
357
        // match on the set ID. If other fields are given, we want to match on
×
358
        // those fields and the set ID but with a less strict join condition.
×
359
        if params.Hash == nil && params.PaymentAddr == nil &&
×
360
                params.SetID != nil {
×
361

×
362
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
×
363
        } else {
×
364
                rows, err = db.GetInvoice(ctx, params)
×
365
        }
×
366
        switch {
×
367
        case len(rows) == 0:
×
368
                return nil, ErrInvoiceNotFound
×
369

370
        case len(rows) > 1:
×
371
                // In case the reference is ambiguous, meaning it matches more
×
372
                // than        one invoice, we'll return an error.
×
373
                return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
×
374
                        ref.String(), spew.Sdump(rows))
×
375

376
        case err != nil:
×
377
                return nil, fmt.Errorf("unable to fetch invoice: %w", err)
×
378
        }
379

380
        var (
×
381
                setID         *[32]byte
×
382
                fetchAmpHtlcs bool
×
383
        )
×
384

×
385
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
386
        // the modifier.
×
387
        switch ref.Modifier() {
×
388
        case DefaultModifier:
×
389
                // By default we'll fetch all AMP HTLCs.
×
390
                setID = nil
×
391
                fetchAmpHtlcs = true
×
392

393
        case HtlcSetOnlyModifier:
×
394
                // In this case we'll fetch all AMP HTLCs for the
×
395
                // specified set id.
×
396
                if ref.SetID() == nil {
×
397
                        return nil, fmt.Errorf("set ID is required to use " +
×
398
                                "the HTLC set only modifier")
×
399
                }
×
400

401
                setID = ref.SetID()
×
402
                fetchAmpHtlcs = true
×
403

404
        case HtlcSetBlankModifier:
×
405
                // No need to fetch any HTLCs.
×
406
                setID = nil
×
407
                fetchAmpHtlcs = false
×
408

409
        default:
×
410
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
411
                        ref.Modifier())
×
412
        }
413

414
        // Fetch the rest of the invoice data and fill the invoice struct.
415
        _, invoice, err = fetchInvoiceData(
×
416
                ctx, db, rows[0], setID, fetchAmpHtlcs,
×
417
        )
×
418
        if err != nil {
×
419
                return nil, err
×
420
        }
×
421

422
        return invoice, nil
×
423
}
424

425
// fetchAmpState fetches the AMP state for the invoice with the given ID.
426
// Optional setID can be provided to fetch the state for a specific AMP HTLC
427
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
428
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
429
// well.
430
//
431
//nolint:funlen
432
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
433
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
434
        HTLCSet, error) {
×
435

×
436
        var paramSetID []byte
×
437
        if setID != nil {
×
438
                paramSetID = setID[:]
×
439
        }
×
440

441
        // First fetch all the AMP sub invoices for this invoice or the one
442
        // matching the provided set ID.
443
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
444
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
445
                        InvoiceID: invoiceID,
×
446
                        SetID:     paramSetID,
×
447
                },
×
448
        )
×
449
        if err != nil {
×
450
                return nil, nil, err
×
451
        }
×
452

453
        ampState := make(map[SetID]InvoiceStateAMP)
×
454
        for _, row := range ampInvoiceRows {
×
455
                var rowSetID [32]byte
×
456

×
457
                if len(row.SetID) != 32 {
×
458
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
459
                                len(row.SetID))
×
460
                }
×
461

462
                var settleDate time.Time
×
463
                if row.SettledAt.Valid {
×
464
                        settleDate = row.SettledAt.Time.Local()
×
465
                }
×
466

467
                copy(rowSetID[:], row.SetID)
×
468
                ampState[rowSetID] = InvoiceStateAMP{
×
469
                        State:       HtlcState(row.State),
×
470
                        SettleIndex: uint64(row.SettleIndex.Int64),
×
471
                        SettleDate:  settleDate,
×
472
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
473
                }
×
474
        }
475

476
        if !fetchHtlcs {
×
477
                return ampState, nil, nil
×
478
        }
×
479

480
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
481
        if err != nil {
×
482
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
483
                        "invoice HTLCs: %w", err)
×
484
        }
×
485

486
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
487
        for _, row := range customRecordRows {
×
488
                if _, ok := customRecords[row.HtlcID]; !ok {
×
489
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
490
                }
×
491

492
                value := row.Value
×
493
                if value == nil {
×
494
                        value = []byte{}
×
495
                }
×
496

497
                customRecords[row.HtlcID][uint64(row.Key)] = value
×
498
        }
499

500
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
501
        // provided set ID.
502
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
×
503
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
×
504
                        InvoiceID: invoiceID,
×
505
                        SetID:     paramSetID,
×
506
                },
×
507
        )
×
508
        if err != nil {
×
509
                return nil, nil, err
×
510
        }
×
511

512
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
×
513
        for _, row := range ampHtlcRows {
×
514
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
515
                if err != nil {
×
516
                        return nil, nil, err
×
517
                }
×
518

519
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
520

×
521
                if row.HtlcID < 0 {
×
522
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
523
                                "value: %v", row.HtlcID)
×
524
                }
×
525

526
                htlcID := uint64(row.HtlcID)
×
527

×
528
                circuitKey := CircuitKey{
×
529
                        ChanID: chanID,
×
530
                        HtlcID: htlcID,
×
531
                }
×
532

×
533
                htlc := &InvoiceHTLC{
×
534
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
535
                        AcceptHeight: uint32(row.AcceptHeight),
×
536
                        AcceptTime:   row.AcceptTime.Local(),
×
537
                        Expiry:       uint32(row.ExpiryHeight),
×
538
                        State:        HtlcState(row.State),
×
539
                }
×
540

×
541
                if row.TotalMppMsat.Valid {
×
542
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
543
                                row.TotalMppMsat.Int64,
×
544
                        )
×
545
                }
×
546

547
                if row.ResolveTime.Valid {
×
548
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
549
                }
×
550

551
                var (
×
552
                        rootShare [32]byte
×
553
                        setID     [32]byte
×
554
                )
×
555

×
556
                if len(row.RootShare) != 32 {
×
557
                        return nil, nil, fmt.Errorf("invalid root share "+
×
558
                                "length: %d", len(row.RootShare))
×
559
                }
×
560
                copy(rootShare[:], row.RootShare)
×
561

×
562
                if len(row.SetID) != 32 {
×
563
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
564
                                len(row.SetID))
×
565
                }
×
566
                copy(setID[:], row.SetID)
×
567

×
568
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
569
                        return nil, nil, fmt.Errorf("invalid child index "+
×
570
                                "value: %v", row.ChildIndex)
×
571
                }
×
572

573
                ampRecord := record.NewAMP(
×
574
                        rootShare, setID, uint32(row.ChildIndex),
×
575
                )
×
576

×
577
                htlc.AMP = &InvoiceHtlcAMPData{
×
578
                        Record: *ampRecord,
×
579
                }
×
580

×
581
                if len(row.Hash) != 32 {
×
582
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
583
                                len(row.Hash))
×
584
                }
×
585
                copy(htlc.AMP.Hash[:], row.Hash)
×
586

×
587
                if row.Preimage != nil {
×
588
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
589
                        if err != nil {
×
590
                                return nil, nil, err
×
591
                        }
×
592

593
                        htlc.AMP.Preimage = &preimage
×
594
                }
595

596
                if _, ok := customRecords[row.ID]; ok {
×
597
                        htlc.CustomRecords = customRecords[row.ID]
×
598
                } else {
×
599
                        htlc.CustomRecords = make(record.CustomSet)
×
600
                }
×
601

602
                ampHtlcs[circuitKey] = htlc
×
603
        }
604

605
        if len(ampHtlcs) > 0 {
×
606
                for setID := range ampState {
×
607
                        var amtPaid lnwire.MilliSatoshi
×
608
                        invoiceKeys := make(
×
609
                                map[models.CircuitKey]struct{},
×
610
                        )
×
611

×
612
                        for key, htlc := range ampHtlcs {
×
613
                                if htlc.AMP.Record.SetID() != setID {
×
614
                                        continue
×
615
                                }
616

617
                                invoiceKeys[key] = struct{}{}
×
618

×
619
                                if htlc.State != HtlcStateCanceled { //nolint: ll
×
620
                                        amtPaid += htlc.Amt
×
621
                                }
×
622
                        }
623

624
                        setState := ampState[setID]
×
625
                        setState.InvoiceKeys = invoiceKeys
×
626
                        setState.AmtPaid = amtPaid
×
627
                        ampState[setID] = setState
×
628
                }
629
        }
630

631
        return ampState, ampHtlcs, nil
×
632
}
633

634
// LookupInvoice attempts to look up an invoice corresponding the passed in
635
// reference. The reference may be a payment hash, a payment address, or a set
636
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
637
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
638
// error.
639
func (i *SQLStore) LookupInvoice(ctx context.Context,
640
        ref InvoiceRef) (Invoice, error) {
×
641

×
642
        var (
×
643
                invoice *Invoice
×
644
                err     error
×
645
        )
×
646

×
647
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
648
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
649
                invoice, err = i.fetchInvoice(ctx, db, ref)
×
650

×
651
                return err
×
652
        }, func() {})
×
653
        if txErr != nil {
×
654
                return Invoice{}, txErr
×
655
        }
×
656

657
        return *invoice, nil
×
658
}
659

660
// FetchPendingInvoices returns all the invoices that are currently in a
661
// "pending" state. An invoice is pending if it has been created but not yet
662
// settled or canceled.
663
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
664
        map[lntypes.Hash]Invoice, error) {
×
665

×
666
        var invoices map[lntypes.Hash]Invoice
×
667

×
668
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
669
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
670
                return queryWithLimit(func(offset int) (int, error) {
×
671
                        params := sqlc.FilterInvoicesParams{
×
672
                                PendingOnly: true,
×
673
                                NumOffset:   int32(offset),
×
674
                                NumLimit:    int32(i.opts.paginationLimit),
×
675
                                Reverse:     false,
×
676
                        }
×
677

×
678
                        rows, err := db.FilterInvoices(ctx, params)
×
679
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
680
                                return 0, fmt.Errorf("unable to get invoices "+
×
681
                                        "from db: %w", err)
×
682
                        }
×
683

684
                        // Load all the information for the invoices.
685
                        for _, row := range rows {
×
686
                                hash, invoice, err := fetchInvoiceData(
×
687
                                        ctx, db, row, nil, true,
×
688
                                )
×
689
                                if err != nil {
×
690
                                        return 0, err
×
691
                                }
×
692

693
                                invoices[*hash] = *invoice
×
694
                        }
695

696
                        return len(rows), nil
×
697
                }, i.opts.paginationLimit)
698
        }, func() {
×
699
                invoices = make(map[lntypes.Hash]Invoice)
×
700
        })
×
701
        if err != nil {
×
702
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
703
                        err)
×
704
        }
×
705

706
        return invoices, nil
×
707
}
708

709
// InvoicesSettledSince can be used by callers to catch up any settled invoices
710
// they missed within the settled invoice time series. We'll return all known
711
// settled invoice that have a settle index higher than the passed idx.
712
//
713
// NOTE: The index starts from 1. As a result we enforce that specifying a value
714
// below the starting index value is a noop.
715
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
716
        []Invoice, error) {
×
717

×
718
        var invoices []Invoice
×
719

×
720
        if idx == 0 {
×
721
                return invoices, nil
×
722
        }
×
723

724
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
725
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
726
                err := queryWithLimit(func(offset int) (int, error) {
×
727
                        params := sqlc.FilterInvoicesParams{
×
728
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
729
                                NumOffset:      int32(offset),
×
730
                                NumLimit:       int32(i.opts.paginationLimit),
×
731
                                Reverse:        false,
×
732
                        }
×
733

×
734
                        rows, err := db.FilterInvoices(ctx, params)
×
735
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
736
                                return 0, fmt.Errorf("unable to get invoices "+
×
737
                                        "from db: %w", err)
×
738
                        }
×
739

740
                        // Load all the information for the invoices.
741
                        for _, row := range rows {
×
742
                                _, invoice, err := fetchInvoiceData(
×
743
                                        ctx, db, row, nil, true,
×
744
                                )
×
745
                                if err != nil {
×
746
                                        return 0, fmt.Errorf("unable to fetch "+
×
747
                                                "invoice(id=%d) from db: %w",
×
748
                                                row.ID, err)
×
749
                                }
×
750

751
                                invoices = append(invoices, *invoice)
×
752
                        }
753

754
                        return len(rows), nil
×
755
                }, i.opts.paginationLimit)
756
                if err != nil {
×
757
                        return err
×
758
                }
×
759

760
                // Now fetch all the AMP sub invoices that were settled since
761
                // the provided index.
762
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
×
763
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
764
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
765
                        },
×
766
                )
×
767
                if err != nil {
×
768
                        return err
×
769
                }
×
770

771
                for _, ampInvoice := range ampInvoices {
×
772
                        // Convert the row to a sqlc.Invoice so we can use the
×
773
                        // existing fetchInvoiceData function.
×
774
                        sqlInvoice := sqlc.Invoice{
×
775
                                ID:             ampInvoice.ID,
×
776
                                Hash:           ampInvoice.Hash,
×
777
                                Preimage:       ampInvoice.Preimage,
×
778
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
779
                                SettledAt:      ampInvoice.AmpSettledAt,
×
780
                                Memo:           ampInvoice.Memo,
×
781
                                AmountMsat:     ampInvoice.AmountMsat,
×
782
                                CltvDelta:      ampInvoice.CltvDelta,
×
783
                                Expiry:         ampInvoice.Expiry,
×
784
                                PaymentAddr:    ampInvoice.PaymentAddr,
×
785
                                PaymentRequest: ampInvoice.PaymentRequest,
×
786
                                State:          ampInvoice.State,
×
787
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
788
                                IsAmp:          ampInvoice.IsAmp,
×
789
                                IsHodl:         ampInvoice.IsHodl,
×
790
                                IsKeysend:      ampInvoice.IsKeysend,
×
791
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
792
                        }
×
793

×
794
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
795
                        _, invoice, err := fetchInvoiceData(
×
796
                                ctx, db, sqlInvoice,
×
797
                                (*[32]byte)(ampInvoice.SetID), true,
×
798
                        )
×
799
                        if err != nil {
×
800
                                return fmt.Errorf("unable to fetch "+
×
801
                                        "AMP invoice(id=%d) from db: %w",
×
802
                                        ampInvoice.ID, err)
×
803
                        }
×
804

805
                        invoices = append(invoices, *invoice)
×
806
                }
807

808
                return nil
×
809
        }, func() {
×
810
                invoices = nil
×
811
        })
×
812
        if err != nil {
×
813
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
814
                        "index (excluding) %d: %w", idx, err)
×
815
        }
×
816

817
        return invoices, nil
×
818
}
819

820
// InvoicesAddedSince can be used by callers to seek into the event time series
821
// of all the invoices added in the database. This method will return all
822
// invoices with an add index greater than the specified idx.
823
//
824
// NOTE: The index starts from 1. As a result we enforce that specifying a value
825
// below the starting index value is a noop.
826
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
827
        []Invoice, error) {
×
828

×
829
        var result []Invoice
×
830

×
831
        if idx == 0 {
×
832
                return result, nil
×
833
        }
×
834

835
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
836
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
837
                return queryWithLimit(func(offset int) (int, error) {
×
838
                        params := sqlc.FilterInvoicesParams{
×
839
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
840
                                NumOffset:   int32(offset),
×
841
                                NumLimit:    int32(i.opts.paginationLimit),
×
842
                                Reverse:     false,
×
843
                        }
×
844

×
845
                        rows, err := db.FilterInvoices(ctx, params)
×
846
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
847
                                return 0, fmt.Errorf("unable to get invoices "+
×
848
                                        "from db: %w", err)
×
849
                        }
×
850

851
                        // Load all the information for the invoices.
852
                        for _, row := range rows {
×
853
                                _, invoice, err := fetchInvoiceData(
×
854
                                        ctx, db, row, nil, true,
×
855
                                )
×
856
                                if err != nil {
×
857
                                        return 0, err
×
858
                                }
×
859

860
                                result = append(result, *invoice)
×
861
                        }
862

863
                        return len(rows), nil
×
864
                }, i.opts.paginationLimit)
865
        }, func() {
×
866
                result = nil
×
867
        })
×
868

869
        if err != nil {
×
870
                return nil, fmt.Errorf("unable to get invoices added since "+
×
871
                        "index %d: %w", idx, err)
×
872
        }
×
873

874
        return result, nil
×
875
}
876

877
// QueryInvoices allows a caller to query the invoice database for invoices
878
// within the specified add index range.
879
func (i *SQLStore) QueryInvoices(ctx context.Context,
880
        q InvoiceQuery) (InvoiceSlice, error) {
×
881

×
882
        var invoices []Invoice
×
883

×
884
        if q.NumMaxInvoices == 0 {
×
885
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
886
                        "be non-zero")
×
887
        }
×
888

889
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
890
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
891
                return queryWithLimit(func(offset int) (int, error) {
×
892
                        params := sqlc.FilterInvoicesParams{
×
893
                                NumOffset:   int32(offset),
×
894
                                NumLimit:    int32(i.opts.paginationLimit),
×
895
                                PendingOnly: q.PendingOnly,
×
896
                                Reverse:     q.Reversed,
×
897
                        }
×
898

×
899
                        if q.Reversed {
×
900
                                // If the index offset was not set, we want to
×
901
                                // fetch from the lastest invoice.
×
902
                                if q.IndexOffset == 0 {
×
903
                                        params.AddIndexLet = sqldb.SQLInt64(
×
904
                                                int64(math.MaxInt64),
×
905
                                        )
×
906
                                } else {
×
907
                                        // The invoice with index offset id must
×
908
                                        // not be included in the results.
×
909
                                        params.AddIndexLet = sqldb.SQLInt64(
×
910
                                                q.IndexOffset - 1,
×
911
                                        )
×
912
                                }
×
913
                        } else {
×
914
                                // The invoice with index offset id must not be
×
915
                                // included in the results.
×
916
                                params.AddIndexGet = sqldb.SQLInt64(
×
917
                                        q.IndexOffset + 1,
×
918
                                )
×
919
                        }
×
920

921
                        if q.CreationDateStart != 0 {
×
922
                                params.CreatedAfter = sqldb.SQLTime(
×
923
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
924
                                )
×
925
                        }
×
926

927
                        if q.CreationDateEnd != 0 {
×
928
                                // We need to add 1 to the end date as we're
×
929
                                // checking less than the end date in SQL.
×
930
                                params.CreatedBefore = sqldb.SQLTime(
×
931
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
×
932
                                )
×
933
                        }
×
934

935
                        rows, err := db.FilterInvoices(ctx, params)
×
936
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
937
                                return 0, fmt.Errorf("unable to get invoices "+
×
938
                                        "from db: %w", err)
×
939
                        }
×
940

941
                        // Load all the information for the invoices.
942
                        for _, row := range rows {
×
943
                                _, invoice, err := fetchInvoiceData(
×
944
                                        ctx, db, row, nil, true,
×
945
                                )
×
946
                                if err != nil {
×
947
                                        return 0, err
×
948
                                }
×
949

950
                                invoices = append(invoices, *invoice)
×
951

×
952
                                if len(invoices) == int(q.NumMaxInvoices) {
×
953
                                        return 0, nil
×
954
                                }
×
955
                        }
956

957
                        return len(rows), nil
×
958
                }, i.opts.paginationLimit)
959
        }, func() {
×
960
                invoices = nil
×
961
        })
×
962
        if err != nil {
×
963
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
964
                        "invoices: %w", err)
×
965
        }
×
966

967
        if len(invoices) == 0 {
×
968
                return InvoiceSlice{
×
969
                        InvoiceQuery: q,
×
970
                }, nil
×
971
        }
×
972

973
        // If we iterated through the add index in reverse order, then
974
        // we'll need to reverse the slice of invoices to return them in
975
        // forward order.
976
        if q.Reversed {
×
977
                numInvoices := len(invoices)
×
978
                for i := 0; i < numInvoices/2; i++ {
×
979
                        reverse := numInvoices - i - 1
×
980
                        invoices[i], invoices[reverse] =
×
981
                                invoices[reverse], invoices[i]
×
982
                }
×
983
        }
984

985
        res := InvoiceSlice{
×
986
                InvoiceQuery:     q,
×
987
                Invoices:         invoices,
×
988
                FirstIndexOffset: invoices[0].AddIndex,
×
989
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
990
        }
×
991

×
992
        return res, nil
×
993
}
994

995
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
996
// a SQL database as the backend.
997
type sqlInvoiceUpdater struct {
998
        db         SQLInvoiceQueries
999
        ctx        context.Context //nolint:containedctx
1000
        invoice    *Invoice
1001
        updateTime time.Time
1002
}
1003

1004
// AddHtlc adds a new htlc to the invoice.
1005
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
1006
        newHtlc *InvoiceHTLC) error {
×
1007

×
1008
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
1009
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
1010
                        HtlcID: int64(circuitKey.HtlcID),
×
1011
                        ChanID: strconv.FormatUint(
×
1012
                                circuitKey.ChanID.ToUint64(), 10,
×
1013
                        ),
×
1014
                        AmountMsat: int64(newHtlc.Amt),
×
1015
                        TotalMppMsat: sql.NullInt64{
×
1016
                                Int64: int64(newHtlc.MppTotalAmt),
×
1017
                                Valid: newHtlc.MppTotalAmt != 0,
×
1018
                        },
×
1019
                        AcceptHeight: int32(newHtlc.AcceptHeight),
×
1020
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
1021
                        ExpiryHeight: int32(newHtlc.Expiry),
×
1022
                        State:        int16(newHtlc.State),
×
1023
                        InvoiceID:    int64(s.invoice.AddIndex),
×
1024
                },
×
1025
        )
×
1026
        if err != nil {
×
1027
                return err
×
1028
        }
×
1029

1030
        for key, value := range newHtlc.CustomRecords {
×
1031
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
1032
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
1033
                                // TODO(bhandras): schema might be wrong here
×
1034
                                // as the custom record key is an uint64.
×
1035
                                Key:    int64(key),
×
1036
                                Value:  value,
×
1037
                                HtlcID: htlcPrimaryKeyID,
×
1038
                        },
×
1039
                )
×
1040
                if err != nil {
×
1041
                        return err
×
1042
                }
×
1043
        }
1044

1045
        if newHtlc.AMP != nil {
×
1046
                setID := newHtlc.AMP.Record.SetID()
×
1047

×
1048
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
1049
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
1050
                                SetID:     setID[:],
×
1051
                                CreatedAt: s.updateTime.UTC(),
×
1052
                                InvoiceID: int64(s.invoice.AddIndex),
×
1053
                        },
×
1054
                )
×
1055
                if err != nil {
×
1056
                        mappedSQLErr := sqldb.MapSQLError(err)
×
1057
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
×
1058
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
1059
                                return ErrDuplicateSetID{
×
1060
                                        SetID: setID,
×
1061
                                }
×
1062
                        }
×
1063

1064
                        return err
×
1065
                }
1066

1067
                // If we're just inserting the AMP invoice, we'll get a non
1068
                // zero rows affected count.
1069
                rowsAffected, err := upsertResult.RowsAffected()
×
1070
                if err != nil {
×
1071
                        return err
×
1072
                }
×
1073
                if rowsAffected != 0 {
×
1074
                        // If we're inserting a new AMP invoice, we'll also
×
1075
                        // insert a new invoice event.
×
1076
                        err = s.db.OnAMPSubInvoiceCreated(
×
1077
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
1078
                                        AddedAt:   s.updateTime.UTC(),
×
1079
                                        InvoiceID: int64(s.invoice.AddIndex),
×
1080
                                        SetID:     setID[:],
×
1081
                                },
×
1082
                        )
×
1083
                        if err != nil {
×
1084
                                return err
×
1085
                        }
×
1086
                }
1087

1088
                rootShare := newHtlc.AMP.Record.RootShare()
×
1089

×
1090
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
1091
                        InvoiceID: int64(s.invoice.AddIndex),
×
1092
                        SetID:     setID[:],
×
1093
                        HtlcID:    htlcPrimaryKeyID,
×
1094
                        RootShare: rootShare[:],
×
1095
                        ChildIndex: int64(
×
1096
                                newHtlc.AMP.Record.ChildIndex(),
×
1097
                        ),
×
1098
                        Hash: newHtlc.AMP.Hash[:],
×
1099
                }
×
1100

×
1101
                if newHtlc.AMP.Preimage != nil {
×
1102
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
1103
                }
×
1104

1105
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
1106
                if err != nil {
×
1107
                        return err
×
1108
                }
×
1109
        }
1110

1111
        return nil
×
1112
}
1113

1114
// ResolveHtlc marks an htlc as resolved with the given state.
1115
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
1116
        state HtlcState, resolveTime time.Time) error {
×
1117

×
1118
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
1119
                HtlcID: int64(circuitKey.HtlcID),
×
1120
                ChanID: strconv.FormatUint(
×
1121
                        circuitKey.ChanID.ToUint64(), 10,
×
1122
                ),
×
1123
                InvoiceID:   int64(s.invoice.AddIndex),
×
1124
                State:       int16(state),
×
1125
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
1126
        })
×
1127
}
×
1128

1129
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1130
// identified by the setID.
1131
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
1132
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
1133

×
1134
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
1135
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
×
1136
                        InvoiceID: int64(s.invoice.AddIndex),
×
1137
                        SetID:     setID[:],
×
1138
                        HtlcID:    int64(circuitKey.HtlcID),
×
1139
                        Preimage:  preimage[:],
×
1140
                        ChanID: strconv.FormatUint(
×
1141
                                circuitKey.ChanID.ToUint64(), 10,
×
1142
                        ),
×
1143
                },
×
1144
        )
×
1145
        if err != nil {
×
1146
                return err
×
1147
        }
×
1148

1149
        rowsAffected, err := result.RowsAffected()
×
1150
        if err != nil {
×
1151
                return err
×
1152
        }
×
1153
        if rowsAffected == 0 {
×
1154
                return ErrInvoiceNotFound
×
1155
        }
×
1156

1157
        return nil
×
1158
}
1159

1160
// UpdateInvoiceState updates the invoice state to the new state.
1161
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
1162
        newState ContractState, preimage *lntypes.Preimage) error {
×
1163

×
1164
        var (
×
1165
                settleIndex sql.NullInt64
×
1166
                settledAt   sql.NullTime
×
1167
        )
×
1168

×
1169
        switch newState {
×
1170
        case ContractSettled:
×
1171
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1172
                if err != nil {
×
1173
                        return err
×
1174
                }
×
1175

1176
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1177

×
1178
                // If the invoice is settled, we'll also update the settle time.
×
1179
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
1180

×
1181
                err = s.db.OnInvoiceSettled(
×
1182
                        s.ctx, sqlc.OnInvoiceSettledParams{
×
1183
                                AddedAt:   s.updateTime.UTC(),
×
1184
                                InvoiceID: int64(s.invoice.AddIndex),
×
1185
                        },
×
1186
                )
×
1187
                if err != nil {
×
1188
                        return err
×
1189
                }
×
1190

1191
        case ContractCanceled:
×
1192
                err := s.db.OnInvoiceCanceled(
×
1193
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
1194
                                AddedAt:   s.updateTime.UTC(),
×
1195
                                InvoiceID: int64(s.invoice.AddIndex),
×
1196
                        },
×
1197
                )
×
1198
                if err != nil {
×
1199
                        return err
×
1200
                }
×
1201
        }
1202

1203
        params := sqlc.UpdateInvoiceStateParams{
×
1204
                ID:          int64(s.invoice.AddIndex),
×
1205
                State:       int16(newState),
×
1206
                SettleIndex: settleIndex,
×
1207
                SettledAt:   settledAt,
×
1208
        }
×
1209

×
1210
        if preimage != nil {
×
1211
                params.Preimage = preimage[:]
×
1212
        }
×
1213

1214
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
1215
        if err != nil {
×
1216
                return err
×
1217
        }
×
1218
        rowsAffected, err := result.RowsAffected()
×
1219
        if err != nil {
×
1220
                return err
×
1221
        }
×
1222

1223
        if rowsAffected == 0 {
×
1224
                return ErrInvoiceNotFound
×
1225
        }
×
1226

1227
        if settleIndex.Valid {
×
1228
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
×
1229
                s.invoice.SettleDate = s.updateTime
×
1230
        }
×
1231

1232
        return nil
×
1233
}
1234

1235
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1236
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
1237
        amtPaid lnwire.MilliSatoshi) error {
×
1238

×
1239
        _, err := s.db.UpdateInvoiceAmountPaid(
×
1240
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
1241
                        ID:             int64(s.invoice.AddIndex),
×
1242
                        AmountPaidMsat: int64(amtPaid),
×
1243
                },
×
1244
        )
×
1245

×
1246
        return err
×
1247
}
×
1248

1249
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1250
// setID.
1251
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
1252
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
1253

×
1254
        var (
×
1255
                settleIndex sql.NullInt64
×
1256
                settledAt   sql.NullTime
×
1257
        )
×
1258

×
1259
        switch newState.State {
×
1260
        case HtlcStateSettled:
×
1261
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

1266
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
1267

×
1268
                // If the invoice is settled, we'll also update the settle time.
×
1269
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
1270

×
1271
                err = s.db.OnAMPSubInvoiceSettled(
×
1272
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
×
1273
                                AddedAt:   s.updateTime.UTC(),
×
1274
                                InvoiceID: int64(s.invoice.AddIndex),
×
1275
                                SetID:     setID[:],
×
1276
                        },
×
1277
                )
×
1278
                if err != nil {
×
1279
                        return err
×
1280
                }
×
1281

1282
        case HtlcStateCanceled:
×
1283
                err := s.db.OnAMPSubInvoiceCanceled(
×
1284
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
1285
                                AddedAt:   s.updateTime.UTC(),
×
1286
                                InvoiceID: int64(s.invoice.AddIndex),
×
1287
                                SetID:     setID[:],
×
1288
                        },
×
1289
                )
×
1290
                if err != nil {
×
1291
                        return err
×
1292
                }
×
1293
        }
1294

1295
        err := s.db.UpdateAMPSubInvoiceState(
×
1296
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
×
1297
                        SetID:       setID[:],
×
1298
                        State:       int16(newState.State),
×
1299
                        SettleIndex: settleIndex,
×
1300
                        SettledAt:   settledAt,
×
1301
                },
×
1302
        )
×
1303
        if err != nil {
×
1304
                return err
×
1305
        }
×
1306

1307
        if settleIndex.Valid {
×
1308
                updatedState := s.invoice.AMPState[setID]
×
1309
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
1310
                updatedState.SettleDate = s.updateTime.UTC()
×
1311
                s.invoice.AMPState[setID] = updatedState
×
1312
        }
×
1313

1314
        return nil
×
1315
}
1316

1317
// Finalize finalizes the update before it is written to the database. Note that
1318
// we don't use this directly in the SQL implementation, so the function is just
1319
// a stub.
1320
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
×
1321
        return nil
×
1322
}
×
1323

1324
// UpdateInvoice attempts to update an invoice corresponding to the passed
1325
// reference. If an invoice matching the passed reference doesn't exist within
1326
// the database, then the action will fail with  ErrInvoiceNotFound error.
1327
//
1328
// The update is performed inside the same database transaction that fetches the
1329
// invoice and is therefore atomic. The fields to update are controlled by the
1330
// supplied callback.
1331
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1332
        setID *SetID, callback InvoiceUpdateCallback) (
1333
        *Invoice, error) {
×
1334

×
1335
        var updatedInvoice *Invoice
×
1336

×
1337
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
×
1338
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
×
1339
                if setID != nil {
×
1340
                        // Make sure to use the set ID if this is an AMP update.
×
1341
                        var setIDBytes [32]byte
×
1342
                        copy(setIDBytes[:], setID[:])
×
1343
                        ref.setID = &setIDBytes
×
1344

×
1345
                        // If we're updating an AMP invoice, we'll also only
×
1346
                        // need to fetch the HTLCs for the given set ID.
×
1347
                        ref.refModifier = HtlcSetOnlyModifier
×
1348
                }
×
1349

1350
                invoice, err := i.fetchInvoice(ctx, db, ref)
×
1351
                if err != nil {
×
1352
                        return err
×
1353
                }
×
1354

1355
                updateTime := i.clock.Now()
×
1356
                updater := &sqlInvoiceUpdater{
×
1357
                        db:         db,
×
1358
                        ctx:        ctx,
×
1359
                        invoice:    invoice,
×
1360
                        updateTime: updateTime,
×
1361
                }
×
1362

×
1363
                payHash := ref.PayHash()
×
1364
                updatedInvoice, err = UpdateInvoice(
×
1365
                        payHash, invoice, updateTime, callback, updater,
×
1366
                )
×
1367

×
1368
                return err
×
1369
        }, func() {})
×
1370
        if txErr != nil {
×
1371
                // If the invoice is already settled, we'll return the
×
1372
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
1373
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
1374
                        return updatedInvoice, txErr
×
1375
                }
×
1376

1377
                return nil, txErr
×
1378
        }
1379

1380
        return updatedInvoice, nil
×
1381
}
1382

1383
// DeleteInvoice attempts to delete the passed invoices and all their related
1384
// data from the database in one transaction.
1385
func (i *SQLStore) DeleteInvoice(ctx context.Context,
1386
        invoicesToDelete []InvoiceDeleteRef) error {
×
1387

×
1388
        // All the InvoiceDeleteRef instances include the add index of the
×
1389
        // invoice. The rest was added to ensure that the invoices were deleted
×
1390
        // properly in the kv database. When we have fully migrated we can
×
1391
        // remove the rest of the fields.
×
1392
        for _, ref := range invoicesToDelete {
×
1393
                if ref.AddIndex == 0 {
×
1394
                        return fmt.Errorf("unable to delete invoice using a "+
×
1395
                                "ref without AddIndex set: %v", ref)
×
1396
                }
×
1397
        }
1398

1399
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
1400
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
1401
                for _, ref := range invoicesToDelete {
×
1402
                        params := sqlc.DeleteInvoiceParams{
×
1403
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
1404
                        }
×
1405

×
1406
                        if ref.SettleIndex != 0 {
×
1407
                                params.SettleIndex = sqldb.SQLInt64(
×
1408
                                        ref.SettleIndex,
×
1409
                                )
×
1410
                        }
×
1411

1412
                        if ref.PayHash != lntypes.ZeroHash {
×
1413
                                params.Hash = ref.PayHash[:]
×
1414
                        }
×
1415

1416
                        result, err := db.DeleteInvoice(ctx, params)
×
1417
                        if err != nil {
×
1418
                                return fmt.Errorf("unable to delete "+
×
1419
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1420
                        }
×
1421
                        rowsAffected, err := result.RowsAffected()
×
1422
                        if err != nil {
×
1423
                                return fmt.Errorf("unable to get rows "+
×
1424
                                        "affected: %w", err)
×
1425
                        }
×
1426
                        if rowsAffected == 0 {
×
1427
                                return fmt.Errorf("%w: %v",
×
1428
                                        ErrInvoiceNotFound, ref.AddIndex)
×
1429
                        }
×
1430
                }
1431

1432
                return nil
×
1433
        }, func() {})
×
1434

1435
        if err != nil {
×
1436
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1437
        }
×
1438

1439
        return nil
×
1440
}
1441

1442
// DeleteCanceledInvoices removes all canceled invoices from the database.
1443
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
1444
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
1445
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
1446
                _, err := db.DeleteCanceledInvoices(ctx)
×
1447
                if err != nil {
×
1448
                        return fmt.Errorf("unable to delete canceled "+
×
1449
                                "invoices: %w", err)
×
1450
                }
×
1451

1452
                return nil
×
1453
        }, func() {})
×
1454
        if err != nil {
×
1455
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1456
        }
×
1457

1458
        return nil
×
1459
}
1460

1461
// fetchInvoiceData fetches additional data for the given invoice. If the
1462
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1463
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1464
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1465
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1466
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
1467
        *Invoice, error) {
×
1468

×
1469
        // Unmarshal the common data.
×
1470
        hash, invoice, err := unmarshalInvoice(row)
×
1471
        if err != nil {
×
1472
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1473
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1474
        }
×
1475

1476
        // Fetch the invoice features.
1477
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
1478
        if err != nil {
×
1479
                return nil, nil, err
×
1480
        }
×
1481

1482
        invoice.Terms.Features = features
×
1483

×
1484
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
1485
        // with the HTLCs (if requested).
×
1486
        if invoice.IsAMP() {
×
1487
                invoiceID := int64(invoice.AddIndex)
×
1488
                ampState, ampHtlcs, err := fetchAmpState(
×
1489
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
1490
                )
×
1491
                if err != nil {
×
1492
                        return nil, nil, err
×
1493
                }
×
1494

1495
                invoice.AMPState = ampState
×
1496
                invoice.Htlcs = ampHtlcs
×
1497

×
1498
                return hash, invoice, nil
×
1499
        }
1500

1501
        // Otherwise simply fetch the invoice HTLCs.
1502
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
1503
        if err != nil {
×
1504
                return nil, nil, err
×
1505
        }
×
1506

1507
        if len(htlcs) > 0 {
×
1508
                invoice.Htlcs = htlcs
×
1509
                var amountPaid lnwire.MilliSatoshi
×
1510
                for _, htlc := range htlcs {
×
1511
                        if htlc.State == HtlcStateSettled {
×
1512
                                amountPaid += htlc.Amt
×
1513
                        }
×
1514
                }
1515
                invoice.AmtPaid = amountPaid
×
1516
        }
1517

1518
        return hash, invoice, nil
×
1519
}
1520

1521
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1522
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
1523
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
1524

×
1525
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
1526
        if err != nil {
×
1527
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1528
                        err)
×
1529
        }
×
1530

1531
        features := lnwire.EmptyFeatureVector()
×
1532
        for _, feature := range rows {
×
1533
                features.Set(lnwire.FeatureBit(feature.Feature))
×
1534
        }
×
1535

1536
        return features, nil
×
1537
}
1538

1539
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1540
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
1541
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
1542

×
1543
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
1544
        if err != nil {
×
1545
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1546
        }
×
1547

1548
        // We have no htlcs to unmarshal.
1549
        if len(htlcRows) == 0 {
×
1550
                return nil, nil
×
1551
        }
×
1552

1553
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
1554
        if err != nil {
×
1555
                return nil, fmt.Errorf("unable to get custom records for "+
×
1556
                        "invoice htlcs: %w", err)
×
1557
        }
×
1558

1559
        cr := make(map[int64]record.CustomSet, len(crRows))
×
1560
        for _, row := range crRows {
×
1561
                if _, ok := cr[row.HtlcID]; !ok {
×
1562
                        cr[row.HtlcID] = make(record.CustomSet)
×
1563
                }
×
1564

1565
                value := row.Value
×
1566
                if value == nil {
×
1567
                        value = []byte{}
×
1568
                }
×
1569
                cr[row.HtlcID][uint64(row.Key)] = value
×
1570
        }
1571

1572
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
1573

×
1574
        for _, row := range htlcRows {
×
1575
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
1576
                if err != nil {
×
1577
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1578
                                "htlc(%d): %w", row.ID, err)
×
1579
                }
×
1580

1581
                if customRecords, ok := cr[row.ID]; ok {
×
1582
                        htlc.CustomRecords = customRecords
×
1583
                } else {
×
1584
                        htlc.CustomRecords = make(record.CustomSet)
×
1585
                }
×
1586

1587
                htlcs[circuiteKey] = htlc
×
1588
        }
1589

1590
        return htlcs, nil
×
1591
}
1592

1593
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1594
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
1595
        error) {
×
1596

×
1597
        var (
×
1598
                settleIndex    int64
×
1599
                settledAt      time.Time
×
1600
                memo           []byte
×
1601
                paymentRequest []byte
×
1602
                preimage       *lntypes.Preimage
×
1603
                paymentAddr    [32]byte
×
1604
        )
×
1605

×
1606
        hash, err := lntypes.MakeHash(row.Hash)
×
1607
        if err != nil {
×
1608
                return nil, nil, err
×
1609
        }
×
1610

1611
        if row.SettleIndex.Valid {
×
1612
                settleIndex = row.SettleIndex.Int64
×
1613
        }
×
1614

1615
        if row.SettledAt.Valid {
×
1616
                settledAt = row.SettledAt.Time.Local()
×
1617
        }
×
1618

1619
        if row.Memo.Valid {
×
1620
                memo = []byte(row.Memo.String)
×
1621
        }
×
1622

1623
        // Keysend payments will have this field empty.
1624
        if row.PaymentRequest.Valid {
×
1625
                paymentRequest = []byte(row.PaymentRequest.String)
×
1626
        } else {
×
1627
                paymentRequest = []byte{}
×
1628
        }
×
1629

1630
        // We may not have the preimage if this a hodl invoice.
1631
        if row.Preimage != nil {
×
1632
                preimage = &lntypes.Preimage{}
×
1633
                copy(preimage[:], row.Preimage)
×
1634
        }
×
1635

1636
        copy(paymentAddr[:], row.PaymentAddr)
×
1637

×
1638
        var cltvDelta int32
×
1639
        if row.CltvDelta.Valid {
×
1640
                cltvDelta = row.CltvDelta.Int32
×
1641
        }
×
1642

1643
        expiry := time.Duration(row.Expiry) * time.Second
×
1644

×
1645
        invoice := &Invoice{
×
1646
                SettleIndex:    uint64(settleIndex),
×
1647
                SettleDate:     settledAt,
×
1648
                Memo:           memo,
×
1649
                PaymentRequest: paymentRequest,
×
1650
                CreationDate:   row.CreatedAt.Local(),
×
1651
                Terms: ContractTerm{
×
1652
                        FinalCltvDelta:  cltvDelta,
×
1653
                        Expiry:          expiry,
×
1654
                        PaymentPreimage: preimage,
×
1655
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
×
1656
                        PaymentAddr:     paymentAddr,
×
1657
                },
×
1658
                AddIndex:    uint64(row.ID),
×
1659
                State:       ContractState(row.State),
×
1660
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
1661
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
1662
                AMPState:    AMPInvoiceState{},
×
1663
                HodlInvoice: row.IsHodl,
×
1664
        }
×
1665

×
1666
        return &hash, invoice, nil
×
1667
}
1668

1669
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1670
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
1671
        *InvoiceHTLC, error) {
×
1672

×
1673
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
1674
        if err != nil {
×
1675
                return CircuitKey{}, nil, err
×
1676
        }
×
1677

1678
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
1679

×
1680
        if row.HtlcID < 0 {
×
1681
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1682
                        "value: %v", row.HtlcID)
×
1683
        }
×
1684

1685
        htlcID := uint64(row.HtlcID)
×
1686

×
1687
        circuitKey := CircuitKey{
×
1688
                ChanID: chanID,
×
1689
                HtlcID: htlcID,
×
1690
        }
×
1691

×
1692
        htlc := &InvoiceHTLC{
×
1693
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
1694
                AcceptHeight: uint32(row.AcceptHeight),
×
1695
                AcceptTime:   row.AcceptTime.Local(),
×
1696
                Expiry:       uint32(row.ExpiryHeight),
×
1697
                State:        HtlcState(row.State),
×
1698
        }
×
1699

×
1700
        if row.TotalMppMsat.Valid {
×
1701
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
1702
        }
×
1703

1704
        if row.ResolveTime.Valid {
×
1705
                htlc.ResolveTime = row.ResolveTime.Time.Local()
×
1706
        }
×
1707

1708
        return circuitKey, htlc, nil
×
1709
}
1710

1711
// queryWithLimit is a helper method that can be used to query the database
1712
// using a limit and offset. The passed query function should return the number
1713
// of rows returned and an error if any.
1714
func queryWithLimit(query func(int) (int, error), limit int) error {
×
1715
        offset := 0
×
1716
        for {
×
1717
                rows, err := query(offset)
×
1718
                if err != nil {
×
1719
                        return err
×
1720
                }
×
1721

1722
                if rows < limit {
×
1723
                        return nil
×
1724
                }
×
1725

1726
                offset += limit
×
1727
        }
1728
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc