• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13043384202

30 Jan 2025 12:49AM UTC coverage: 48.841% (-9.9%) from 58.777%
13043384202

Pull #9459

github

ziggie1984
docs: add release notes.
Pull Request #9459: invoices: amp invoices bugfix.

28 of 45 new or added lines in 3 files covered. (62.22%)

28177 existing lines in 437 files now uncovered.

99712 of 204157 relevant lines covered (48.84%)

1.02 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.51
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "time"
11

12
        "github.com/lightningnetwork/lnd/graph/db/models"
13
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
14
        invpkg "github.com/lightningnetwork/lnd/invoices"
15
        "github.com/lightningnetwork/lnd/kvdb"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/tlv"
20
)
21

22
var (
23
        // invoiceBucket is the name of the bucket within the database that
24
        // stores all data related to invoices no matter their final state.
25
        // Within the invoice bucket, each invoice is keyed by its invoice ID
26
        // which is a monotonically increasing uint32.
27
        invoiceBucket = []byte("invoices")
28

29
        // paymentHashIndexBucket is the name of the sub-bucket within the
30
        // invoiceBucket which indexes all invoices by their payment hash. The
31
        // payment hash is the sha256 of the invoice's payment preimage. This
32
        // index is used to detect duplicates, and also to provide a fast path
33
        // for looking up incoming HTLCs to determine if we're able to settle
34
        // them fully.
35
        //
36
        // maps: payHash => invoiceKey
37
        invoiceIndexBucket = []byte("paymenthashes")
38

39
        // payAddrIndexBucket is the name of the top-level bucket that maps
40
        // payment addresses to their invoice number. This can be used
41
        // to efficiently query or update non-legacy invoices. Note that legacy
42
        // invoices will not be included in this index since they all have the
43
        // same, all-zero payment address, however all newly generated invoices
44
        // will end up in this index.
45
        //
46
        // maps: payAddr => invoiceKey
47
        payAddrIndexBucket = []byte("pay-addr-index")
48

49
        // setIDIndexBucket is the name of the top-level bucket that maps set
50
        // ids to their invoice number. This can be used to efficiently query or
51
        // update AMP invoice. Note that legacy or MPP invoices will not be
52
        // included in this index, since their HTLCs do not have a set id.
53
        //
54
        // maps: setID => invoiceKey
55
        setIDIndexBucket = []byte("set-id-index")
56

57
        // numInvoicesKey is the name of key which houses the auto-incrementing
58
        // invoice ID which is essentially used as a primary key. With each
59
        // invoice inserted, the primary key is incremented by one. This key is
60
        // stored within the invoiceIndexBucket. Within the invoiceBucket
61
        // invoices are uniquely identified by the invoice ID.
62
        numInvoicesKey = []byte("nik")
63

64
        // addIndexBucket is an index bucket that we'll use to create a
65
        // monotonically increasing set of add indexes. Each time we add a new
66
        // invoice, this sequence number will be incremented and then populated
67
        // within the new invoice.
68
        //
69
        // In addition to this sequence number, we map:
70
        //
71
        //   addIndexNo => invoiceKey
72
        addIndexBucket = []byte("invoice-add-index")
73

74
        // settleIndexBucket is an index bucket that we'll use to create a
75
        // monotonically increasing integer for tracking a "settle index". Each
76
        // time an invoice is settled, this sequence number will be incremented
77
        // as populate within the newly settled invoice.
78
        //
79
        // In addition to this sequence number, we map:
80
        //
81
        //   settleIndexNo => invoiceKey
82
        settleIndexBucket = []byte("invoice-settle-index")
83
)
84

85
const (
86
        // A set of tlv type definitions used to serialize invoice htlcs to the
87
        // database.
88
        //
89
        // NOTE: A migration should be added whenever this list changes. This
90
        // prevents against the database being rolled back to an older
91
        // format where the surrounding logic might assume a different set of
92
        // fields are known.
93
        chanIDType       tlv.Type = 1
94
        htlcIDType       tlv.Type = 3
95
        amtType          tlv.Type = 5
96
        acceptHeightType tlv.Type = 7
97
        acceptTimeType   tlv.Type = 9
98
        resolveTimeType  tlv.Type = 11
99
        expiryHeightType tlv.Type = 13
100
        htlcStateType    tlv.Type = 15
101
        mppTotalAmtType  tlv.Type = 17
102
        htlcAMPType      tlv.Type = 19
103
        htlcHashType     tlv.Type = 21
104
        htlcPreimageType tlv.Type = 23
105

106
        // A set of tlv type definitions used to serialize invoice bodiees.
107
        //
108
        // NOTE: A migration should be added whenever this list changes. This
109
        // prevents against the database being rolled back to an older
110
        // format where the surrounding logic might assume a different set of
111
        // fields are known.
112
        memoType            tlv.Type = 0
113
        payReqType          tlv.Type = 1
114
        createTimeType      tlv.Type = 2
115
        settleTimeType      tlv.Type = 3
116
        addIndexType        tlv.Type = 4
117
        settleIndexType     tlv.Type = 5
118
        preimageType        tlv.Type = 6
119
        valueType           tlv.Type = 7
120
        cltvDeltaType       tlv.Type = 8
121
        expiryType          tlv.Type = 9
122
        paymentAddrType     tlv.Type = 10
123
        featuresType        tlv.Type = 11
124
        invStateType        tlv.Type = 12
125
        amtPaidType         tlv.Type = 13
126
        hodlInvoiceType     tlv.Type = 14
127
        invoiceAmpStateType tlv.Type = 15
128

129
        // A set of tlv type definitions used to serialize the invoice AMP
130
        // state along-side the main invoice body.
131
        ampStateSetIDType       tlv.Type = 0
132
        ampStateHtlcStateType   tlv.Type = 1
133
        ampStateSettleIndexType tlv.Type = 2
134
        ampStateSettleDateType  tlv.Type = 3
135
        ampStateCircuitKeysType tlv.Type = 4
136
        ampStateAmtPaidType     tlv.Type = 5
137
)
138

139
// AddInvoice inserts the targeted invoice into the database. If the invoice has
140
// *any* payment hashes which already exists within the database, then the
141
// insertion will be aborted and rejected due to the strict policy banning any
142
// duplicate payment hashes. A side effect of this function is that it sets
143
// AddIndex on newInvoice.
144
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
145
        paymentHash lntypes.Hash) (uint64, error) {
2✔
146

2✔
147
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
2✔
UNCOV
148
                return 0, err
×
UNCOV
149
        }
×
150

151
        var invoiceAddIndex uint64
2✔
152
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
4✔
153
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
2✔
154
                if err != nil {
2✔
155
                        return err
×
156
                }
×
157

158
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
2✔
159
                        invoiceIndexBucket,
2✔
160
                )
2✔
161
                if err != nil {
2✔
162
                        return err
×
163
                }
×
164
                addIndex, err := invoices.CreateBucketIfNotExists(
2✔
165
                        addIndexBucket,
2✔
166
                )
2✔
167
                if err != nil {
2✔
168
                        return err
×
169
                }
×
170

171
                // Ensure that an invoice an identical payment hash doesn't
172
                // already exist within the index.
173
                if invoiceIndex.Get(paymentHash[:]) != nil {
2✔
UNCOV
174
                        return invpkg.ErrDuplicateInvoice
×
UNCOV
175
                }
×
176

177
                // Check that we aren't inserting an invoice with a duplicate
178
                // payment address. The all-zeros payment address is
179
                // special-cased to support legacy keysend invoices which don't
180
                // assign one. This is safe since later we also will avoid
181
                // indexing them and avoid collisions.
182
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
2✔
183
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
4✔
184
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
2✔
185
                        if payAddrIndex.Get(paymentAddr) != nil {
4✔
186
                                return invpkg.ErrDuplicatePayAddr
2✔
187
                        }
2✔
188
                }
189

190
                // If the current running payment ID counter hasn't yet been
191
                // created, then create it now.
192
                var invoiceNum uint32
2✔
193
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
2✔
194
                if invoiceCounter == nil {
4✔
195
                        var scratch [4]byte
2✔
196
                        byteOrder.PutUint32(scratch[:], invoiceNum)
2✔
197
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
2✔
198
                        if err != nil {
2✔
199
                                return err
×
200
                        }
×
201
                } else {
2✔
202
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
2✔
203
                }
2✔
204

205
                newIndex, err := putInvoice(
2✔
206
                        invoices, invoiceIndex, payAddrIndex, addIndex,
2✔
207
                        newInvoice, invoiceNum, paymentHash,
2✔
208
                )
2✔
209
                if err != nil {
2✔
210
                        return err
×
211
                }
×
212

213
                invoiceAddIndex = newIndex
2✔
214
                return nil
2✔
215
        }, func() {
2✔
216
                invoiceAddIndex = 0
2✔
217
        })
2✔
218
        if err != nil {
4✔
219
                return 0, err
2✔
220
        }
2✔
221

222
        return invoiceAddIndex, err
2✔
223
}
224

225
// InvoicesAddedSince can be used by callers to seek into the event time series
226
// of all the invoices added in the database. The specified sinceAddIndex
227
// should be the highest add index that the caller knows of. This method will
228
// return all invoices with an add index greater than the specified
229
// sinceAddIndex.
230
//
231
// NOTE: The index starts from 1, as a result. We enforce that specifying a
232
// value below the starting index value is a noop.
233
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
234
        []invpkg.Invoice, error) {
2✔
235

2✔
236
        var newInvoices []invpkg.Invoice
2✔
237

2✔
238
        // If an index of zero was specified, then in order to maintain
2✔
239
        // backwards compat, we won't send out any new invoices.
2✔
240
        if sinceAddIndex == 0 {
4✔
241
                return newInvoices, nil
2✔
242
        }
2✔
243

244
        var startIndex [8]byte
2✔
245
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
2✔
246

2✔
247
        err := kvdb.View(d, func(tx kvdb.RTx) error {
4✔
248
                invoices := tx.ReadBucket(invoiceBucket)
2✔
249
                if invoices == nil {
2✔
250
                        return nil
×
251
                }
×
252

253
                addIndex := invoices.NestedReadBucket(addIndexBucket)
2✔
254
                if addIndex == nil {
2✔
255
                        return nil
×
256
                }
×
257

258
                // We'll now run through each entry in the add index starting
259
                // at our starting index. We'll continue until we reach the
260
                // very end of the current key space.
261
                invoiceCursor := addIndex.ReadCursor()
2✔
262

2✔
263
                // We'll seek to the starting index, then manually advance the
2✔
264
                // cursor in order to skip the entry with the since add index.
2✔
265
                invoiceCursor.Seek(startIndex[:])
2✔
266
                addSeqNo, invoiceKey := invoiceCursor.Next()
2✔
267

2✔
268
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
4✔
269

2✔
270
                        // For each key found, we'll look up the actual
2✔
271
                        // invoice, then accumulate it into our return value.
2✔
272
                        invoice, err := fetchInvoice(
2✔
273
                                invoiceKey, invoices, nil, false,
2✔
274
                        )
2✔
275
                        if err != nil {
2✔
276
                                return err
×
277
                        }
×
278

279
                        newInvoices = append(newInvoices, invoice)
2✔
280
                }
281

282
                return nil
2✔
283
        }, func() {
2✔
284
                newInvoices = nil
2✔
285
        })
2✔
286
        if err != nil {
2✔
287
                return nil, err
×
288
        }
×
289

290
        return newInvoices, nil
2✔
291
}
292

293
// LookupInvoice attempts to look up an invoice according to its 32 byte
294
// payment hash. If an invoice which can settle the HTLC identified by the
295
// passed payment hash isn't found, then an error is returned. Otherwise, the
296
// full invoice is returned. Before setting the incoming HTLC, the values
297
// SHOULD be checked to ensure the payer meets the agreed upon contractual
298
// terms of the payment.
299
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
300
        invpkg.Invoice, error) {
2✔
301

2✔
302
        var invoice invpkg.Invoice
2✔
303
        err := kvdb.View(d, func(tx kvdb.RTx) error {
4✔
304
                invoices := tx.ReadBucket(invoiceBucket)
2✔
305
                if invoices == nil {
2✔
306
                        return invpkg.ErrNoInvoicesCreated
×
307
                }
×
308
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
2✔
309
                if invoiceIndex == nil {
4✔
310
                        return invpkg.ErrNoInvoicesCreated
2✔
311
                }
2✔
312
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
2✔
313
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
2✔
314

2✔
315
                // Retrieve the invoice number for this invoice using
2✔
316
                // the provided invoice reference.
2✔
317
                invoiceNum, err := fetchInvoiceNumByRef(
2✔
318
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
2✔
319
                )
2✔
320
                if err != nil {
4✔
321
                        return err
2✔
322
                }
2✔
323

324
                var setID *invpkg.SetID
2✔
325
                switch {
2✔
326
                // If this is a payment address ref, and the blank modified was
327
                // specified, then we'll use the zero set ID to indicate that
328
                // we won't want any HTLCs returned.
329
                case ref.PayAddr() != nil &&
330
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
2✔
331

2✔
332
                        var zeroSetID invpkg.SetID
2✔
333
                        setID = &zeroSetID
2✔
334

335
                // If this is a set ID ref, and the htlc set only modified was
336
                // specified, then we'll pass through the specified setID so
337
                // only that will be returned.
338
                case ref.SetID() != nil &&
339
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
2✔
340

2✔
341
                        setID = (*invpkg.SetID)(ref.SetID())
2✔
342
                }
343

344
                // An invoice was found, retrieve the remainder of the invoice
345
                // body.
346
                i, err := fetchInvoice(
2✔
347
                        invoiceNum, invoices, []*invpkg.SetID{setID}, true,
2✔
348
                )
2✔
349
                if err != nil {
2✔
350
                        return err
×
351
                }
×
352
                invoice = i
2✔
353

2✔
354
                return nil
2✔
355
        }, func() {})
2✔
356
        if err != nil {
4✔
357
                return invoice, err
2✔
358
        }
2✔
359

360
        return invoice, nil
2✔
361
}
362

363
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
364
// reference. The payment address will be treated as the primary key, falling
365
// back to the payment hash if nothing is found for the payment address. An
366
// error is returned if the invoice is not found.
367
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
368
        ref invpkg.InvoiceRef) ([]byte, error) {
2✔
369

2✔
370
        // If the set id is present, we only consult the set id index for this
2✔
371
        // invoice. This type of query is only used to facilitate user-facing
2✔
372
        // requests to lookup, settle or cancel an AMP invoice.
2✔
373
        setID := ref.SetID()
2✔
374
        if setID != nil {
4✔
375
                invoiceNumBySetID := setIDIndex.Get(setID[:])
2✔
376
                if invoiceNumBySetID == nil {
2✔
UNCOV
377
                        return nil, invpkg.ErrInvoiceNotFound
×
UNCOV
378
                }
×
379

380
                return invoiceNumBySetID, nil
2✔
381
        }
382

383
        payHash := ref.PayHash()
2✔
384
        payAddr := ref.PayAddr()
2✔
385

2✔
386
        getInvoiceNumByHash := func() []byte {
4✔
387
                if payHash != nil {
4✔
388
                        return invoiceIndex.Get(payHash[:])
2✔
389
                }
2✔
390
                return nil
2✔
391
        }
392

393
        getInvoiceNumByAddr := func() []byte {
4✔
394
                if payAddr != nil {
4✔
395
                        // Only allow lookups for payment address if it is not a
2✔
396
                        // blank payment address, which is a special-cased value
2✔
397
                        // for legacy keysend invoices.
2✔
398
                        if *payAddr != invpkg.BlankPayAddr {
4✔
399
                                return payAddrIndex.Get(payAddr[:])
2✔
400
                        }
2✔
401
                }
402
                return nil
2✔
403
        }
404

405
        invoiceNumByHash := getInvoiceNumByHash()
2✔
406
        invoiceNumByAddr := getInvoiceNumByAddr()
2✔
407
        switch {
2✔
408
        // If payment address and payment hash both reference an existing
409
        // invoice, ensure they reference the _same_ invoice.
410
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
2✔
411
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
2✔
UNCOV
412
                        return nil, invpkg.ErrInvRefEquivocation
×
UNCOV
413
                }
×
414

415
                return invoiceNumByAddr, nil
2✔
416

417
        // Return invoices by payment addr only.
418
        //
419
        // NOTE: We constrain this lookup to only apply if the invoice ref does
420
        // not contain a payment hash. Legacy and MPP payments depend on the
421
        // payment hash index to enforce that the HTLCs payment hash matches the
422
        // payment hash for the invoice, without this check we would
423
        // inadvertently assume the invoice contains the correct preimage for
424
        // the HTLC, which we only enforce via the lookup by the invoice index.
425
        case invoiceNumByAddr != nil && payHash == nil:
2✔
426
                return invoiceNumByAddr, nil
2✔
427

428
        // If we were only able to reference the invoice by hash, return the
429
        // corresponding invoice number. This can happen when no payment address
430
        // was provided, or if it didn't match anything in our records.
431
        case invoiceNumByHash != nil:
2✔
432
                return invoiceNumByHash, nil
2✔
433

434
        // Otherwise we don't know of the target invoice.
435
        default:
2✔
436
                return nil, invpkg.ErrInvoiceNotFound
2✔
437
        }
438
}
439

440
// FetchPendingInvoices returns all invoices that have not yet been settled or
441
// canceled. The returned map is keyed by the payment hash of each respective
442
// invoice.
443
func (d *DB) FetchPendingInvoices(_ context.Context) (
444
        map[lntypes.Hash]invpkg.Invoice, error) {
2✔
445

2✔
446
        result := make(map[lntypes.Hash]invpkg.Invoice)
2✔
447

2✔
448
        err := kvdb.View(d, func(tx kvdb.RTx) error {
4✔
449
                invoices := tx.ReadBucket(invoiceBucket)
2✔
450
                if invoices == nil {
2✔
451
                        return nil
×
452
                }
×
453

454
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
2✔
455
                if invoiceIndex == nil {
4✔
456
                        // Mask the error if there's no invoice
2✔
457
                        // index as that simply means there are no
2✔
458
                        // invoices added yet to the DB. In this case
2✔
459
                        // we simply return an empty list.
2✔
460
                        return nil
2✔
461
                }
2✔
462

463
                return invoiceIndex.ForEach(func(k, v []byte) error {
4✔
464
                        // Skip the special numInvoicesKey as that does not
2✔
465
                        // point to a valid invoice.
2✔
466
                        if bytes.Equal(k, numInvoicesKey) {
4✔
467
                                return nil
2✔
468
                        }
2✔
469

470
                        // Skip sub-buckets.
471
                        if v == nil {
2✔
472
                                return nil
×
473
                        }
×
474

475
                        invoice, err := fetchInvoice(v, invoices, nil, false)
2✔
476
                        if err != nil {
2✔
477
                                return err
×
478
                        }
×
479

480
                        if invoice.IsPending() {
4✔
481
                                var paymentHash lntypes.Hash
2✔
482
                                copy(paymentHash[:], k)
2✔
483
                                result[paymentHash] = invoice
2✔
484
                        }
2✔
485

486
                        return nil
2✔
487
                })
488
        }, func() {
2✔
489
                result = make(map[lntypes.Hash]invpkg.Invoice)
2✔
490
        })
2✔
491

492
        if err != nil {
2✔
493
                return nil, err
×
494
        }
×
495

496
        return result, nil
2✔
497
}
498

499
// QueryInvoices allows a caller to query the invoice database for invoices
500
// within the specified add index range.
501
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
502
        invpkg.InvoiceSlice, error) {
2✔
503

2✔
504
        var resp invpkg.InvoiceSlice
2✔
505

2✔
506
        err := kvdb.View(d, func(tx kvdb.RTx) error {
4✔
507
                // If the bucket wasn't found, then there aren't any invoices
2✔
508
                // within the database yet, so we can simply exit.
2✔
509
                invoices := tx.ReadBucket(invoiceBucket)
2✔
510
                if invoices == nil {
2✔
511
                        return invpkg.ErrNoInvoicesCreated
×
512
                }
×
513

514
                // Get the add index bucket which we will use to iterate through
515
                // our indexed invoices.
516
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
2✔
517
                if invoiceAddIndex == nil {
4✔
518
                        return invpkg.ErrNoInvoicesCreated
2✔
519
                }
2✔
520

521
                // Create a paginator which reads from our add index bucket with
522
                // the parameters provided by the invoice query.
523
                paginator := newPaginator(
2✔
524
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
2✔
525
                        q.NumMaxInvoices,
2✔
526
                )
2✔
527

2✔
528
                // accumulateInvoices looks up an invoice based on the index we
2✔
529
                // are given, adds it to our set of invoices if it has the right
2✔
530
                // characteristics for our query and returns the number of items
2✔
531
                // we have added to our set of invoices.
2✔
532
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
4✔
533
                        invoice, err := fetchInvoice(
2✔
534
                                indexValue, invoices, nil, false,
2✔
535
                        )
2✔
536
                        if err != nil {
2✔
537
                                return false, err
×
538
                        }
×
539

540
                        // Skip any settled or canceled invoices if the caller
541
                        // is only interested in pending ones.
542
                        if q.PendingOnly && !invoice.IsPending() {
2✔
UNCOV
543
                                return false, nil
×
UNCOV
544
                        }
×
545

546
                        // Get the creation time in Unix seconds, this always
547
                        // rounds down the nanoseconds to full seconds.
548
                        createTime := invoice.CreationDate.Unix()
2✔
549

2✔
550
                        // Skip any invoices that were created before the
2✔
551
                        // specified time.
2✔
552
                        if createTime < q.CreationDateStart {
4✔
553
                                return false, nil
2✔
554
                        }
2✔
555

556
                        // Skip any invoices that were created after the
557
                        // specified time.
558
                        if q.CreationDateEnd != 0 &&
2✔
559
                                createTime > q.CreationDateEnd {
4✔
560

2✔
561
                                return false, nil
2✔
562
                        }
2✔
563

564
                        // At this point, we've exhausted the offset, so we'll
565
                        // begin collecting invoices found within the range.
566
                        resp.Invoices = append(resp.Invoices, invoice)
2✔
567

2✔
568
                        return true, nil
2✔
569
                }
570

571
                // Query our paginator using accumulateInvoices to build up a
572
                // set of invoices.
573
                if err := paginator.query(accumulateInvoices); err != nil {
2✔
574
                        return err
×
575
                }
×
576

577
                // If we iterated through the add index in reverse order, then
578
                // we'll need to reverse the slice of invoices to return them in
579
                // forward order.
580
                if q.Reversed {
2✔
UNCOV
581
                        numInvoices := len(resp.Invoices)
×
UNCOV
582
                        for i := 0; i < numInvoices/2; i++ {
×
UNCOV
583
                                reverse := numInvoices - i - 1
×
UNCOV
584
                                resp.Invoices[i], resp.Invoices[reverse] =
×
UNCOV
585
                                        resp.Invoices[reverse], resp.Invoices[i]
×
UNCOV
586
                        }
×
587
                }
588

589
                return nil
2✔
590
        }, func() {
2✔
591
                resp = invpkg.InvoiceSlice{
2✔
592
                        InvoiceQuery: q,
2✔
593
                }
2✔
594
        })
2✔
595
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
2✔
596
                return resp, err
×
597
        }
×
598

599
        // Finally, record the indexes of the first and last invoices returned
600
        // so that the caller can resume from this point later on.
601
        if len(resp.Invoices) > 0 {
4✔
602
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
2✔
603
                lastIdx := len(resp.Invoices) - 1
2✔
604
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
2✔
605
        }
2✔
606

607
        return resp, nil
2✔
608
}
609

610
// UpdateInvoice attempts to update an invoice corresponding to the passed
611
// payment hash. If an invoice matching the passed payment hash doesn't exist
612
// within the database, then the action will fail with a "not found" error.
613
//
614
// The update is performed inside the same database transaction that fetches the
615
// invoice and is therefore atomic. The fields to update are controlled by the
616
// supplied callback.  When updating an invoice, the update itself happens
617
// in-memory on a copy of the invoice. Once it is written successfully to the
618
// database, the in-memory copy is returned to the caller.
619
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
620
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
621
        *invpkg.Invoice, error) {
2✔
622

2✔
623
        var updatedInvoice *invpkg.Invoice
2✔
624
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
4✔
625
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
2✔
626
                if err != nil {
2✔
627
                        return err
×
628
                }
×
629
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
2✔
630
                        invoiceIndexBucket,
2✔
631
                )
2✔
632
                if err != nil {
2✔
633
                        return err
×
634
                }
×
635
                settleIndex, err := invoices.CreateBucketIfNotExists(
2✔
636
                        settleIndexBucket,
2✔
637
                )
2✔
638
                if err != nil {
2✔
639
                        return err
×
640
                }
×
641
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
2✔
642
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
2✔
643

2✔
644
                // Retrieve the invoice number for this invoice using the
2✔
645
                // provided invoice reference.
2✔
646
                invoiceNum, err := fetchInvoiceNumByRef(
2✔
647
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
2✔
648
                )
2✔
649
                if err != nil {
2✔
UNCOV
650
                        return err
×
UNCOV
651
                }
×
652

653
                invoice, err := fetchInvoice(
2✔
654
                        invoiceNum, invoices, []*invpkg.SetID{setIDHint}, false,
2✔
655
                )
2✔
656
                if err != nil {
2✔
657
                        return err
×
658
                }
×
659

660
                now := d.clock.Now()
2✔
661
                updater := &kvInvoiceUpdater{
2✔
662
                        db:                d,
2✔
663
                        invoicesBucket:    invoices,
2✔
664
                        settleIndexBucket: settleIndex,
2✔
665
                        setIDIndexBucket:  setIDIndex,
2✔
666
                        updateTime:        now,
2✔
667
                        invoiceNum:        invoiceNum,
2✔
668
                        invoice:           &invoice,
2✔
669
                        updatedAmpHtlcs:   make(ampHTLCsMap),
2✔
670
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
2✔
671
                }
2✔
672

2✔
673
                payHash := ref.PayHash()
2✔
674
                updatedInvoice, err = invpkg.UpdateInvoice(
2✔
675
                        payHash, updater.invoice, now, callback, updater,
2✔
676
                )
2✔
677
                if err != nil {
4✔
678
                        return err
2✔
679
                }
2✔
680

681
                // If this is an AMP update, then limit the returned AMP state
682
                // to only the requested set ID.
683
                if setIDHint != nil {
4✔
684
                        filterInvoiceAMPState(updatedInvoice, setIDHint)
2✔
685
                }
2✔
686

687
                return nil
2✔
688
        }, func() {
2✔
689
                updatedInvoice = nil
2✔
690
        })
2✔
691

692
        return updatedInvoice, err
2✔
693
}
694

695
// filterInvoiceAMPState filters the AMP state of the invoice to only include
696
// state for the specified set IDs.
697
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
2✔
698
        filteredAMPState := make(invpkg.AMPInvoiceState)
2✔
699

2✔
700
        for _, setID := range setIDs {
4✔
701
                if setID == nil {
4✔
702
                        return
2✔
703
                }
2✔
704

705
                ampState, ok := invoice.AMPState[*setID]
2✔
706
                if ok {
4✔
707
                        filteredAMPState[*setID] = ampState
2✔
708
                }
2✔
709
        }
710

711
        invoice.AMPState = filteredAMPState
2✔
712
}
713

714
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
715
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
716

717
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
718
// is used with the kv implementation of the invoice database. Note that this
719
// updater is not concurrency safe and synchronizaton is expected to be handled
720
// on the DB level.
721
type kvInvoiceUpdater struct {
722
        db                *DB
723
        invoicesBucket    kvdb.RwBucket
724
        settleIndexBucket kvdb.RwBucket
725
        setIDIndexBucket  kvdb.RwBucket
726

727
        // updateTime is the timestamp for the update.
728
        updateTime time.Time
729

730
        // invoiceNum is a legacy key similar to the add index that is used
731
        // only in the kv implementation.
732
        invoiceNum []byte
733

734
        // invoice is the invoice that we're updating. As a side effect of the
735
        // update this invoice will be mutated.
736
        invoice *invpkg.Invoice
737

738
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
739
        // cancelled as part of this update.
740
        updatedAmpHtlcs ampHTLCsMap
741

742
        // settledSetIDs holds the set IDs that are settled with this update.
743
        settledSetIDs map[invpkg.SetID]struct{}
744
}
745

746
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
747
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
748
        _ *invpkg.InvoiceHTLC) error {
2✔
749

2✔
750
        return nil
2✔
751
}
2✔
752

753
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
754
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
755
        _ time.Time) error {
2✔
756

2✔
757
        return nil
2✔
758
}
2✔
759

760
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
761
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
762
        _ lntypes.Preimage) error {
2✔
763

2✔
764
        return nil
2✔
765
}
2✔
766

767
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
768
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
769
        _ *lntypes.Preimage) error {
2✔
770

2✔
771
        return nil
2✔
772
}
2✔
773

774
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
775
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
2✔
776
        return nil
2✔
777
}
2✔
778

779
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
780
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
781
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
2✔
782

2✔
783
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
4✔
784
                switch state.State {
2✔
785
                case invpkg.HtlcStateAccepted:
2✔
786
                        // If we're just now creating the HTLCs for this set
2✔
787
                        // then we'll also pull in the existing HTLCs that are
2✔
788
                        // part of this set, so we can write them all to disk
2✔
789
                        // together (same value)
2✔
790
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
2✔
791
                                &setID, invpkg.HtlcStateAccepted,
2✔
792
                        )
2✔
793

UNCOV
794
                case invpkg.HtlcStateCanceled:
×
UNCOV
795
                        // Only HTLCs in the accepted state, can be cancelled,
×
UNCOV
796
                        // but we also want to merge that with HTLCs that may be
×
UNCOV
797
                        // canceled as well since it can be cancelled one by
×
UNCOV
798
                        // one.
×
UNCOV
799
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
×
UNCOV
800
                                &setID, invpkg.HtlcStateAccepted,
×
UNCOV
801
                        )
×
UNCOV
802

×
UNCOV
803
                        cancelledHtlcs := k.invoice.HTLCSet(
×
UNCOV
804
                                &setID, invpkg.HtlcStateCanceled,
×
UNCOV
805
                        )
×
UNCOV
806
                        for htlcKey, htlc := range cancelledHtlcs {
×
UNCOV
807
                                k.updatedAmpHtlcs[setID][htlcKey] = htlc
×
UNCOV
808
                        }
×
809

UNCOV
810
                case invpkg.HtlcStateSettled:
×
UNCOV
811
                        k.updatedAmpHtlcs[setID] = make(
×
UNCOV
812
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
×
UNCOV
813
                        )
×
814
                }
815
        }
816

817
        if state.State == invpkg.HtlcStateSettled {
4✔
818
                // Add the set ID to the set that was settled in this invoice
2✔
819
                // update. We'll use this later to update the settle index.
2✔
820
                k.settledSetIDs[setID] = struct{}{}
2✔
821
        }
2✔
822

823
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
2✔
824

2✔
825
        return nil
2✔
826
}
827

828
// Finalize finalizes the update before it is written to the database.
829
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
2✔
830
        switch updateType {
2✔
831
        case invpkg.AddHTLCsUpdate:
2✔
832
                return k.storeAddHtlcsUpdate()
2✔
833

UNCOV
834
        case invpkg.CancelHTLCsUpdate:
×
UNCOV
835
                return k.storeCancelHtlcsUpdate()
×
836

837
        case invpkg.SettleHodlInvoiceUpdate:
2✔
838
                return k.storeSettleHodlInvoiceUpdate()
2✔
839

840
        case invpkg.CancelInvoiceUpdate:
2✔
841
                err := k.serializeAndStoreInvoice()
2✔
842
                if err != nil {
2✔
NEW
843
                        return err
×
NEW
844
                }
×
845

846
                // If this is an AMP invoice, then we'll actually store the rest
847
                // of the HTLCs in-line with the invoice, using the invoice ID
848
                // as a prefix, and the AMP key as a suffix: invoiceNum ||
849
                // setID.
850
                if k.invoice.IsAMP() {
2✔
NEW
851
                        k.updateAMPInvoices()
×
NEW
852
                }
×
853

854
                return nil
2✔
855

856
        }
857

858
        return fmt.Errorf("unknown update type: %v", updateType)
×
859
}
860

861
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
862
// set of HTLCs.
UNCOV
863
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
×
UNCOV
864
        err := k.serializeAndStoreInvoice()
×
UNCOV
865
        if err != nil {
×
866
                return err
×
867
        }
×
868

869
        // If this is an AMP invoice, then we'll actually store the rest
870
        // of the HTLCs in-line with the invoice, using the invoice ID
871
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
872
        // setID.
UNCOV
873
        if k.invoice.IsAMP() {
×
UNCOV
874
                return k.updateAMPInvoices()
×
UNCOV
875
        }
×
876

UNCOV
877
        return nil
×
878
}
879

880
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
881
// HTLCs.
882
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
2✔
883
        invoiceIsAMP := k.invoice.IsAMP()
2✔
884

2✔
885
        for htlcSetID := range k.updatedAmpHtlcs {
4✔
886
                // Check if this SetID already exist.
2✔
887
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
2✔
888

2✔
889
                if setIDInvNum == nil {
4✔
890
                        err := k.setIDIndexBucket.Put(
2✔
891
                                htlcSetID[:], k.invoiceNum,
2✔
892
                        )
2✔
893
                        if err != nil {
2✔
894
                                return err
×
895
                        }
×
896
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
2✔
UNCOV
897
                        return invpkg.ErrDuplicateSetID{
×
UNCOV
898
                                SetID: htlcSetID,
×
UNCOV
899
                        }
×
UNCOV
900
                }
×
901
        }
902

903
        // If this is a non-AMP invoice, then the state can eventually go to
904
        // ContractSettled, so we pass in nil value as part of
905
        // setSettleMetaFields.
906
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
4✔
907
                err := k.setSettleMetaFields(nil)
2✔
908
                if err != nil {
2✔
909
                        return err
×
910
                }
×
911
        }
912

913
        // As we don't update the settle index above for AMP invoices, we'll do
914
        // it here for each sub-AMP invoice that was settled.
915
        for settledSetID := range k.settledSetIDs {
4✔
916
                settledSetID := settledSetID
2✔
917
                err := k.setSettleMetaFields(&settledSetID)
2✔
918
                if err != nil {
2✔
919
                        return err
×
920
                }
×
921
        }
922

923
        err := k.serializeAndStoreInvoice()
2✔
924
        if err != nil {
2✔
925
                return err
×
926
        }
×
927

928
        // If this is an AMP invoice, then we'll actually store the rest of the
929
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
930
        // and the AMP key as a suffix: invoiceNum || setID.
931
        if invoiceIsAMP {
4✔
932
                return k.updateAMPInvoices()
2✔
933
        }
2✔
934

935
        return nil
2✔
936
}
937

938
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
939
// settling a hodl invoice.
940
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
2✔
941
        err := k.setSettleMetaFields(nil)
2✔
942
        if err != nil {
2✔
943
                return err
×
944
        }
×
945

946
        return k.serializeAndStoreInvoice()
2✔
947
}
948

949
// setSettleMetaFields updates the metadata associated with settlement of an
950
// invoice. If a non-nil setID is passed in, then the value will be append to
951
// the invoice number as well, in order to allow us to detect repeated payments
952
// to the same AMP invoices "across time".
953
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
2✔
954
        // Now that we know the invoice hasn't already been settled, we'll
2✔
955
        // update the settle index so we can place this settle event in the
2✔
956
        // proper location within our time series.
2✔
957
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
2✔
958
        if err != nil {
2✔
959
                return err
×
960
        }
×
961

962
        // Make a new byte array on the stack that can potentially store the 4
963
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
964
        // here which is the number of bytes copied so we can only store the 4
965
        // bytes if this is a non-AMP invoice.
966
        var indexKey [invoiceSetIDKeyLen]byte
2✔
967
        valueLen := copy(indexKey[:], k.invoiceNum)
2✔
968

2✔
969
        if setID != nil {
4✔
970
                valueLen += copy(indexKey[valueLen:], setID[:])
2✔
971
        }
2✔
972

973
        var seqNoBytes [8]byte
2✔
974
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
2✔
975
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
2✔
976
        if err != nil {
2✔
977
                return err
×
978
        }
×
979

980
        // If the setID is nil, then this means that this is a non-AMP settle,
981
        // so we'll update the invoice settle index directly.
982
        if setID == nil {
4✔
983
                k.invoice.SettleDate = k.updateTime
2✔
984
                k.invoice.SettleIndex = nextSettleSeqNo
2✔
985
        } else {
4✔
986
                // If the set ID isn't blank, we'll update the AMP state map
2✔
987
                // which tracks when each of the setIDs associated with a given
2✔
988
                // AMP invoice are settled.
2✔
989
                ampState := k.invoice.AMPState[*setID]
2✔
990

2✔
991
                ampState.SettleDate = k.updateTime
2✔
992
                ampState.SettleIndex = nextSettleSeqNo
2✔
993

2✔
994
                k.invoice.AMPState[*setID] = ampState
2✔
995
        }
2✔
996

997
        return nil
2✔
998
}
999

1000
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
1001
// then continually write the invoices to the end of the invoice value, we
1002
// instead write the invoices into a new key preifx that follows the main
1003
// invoice number. This ensures that we don't need to continually decode a
1004
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
1005
// associated with a particular HTLC set.
1006
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
2✔
1007
        for setID, htlcSet := range k.updatedAmpHtlcs {
4✔
1008
                // First write out the set of HTLCs including all the relevant
2✔
1009
                // TLV values.
2✔
1010
                var b bytes.Buffer
2✔
1011
                if err := serializeHtlcs(&b, htlcSet); err != nil {
2✔
1012
                        return err
×
1013
                }
×
1014

1015
                // Next store each HTLC in-line, using a prefix based off the
1016
                // invoice number.
1017
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
2✔
1018

2✔
1019
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
2✔
1020
                if err != nil {
2✔
1021
                        return err
×
1022
                }
×
1023
        }
1024

1025
        return nil
2✔
1026
}
1027

1028
// serializeAndStoreInvoice is a helper function used to store invoices.
1029
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
2✔
1030
        var buf bytes.Buffer
2✔
1031
        if err := serializeInvoice(&buf, k.invoice); err != nil {
2✔
1032
                return err
×
1033
        }
×
1034

1035
        err := k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
2✔
1036
        if err != nil {
2✔
NEW
1037
                return err
×
NEW
1038
        }
×
1039

1040
        return nil
2✔
1041
}
1042

1043
// InvoicesSettledSince can be used by callers to catch up any settled invoices
1044
// they missed within the settled invoice time series. We'll return all known
1045
// settled invoice that have a settle index higher than the passed
1046
// sinceSettleIndex.
1047
//
1048
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1049
// value below the starting index value is a noop.
1050
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1051
        []invpkg.Invoice, error) {
2✔
1052

2✔
1053
        var settledInvoices []invpkg.Invoice
2✔
1054

2✔
1055
        // If an index of zero was specified, then in order to maintain
2✔
1056
        // backwards compat, we won't send out any new invoices.
2✔
1057
        if sinceSettleIndex == 0 {
4✔
1058
                return settledInvoices, nil
2✔
1059
        }
2✔
1060

1061
        var startIndex [8]byte
2✔
1062
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
2✔
1063

2✔
1064
        err := kvdb.View(d, func(tx kvdb.RTx) error {
4✔
1065
                invoices := tx.ReadBucket(invoiceBucket)
2✔
1066
                if invoices == nil {
2✔
1067
                        return nil
×
1068
                }
×
1069

1070
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
2✔
1071
                if settleIndex == nil {
2✔
1072
                        return nil
×
1073
                }
×
1074

1075
                // We'll now run through each entry in the add index starting
1076
                // at our starting index. We'll continue until we reach the
1077
                // very end of the current key space.
1078
                invoiceCursor := settleIndex.ReadCursor()
2✔
1079

2✔
1080
                // We'll seek to the starting index, then manually advance the
2✔
1081
                // cursor in order to skip the entry with the since add index.
2✔
1082
                invoiceCursor.Seek(startIndex[:])
2✔
1083
                seqNo, indexValue := invoiceCursor.Next()
2✔
1084

2✔
1085
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
4✔
1086
                        // Depending on the length of the index value, this may
2✔
1087
                        // or may not be an AMP invoice, so we'll extract the
2✔
1088
                        // invoice value into two components: the invoice num,
2✔
1089
                        // and the setID (may not be there).
2✔
1090
                        var (
2✔
1091
                                invoiceKey [4]byte
2✔
1092
                                setID      *invpkg.SetID
2✔
1093
                        )
2✔
1094

2✔
1095
                        valueLen := copy(invoiceKey[:], indexValue)
2✔
1096
                        if len(indexValue) == invoiceSetIDKeyLen {
4✔
1097
                                setID = new(invpkg.SetID)
2✔
1098
                                copy(setID[:], indexValue[valueLen:])
2✔
1099
                        }
2✔
1100

1101
                        // For each key found, we'll look up the actual
1102
                        // invoice, then accumulate it into our return value.
1103
                        invoice, err := fetchInvoice(
2✔
1104
                                invoiceKey[:], invoices, []*invpkg.SetID{setID},
2✔
1105
                                true,
2✔
1106
                        )
2✔
1107
                        if err != nil {
2✔
1108
                                return err
×
1109
                        }
×
1110

1111
                        settledInvoices = append(settledInvoices, invoice)
2✔
1112
                }
1113

1114
                return nil
2✔
1115
        }, func() {
2✔
1116
                settledInvoices = nil
2✔
1117
        })
2✔
1118
        if err != nil {
2✔
1119
                return nil, err
×
1120
        }
×
1121

1122
        return settledInvoices, nil
2✔
1123
}
1124

1125
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
1126
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
1127
        uint64, error) {
2✔
1128

2✔
1129
        // Create the invoice key which is just the big-endian representation
2✔
1130
        // of the invoice number.
2✔
1131
        var invoiceKey [4]byte
2✔
1132
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
2✔
1133

2✔
1134
        // Increment the num invoice counter index so the next invoice bares
2✔
1135
        // the proper ID.
2✔
1136
        var scratch [4]byte
2✔
1137
        invoiceCounter := invoiceNum + 1
2✔
1138
        byteOrder.PutUint32(scratch[:], invoiceCounter)
2✔
1139
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
2✔
1140
                return 0, err
×
1141
        }
×
1142

1143
        // Add the payment hash to the invoice index. This will let us quickly
1144
        // identify if we can settle an incoming payment, and also to possibly
1145
        // allow a single invoice to have multiple payment installations.
1146
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
2✔
1147
        if err != nil {
2✔
1148
                return 0, err
×
1149
        }
×
1150

1151
        // Add the invoice to the payment address index, but only if the invoice
1152
        // has a non-zero payment address. The all-zero payment address is still
1153
        // in use by legacy keysend, so we special-case here to avoid
1154
        // collisions.
1155
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
4✔
1156
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
2✔
1157
                if err != nil {
2✔
1158
                        return 0, err
×
1159
                }
×
1160
        }
1161

1162
        // Next, we'll obtain the next add invoice index (sequence
1163
        // number), so we can properly place this invoice within this
1164
        // event stream.
1165
        nextAddSeqNo, err := addIndex.NextSequence()
2✔
1166
        if err != nil {
2✔
1167
                return 0, err
×
1168
        }
×
1169

1170
        // With the next sequence obtained, we'll updating the event series in
1171
        // the add index bucket to map this current add counter to the index of
1172
        // this new invoice.
1173
        var seqNoBytes [8]byte
2✔
1174
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
2✔
1175
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
2✔
1176
                return 0, err
×
1177
        }
×
1178

1179
        i.AddIndex = nextAddSeqNo
2✔
1180

2✔
1181
        // Finally, serialize the invoice itself to be written to the disk.
2✔
1182
        var buf bytes.Buffer
2✔
1183
        if err := serializeInvoice(&buf, i); err != nil {
2✔
1184
                return 0, err
×
1185
        }
×
1186

1187
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
2✔
1188
                return 0, err
×
1189
        }
×
1190

1191
        return nextAddSeqNo, nil
2✔
1192
}
1193

1194
// recordSize returns the amount of bytes this TLV record will occupy when
1195
// encoded.
1196
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
2✔
1197
        var (
2✔
1198
                b   bytes.Buffer
2✔
1199
                buf [8]byte
2✔
1200
        )
2✔
1201

2✔
1202
        // We know that encoding works since the tests pass in the build this
2✔
1203
        // file is checked into, so we'll simplify things and simply encode it
2✔
1204
        // ourselves then report the total amount of bytes used.
2✔
1205
        if err := ampStateEncoder(&b, a, &buf); err != nil {
2✔
1206
                // This should never error out, but we log it just in case it
×
1207
                // does.
×
1208
                log.Errorf("encoding the amp invoice state failed: %v", err)
×
1209
        }
×
1210

1211
        return func() uint64 {
4✔
1212
                return uint64(len(b.Bytes()))
2✔
1213
        }
2✔
1214
}
1215

1216
// serializeInvoice serializes an invoice to a writer.
1217
//
1218
// Note: this function is in use for a migration. Before making changes that
1219
// would modify the on disk format, make a copy of the original code and store
1220
// it with the migration.
1221
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
2✔
1222
        creationDateBytes, err := i.CreationDate.MarshalBinary()
2✔
1223
        if err != nil {
2✔
1224
                return err
×
1225
        }
×
1226

1227
        settleDateBytes, err := i.SettleDate.MarshalBinary()
2✔
1228
        if err != nil {
2✔
1229
                return err
×
1230
        }
×
1231

1232
        var fb bytes.Buffer
2✔
1233
        err = i.Terms.Features.EncodeBase256(&fb)
2✔
1234
        if err != nil {
2✔
1235
                return err
×
1236
        }
×
1237
        featureBytes := fb.Bytes()
2✔
1238

2✔
1239
        preimage := [32]byte(invpkg.UnknownPreimage)
2✔
1240
        if i.Terms.PaymentPreimage != nil {
4✔
1241
                preimage = *i.Terms.PaymentPreimage
2✔
1242
                if preimage == invpkg.UnknownPreimage {
2✔
1243
                        return errors.New("cannot use all-zeroes preimage")
×
1244
                }
×
1245
        }
1246
        value := uint64(i.Terms.Value)
2✔
1247
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
2✔
1248
        expiry := uint64(i.Terms.Expiry)
2✔
1249

2✔
1250
        amtPaid := uint64(i.AmtPaid)
2✔
1251
        state := uint8(i.State)
2✔
1252

2✔
1253
        var hodlInvoice uint8
2✔
1254
        if i.HodlInvoice {
4✔
1255
                hodlInvoice = 1
2✔
1256
        }
2✔
1257

1258
        tlvStream, err := tlv.NewStream(
2✔
1259
                // Memo and payreq.
2✔
1260
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
2✔
1261
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
2✔
1262

2✔
1263
                // Add/settle metadata.
2✔
1264
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
2✔
1265
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
2✔
1266
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
2✔
1267
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
2✔
1268

2✔
1269
                // Terms.
2✔
1270
                tlv.MakePrimitiveRecord(preimageType, &preimage),
2✔
1271
                tlv.MakePrimitiveRecord(valueType, &value),
2✔
1272
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
2✔
1273
                tlv.MakePrimitiveRecord(expiryType, &expiry),
2✔
1274
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
2✔
1275
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
2✔
1276

2✔
1277
                // Invoice state.
2✔
1278
                tlv.MakePrimitiveRecord(invStateType, &state),
2✔
1279
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
2✔
1280

2✔
1281
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
2✔
1282

2✔
1283
                // Invoice AMP state.
2✔
1284
                tlv.MakeDynamicRecord(
2✔
1285
                        invoiceAmpStateType, &i.AMPState,
2✔
1286
                        ampRecordSize(&i.AMPState),
2✔
1287
                        ampStateEncoder, ampStateDecoder,
2✔
1288
                ),
2✔
1289
        )
2✔
1290
        if err != nil {
2✔
1291
                return err
×
1292
        }
×
1293

1294
        var b bytes.Buffer
2✔
1295
        if err = tlvStream.Encode(&b); err != nil {
2✔
1296
                return err
×
1297
        }
×
1298

1299
        err = binary.Write(w, byteOrder, uint64(b.Len()))
2✔
1300
        if err != nil {
2✔
1301
                return err
×
1302
        }
×
1303

1304
        if _, err = w.Write(b.Bytes()); err != nil {
2✔
1305
                return err
×
1306
        }
×
1307

1308
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
1309
        // in-line with the rest of the invoice.
1310
        if i.IsAMP() {
4✔
1311
                return nil
2✔
1312
        }
2✔
1313

1314
        return serializeHtlcs(w, i.Htlcs)
2✔
1315
}
1316

1317
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
1318
// a writer.
1319
func serializeHtlcs(w io.Writer,
1320
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
2✔
1321

2✔
1322
        for key, htlc := range htlcs {
4✔
1323
                // Encode the htlc in a tlv stream.
2✔
1324
                chanID := key.ChanID.ToUint64()
2✔
1325
                amt := uint64(htlc.Amt)
2✔
1326
                mppTotalAmt := uint64(htlc.MppTotalAmt)
2✔
1327
                acceptTime := putNanoTime(htlc.AcceptTime)
2✔
1328
                resolveTime := putNanoTime(htlc.ResolveTime)
2✔
1329
                state := uint8(htlc.State)
2✔
1330

2✔
1331
                var records []tlv.Record
2✔
1332
                records = append(records,
2✔
1333
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
2✔
1334
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
2✔
1335
                        tlv.MakePrimitiveRecord(amtType, &amt),
2✔
1336
                        tlv.MakePrimitiveRecord(
2✔
1337
                                acceptHeightType, &htlc.AcceptHeight,
2✔
1338
                        ),
2✔
1339
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
2✔
1340
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
2✔
1341
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
2✔
1342
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
2✔
1343
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
2✔
1344
                )
2✔
1345

2✔
1346
                if htlc.AMP != nil {
4✔
1347
                        setIDRecord := tlv.MakeDynamicRecord(
2✔
1348
                                htlcAMPType, &htlc.AMP.Record,
2✔
1349
                                htlc.AMP.Record.PayloadSize,
2✔
1350
                                record.AMPEncoder, record.AMPDecoder,
2✔
1351
                        )
2✔
1352
                        records = append(records, setIDRecord)
2✔
1353

2✔
1354
                        hash32 := [32]byte(htlc.AMP.Hash)
2✔
1355
                        hashRecord := tlv.MakePrimitiveRecord(
2✔
1356
                                htlcHashType, &hash32,
2✔
1357
                        )
2✔
1358
                        records = append(records, hashRecord)
2✔
1359

2✔
1360
                        if htlc.AMP.Preimage != nil {
4✔
1361
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
2✔
1362
                                preimageRecord := tlv.MakePrimitiveRecord(
2✔
1363
                                        htlcPreimageType, &preimage32,
2✔
1364
                                )
2✔
1365
                                records = append(records, preimageRecord)
2✔
1366
                        }
2✔
1367
                }
1368

1369
                // Convert the custom records to tlv.Record types that are ready
1370
                // for serialization.
1371
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
2✔
1372

2✔
1373
                // Append the custom records. Their ids are in the experimental
2✔
1374
                // range and sorted, so there is no need to sort again.
2✔
1375
                records = append(records, customRecords...)
2✔
1376

2✔
1377
                tlvStream, err := tlv.NewStream(records...)
2✔
1378
                if err != nil {
2✔
1379
                        return err
×
1380
                }
×
1381

1382
                var b bytes.Buffer
2✔
1383
                if err := tlvStream.Encode(&b); err != nil {
2✔
1384
                        return err
×
1385
                }
×
1386

1387
                // Write the length of the tlv stream followed by the stream
1388
                // bytes.
1389
                err = binary.Write(w, byteOrder, uint64(b.Len()))
2✔
1390
                if err != nil {
2✔
1391
                        return err
×
1392
                }
×
1393

1394
                if _, err := w.Write(b.Bytes()); err != nil {
2✔
1395
                        return err
×
1396
                }
×
1397
        }
1398

1399
        return nil
2✔
1400
}
1401

1402
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
1403
// timestamp will be mapped to 0, since calling UnixNano in that case is
1404
// undefined.
1405
func putNanoTime(t time.Time) uint64 {
2✔
1406
        if t.IsZero() {
4✔
1407
                return 0
2✔
1408
        }
2✔
1409
        return uint64(t.UnixNano())
2✔
1410
}
1411

1412
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
1413
// is provided, an zero-value time stamp is returned.
1414
func getNanoTime(ns uint64) time.Time {
2✔
1415
        if ns == 0 {
4✔
1416
                return time.Time{}
2✔
1417
        }
2✔
1418
        return time.Unix(0, int64(ns))
2✔
1419
}
1420

1421
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
1422
// identified by the setID value.
1423
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1424
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1425
        error) {
2✔
1426

2✔
1427
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
2✔
1428
        for _, setID := range setIDs {
4✔
1429
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
2✔
1430

2✔
1431
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
2✔
1432
                if htlcSetBytes == nil {
4✔
1433
                        // A set ID was passed in, but we don't have this
2✔
1434
                        // stored yet, meaning that the setID is being added
2✔
1435
                        // for the first time.
2✔
1436
                        return htlcs, invpkg.ErrInvoiceNotFound
2✔
1437
                }
2✔
1438

1439
                htlcSetReader := bytes.NewReader(htlcSetBytes)
2✔
1440
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
2✔
1441
                if err != nil {
2✔
1442
                        return nil, err
×
1443
                }
×
1444

1445
                for key, htlc := range htlcsBySetID {
4✔
1446
                        htlcs[key] = htlc
2✔
1447
                }
2✔
1448
        }
1449

1450
        return htlcs, nil
2✔
1451
}
1452

1453
// forEachAMPInvoice is a helper function that attempts to iterate over each of
1454
// the HTLC sets (based on their set ID) for the given AMP invoice identified
1455
// by its invoiceNum. The callback closure is called for each key within the
1456
// prefix range.
1457
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1458
        callback func(key, htlcSet []byte) error) error {
2✔
1459

2✔
1460
        invoiceCursor := invoiceBucket.ReadCursor()
2✔
1461

2✔
1462
        // Seek to the first key that includes the invoice data itself.
2✔
1463
        invoiceCursor.Seek(invoiceNum)
2✔
1464

2✔
1465
        // Advance to the very first key _after_ the invoice data, as this is
2✔
1466
        // where we'll encounter our first HTLC (if any are present).
2✔
1467
        cursorKey, htlcSet := invoiceCursor.Next()
2✔
1468

2✔
1469
        // If at this point, the cursor key doesn't match the invoice num
2✔
1470
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
2✔
1471
        // associated with it.
2✔
1472
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
4✔
1473
                return nil
2✔
1474
        }
2✔
1475

1476
        // Otherwise continue to iterate until we no longer match the prefix,
1477
        // executing the call back at each step.
1478
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
4✔
1479
                err := callback(cursorKey, htlcSet)
2✔
1480
                if err != nil {
2✔
1481
                        return err
×
1482
                }
×
1483
        }
1484

1485
        return nil
2✔
1486
}
1487

1488
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
1489
// AMP bucket to find all the individual HTLCs (by setID) associated with a
1490
// given invoice. If a list of set IDs are specified, then only HTLCs
1491
// associated with that setID will be retrieved.
1492
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1493
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1494
        error) {
2✔
1495

2✔
1496
        // If a set of setIDs was specified, then we can skip the cursor and
2✔
1497
        // just read out exactly what we need.
2✔
1498
        if len(setIDs) != 0 && setIDs[0] != nil {
4✔
1499
                return fetchFilteredAmpInvoices(
2✔
1500
                        invoiceBucket, invoiceNum, setIDs...,
2✔
1501
                )
2✔
1502
        }
2✔
1503

1504
        // Otherwise, iterate over all the htlc sets that are prefixed beside
1505
        // this invoice in the main invoice bucket.
1506
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
2✔
1507
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
2✔
1508
                func(key, htlcSet []byte) error {
4✔
1509
                        htlcSetReader := bytes.NewReader(htlcSet)
2✔
1510
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
2✔
1511
                        if err != nil {
2✔
1512
                                return err
×
1513
                        }
×
1514

1515
                        for key, htlc := range htlcsBySetID {
4✔
1516
                                htlcs[key] = htlc
2✔
1517
                        }
2✔
1518

1519
                        return nil
2✔
1520
                },
1521
        )
1522

1523
        if err != nil {
2✔
1524
                return nil, err
×
1525
        }
×
1526

1527
        return htlcs, nil
2✔
1528
}
1529

1530
// fetchInvoice attempts to read out the relevant state for the invoice as
1531
// specified by the invoice number. If the setID fields are set, then only the
1532
// HTLC information pertaining to those set IDs is returned.
1533
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1534
        setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
2✔
1535

2✔
1536
        invoiceBytes := invoices.Get(invoiceNum)
2✔
1537
        if invoiceBytes == nil {
2✔
1538
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
×
1539
        }
×
1540

1541
        invoiceReader := bytes.NewReader(invoiceBytes)
2✔
1542

2✔
1543
        invoice, err := deserializeInvoice(invoiceReader)
2✔
1544
        if err != nil {
2✔
1545
                return invpkg.Invoice{}, err
×
1546
        }
×
1547

1548
        // If this is an AMP invoice we'll also attempt to read out the set of
1549
        // HTLCs that were paid to prior set IDs, if needed.
1550
        if !invoice.IsAMP() {
4✔
1551
                return invoice, nil
2✔
1552
        }
2✔
1553

1554
        if shouldFetchAMPHTLCs(invoice, setIDs) {
4✔
1555
                invoice.Htlcs, err = fetchAmpSubInvoices(
2✔
1556
                        invoices, invoiceNum, setIDs...,
2✔
1557
                )
2✔
1558
                // TODO(positiveblue): we should fail when we are not able to
2✔
1559
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
2✔
1560
                // the invoice and channeldb package break if we return this
2✔
1561
                // error. We need to update them when we migrate this logic to
2✔
1562
                // the sql implementation.
2✔
1563
                if err != nil {
4✔
1564
                        log.Errorf("unable to fetch amp htlcs for inv "+
2✔
1565
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
2✔
1566
                }
2✔
1567

1568
                if filterAMPState {
4✔
1569
                        filterInvoiceAMPState(&invoice, setIDs...)
2✔
1570
                }
2✔
1571
        }
1572

1573
        return invoice, nil
2✔
1574
}
1575

1576
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
1577
// were paid to the relevant set IDs.
1578
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
2✔
1579
        // For AMP invoice that already have HTLCs populated (created before
2✔
1580
        // recurring invoices), then we don't need to read from the prefix
2✔
1581
        // keyed section of the bucket.
2✔
1582
        if len(invoice.Htlcs) != 0 {
2✔
1583
                return false
×
1584
        }
×
1585

1586
        // If the "zero" setID was specified, then this means that no HTLC data
1587
        // should be returned alongside of it.
1588
        if len(setIDs) != 0 && setIDs[0] != nil &&
2✔
1589
                *setIDs[0] == invpkg.BlankPayAddr {
4✔
1590

2✔
1591
                return false
2✔
1592
        }
2✔
1593

1594
        return true
2✔
1595
}
1596

1597
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
1598
// an AMP invoice. This methods only decode the relevant state vs the entire
1599
// invoice.
1600
func fetchInvoiceStateAMP(invoiceNum []byte,
UNCOV
1601
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
×
UNCOV
1602

×
UNCOV
1603
        // Fetch the raw invoice bytes.
×
UNCOV
1604
        invoiceBytes := invoices.Get(invoiceNum)
×
UNCOV
1605
        if invoiceBytes == nil {
×
1606
                return nil, invpkg.ErrInvoiceNotFound
×
1607
        }
×
1608

UNCOV
1609
        r := bytes.NewReader(invoiceBytes)
×
UNCOV
1610

×
UNCOV
1611
        var bodyLen int64
×
UNCOV
1612
        err := binary.Read(r, byteOrder, &bodyLen)
×
UNCOV
1613
        if err != nil {
×
1614
                return nil, err
×
1615
        }
×
1616

1617
        // Next, we'll make a new TLV stream that only attempts to decode the
1618
        // bytes we actually need.
UNCOV
1619
        ampState := make(invpkg.AMPInvoiceState)
×
UNCOV
1620
        tlvStream, err := tlv.NewStream(
×
UNCOV
1621
                // Invoice AMP state.
×
UNCOV
1622
                tlv.MakeDynamicRecord(
×
UNCOV
1623
                        invoiceAmpStateType, &ampState, nil,
×
UNCOV
1624
                        ampStateEncoder, ampStateDecoder,
×
UNCOV
1625
                ),
×
UNCOV
1626
        )
×
UNCOV
1627
        if err != nil {
×
1628
                return nil, err
×
1629
        }
×
1630

UNCOV
1631
        invoiceReader := io.LimitReader(r, bodyLen)
×
UNCOV
1632
        if err = tlvStream.Decode(invoiceReader); err != nil {
×
1633
                return nil, err
×
1634
        }
×
1635

UNCOV
1636
        return ampState, nil
×
1637
}
1638

1639
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
2✔
1640
        var (
2✔
1641
                preimageBytes [32]byte
2✔
1642
                value         uint64
2✔
1643
                cltvDelta     uint32
2✔
1644
                expiry        uint64
2✔
1645
                amtPaid       uint64
2✔
1646
                state         uint8
2✔
1647
                hodlInvoice   uint8
2✔
1648

2✔
1649
                creationDateBytes []byte
2✔
1650
                settleDateBytes   []byte
2✔
1651
                featureBytes      []byte
2✔
1652
        )
2✔
1653

2✔
1654
        var i invpkg.Invoice
2✔
1655
        i.AMPState = make(invpkg.AMPInvoiceState)
2✔
1656
        tlvStream, err := tlv.NewStream(
2✔
1657
                // Memo and payreq.
2✔
1658
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
2✔
1659
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
2✔
1660

2✔
1661
                // Add/settle metadata.
2✔
1662
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
2✔
1663
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
2✔
1664
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
2✔
1665
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
2✔
1666

2✔
1667
                // Terms.
2✔
1668
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
2✔
1669
                tlv.MakePrimitiveRecord(valueType, &value),
2✔
1670
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
2✔
1671
                tlv.MakePrimitiveRecord(expiryType, &expiry),
2✔
1672
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
2✔
1673
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
2✔
1674

2✔
1675
                // Invoice state.
2✔
1676
                tlv.MakePrimitiveRecord(invStateType, &state),
2✔
1677
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
2✔
1678

2✔
1679
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
2✔
1680

2✔
1681
                // Invoice AMP state.
2✔
1682
                tlv.MakeDynamicRecord(
2✔
1683
                        invoiceAmpStateType, &i.AMPState, nil,
2✔
1684
                        ampStateEncoder, ampStateDecoder,
2✔
1685
                ),
2✔
1686
        )
2✔
1687
        if err != nil {
2✔
1688
                return i, err
×
1689
        }
×
1690

1691
        var bodyLen int64
2✔
1692
        err = binary.Read(r, byteOrder, &bodyLen)
2✔
1693
        if err != nil {
2✔
1694
                return i, err
×
1695
        }
×
1696

1697
        lr := io.LimitReader(r, bodyLen)
2✔
1698
        if err = tlvStream.Decode(lr); err != nil {
2✔
1699
                return i, err
×
1700
        }
×
1701

1702
        preimage := lntypes.Preimage(preimageBytes)
2✔
1703
        if preimage != invpkg.UnknownPreimage {
4✔
1704
                i.Terms.PaymentPreimage = &preimage
2✔
1705
        }
2✔
1706

1707
        i.Terms.Value = lnwire.MilliSatoshi(value)
2✔
1708
        i.Terms.FinalCltvDelta = int32(cltvDelta)
2✔
1709
        i.Terms.Expiry = time.Duration(expiry)
2✔
1710
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
2✔
1711
        i.State = invpkg.ContractState(state)
2✔
1712

2✔
1713
        if hodlInvoice != 0 {
4✔
1714
                i.HodlInvoice = true
2✔
1715
        }
2✔
1716

1717
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
2✔
1718
        if err != nil {
2✔
1719
                return i, err
×
1720
        }
×
1721

1722
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
2✔
1723
        if err != nil {
2✔
1724
                return i, err
×
1725
        }
×
1726

1727
        rawFeatures := lnwire.NewRawFeatureVector()
2✔
1728
        err = rawFeatures.DecodeBase256(
2✔
1729
                bytes.NewReader(featureBytes), len(featureBytes),
2✔
1730
        )
2✔
1731
        if err != nil {
2✔
1732
                return i, err
×
1733
        }
×
1734

1735
        i.Terms.Features = lnwire.NewFeatureVector(
2✔
1736
                rawFeatures, lnwire.Features,
2✔
1737
        )
2✔
1738

2✔
1739
        i.Htlcs, err = deserializeHtlcs(r)
2✔
1740
        return i, err
2✔
1741
}
1742

1743
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
2✔
1744
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
4✔
1745
                // We encode the set of circuit keys as a varint length prefix.
2✔
1746
                // followed by a series of fixed sized uint8 integers.
2✔
1747
                numKeys := uint64(len(*v))
2✔
1748

2✔
1749
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
2✔
1750
                        return err
×
1751
                }
×
1752

1753
                for key := range *v {
4✔
1754
                        scidInt := key.ChanID.ToUint64()
2✔
1755

2✔
1756
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
2✔
1757
                                return err
×
1758
                        }
×
1759
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
2✔
1760
                                return err
×
1761
                        }
×
1762
                }
1763

1764
                return nil
2✔
1765
        }
1766

1767
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1768
}
1769

1770
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1771
        l uint64) error {
2✔
1772

2✔
1773
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
4✔
1774
                // First, we'll read out the varint that encodes the number of
2✔
1775
                // circuit keys encoded.
2✔
1776
                numKeys, err := tlv.ReadVarInt(r, buf)
2✔
1777
                if err != nil {
2✔
1778
                        return err
×
1779
                }
×
1780

1781
                // Now that we know how many keys to expect, iterate reading
1782
                // each one until we're done.
1783
                for i := uint64(0); i < numKeys; i++ {
4✔
1784
                        var (
2✔
1785
                                key  models.CircuitKey
2✔
1786
                                scid uint64
2✔
1787
                        )
2✔
1788

2✔
1789
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
2✔
1790
                                return err
×
1791
                        }
×
1792

1793
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
2✔
1794

2✔
1795
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
2✔
1796
                        if err != nil {
2✔
1797
                                return err
×
1798
                        }
×
1799

1800
                        (*v)[key] = struct{}{}
2✔
1801
                }
1802

1803
                return nil
2✔
1804
        }
1805

1806
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
×
1807
}
1808

1809
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1810
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
2✔
1811
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
4✔
1812
                // We'll encode the AMP state as a series of KV pairs on the
2✔
1813
                // wire with a length prefix.
2✔
1814
                numRecords := uint64(len(*v))
2✔
1815

2✔
1816
                // First, we'll write out the number of records as a var int.
2✔
1817
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
2✔
1818
                        return err
×
1819
                }
×
1820

1821
                // With that written out, we'll now encode the entries
1822
                // themselves as a sub-TLV record, which includes its _own_
1823
                // inner length prefix.
1824
                for setID, ampState := range *v {
4✔
1825
                        setID := [32]byte(setID)
2✔
1826
                        ampState := ampState
2✔
1827

2✔
1828
                        htlcState := uint8(ampState.State)
2✔
1829
                        settleDate := ampState.SettleDate
2✔
1830
                        settleDateBytes, err := settleDate.MarshalBinary()
2✔
1831
                        if err != nil {
2✔
1832
                                return err
×
1833
                        }
×
1834

1835
                        amtPaid := uint64(ampState.AmtPaid)
2✔
1836

2✔
1837
                        var ampStateTlvBytes bytes.Buffer
2✔
1838
                        tlvStream, err := tlv.NewStream(
2✔
1839
                                tlv.MakePrimitiveRecord(
2✔
1840
                                        ampStateSetIDType, &setID,
2✔
1841
                                ),
2✔
1842
                                tlv.MakePrimitiveRecord(
2✔
1843
                                        ampStateHtlcStateType, &htlcState,
2✔
1844
                                ),
2✔
1845
                                tlv.MakePrimitiveRecord(
2✔
1846
                                        ampStateSettleIndexType,
2✔
1847
                                        &ampState.SettleIndex,
2✔
1848
                                ),
2✔
1849
                                tlv.MakePrimitiveRecord(
2✔
1850
                                        ampStateSettleDateType,
2✔
1851
                                        &settleDateBytes,
2✔
1852
                                ),
2✔
1853
                                tlv.MakeDynamicRecord(
2✔
1854
                                        ampStateCircuitKeysType,
2✔
1855
                                        &ampState.InvoiceKeys,
2✔
1856
                                        func() uint64 {
4✔
1857
                                                // The record takes 8 bytes to
2✔
1858
                                                // encode the set of circuits,
2✔
1859
                                                // 8 bytes for the scid for the
2✔
1860
                                                // key, and 8 bytes for the HTLC
2✔
1861
                                                // index.
2✔
1862
                                                keys := ampState.InvoiceKeys
2✔
1863
                                                numKeys := uint64(len(keys))
2✔
1864
                                                size := tlv.VarIntSize(numKeys)
2✔
1865
                                                dataSize := (numKeys * 16)
2✔
1866

2✔
1867
                                                return size + dataSize
2✔
1868
                                        },
2✔
1869
                                        encodeCircuitKeys, decodeCircuitKeys,
1870
                                ),
1871
                                tlv.MakePrimitiveRecord(
1872
                                        ampStateAmtPaidType, &amtPaid,
1873
                                ),
1874
                        )
1875
                        if err != nil {
2✔
1876
                                return err
×
1877
                        }
×
1878

1879
                        err = tlvStream.Encode(&ampStateTlvBytes)
2✔
1880
                        if err != nil {
2✔
1881
                                return err
×
1882
                        }
×
1883

1884
                        // We encode the record with a varint length followed by
1885
                        // the _raw_ TLV bytes.
1886
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
2✔
1887
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
2✔
1888
                                return err
×
1889
                        }
×
1890

1891
                        _, err = w.Write(ampStateTlvBytes.Bytes())
2✔
1892
                        if err != nil {
2✔
1893
                                return err
×
1894
                        }
×
1895
                }
1896

1897
                return nil
2✔
1898
        }
1899

1900
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1901
}
1902

1903
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1904
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
1905
        l uint64) error {
2✔
1906

2✔
1907
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
4✔
1908
                // First, we'll decode the varint that encodes how many set IDs
2✔
1909
                // are encoded within the greater map.
2✔
1910
                numRecords, err := tlv.ReadVarInt(r, buf)
2✔
1911
                if err != nil {
2✔
1912
                        return err
×
1913
                }
×
1914

1915
                // Now that we know how many records we'll need to read, we can
1916
                // iterate and read them all out in series.
1917
                for i := uint64(0); i < numRecords; i++ {
4✔
1918
                        // Read out the varint that encodes the size of this
2✔
1919
                        // inner TLV record.
2✔
1920
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
2✔
1921
                        if err != nil {
2✔
1922
                                return err
×
1923
                        }
×
1924

1925
                        // Using this information, we'll create a new limited
1926
                        // reader that'll return an EOF once the end has been
1927
                        // reached so the stream stops consuming bytes.
1928
                        innerTlvReader := io.LimitedReader{
2✔
1929
                                R: r,
2✔
1930
                                N: int64(stateRecordSize),
2✔
1931
                        }
2✔
1932

2✔
1933
                        var (
2✔
1934
                                setID           [32]byte
2✔
1935
                                htlcState       uint8
2✔
1936
                                settleIndex     uint64
2✔
1937
                                settleDateBytes []byte
2✔
1938
                                invoiceKeys     = make(
2✔
1939
                                        map[models.CircuitKey]struct{},
2✔
1940
                                )
2✔
1941
                                amtPaid uint64
2✔
1942
                        )
2✔
1943
                        tlvStream, err := tlv.NewStream(
2✔
1944
                                tlv.MakePrimitiveRecord(
2✔
1945
                                        ampStateSetIDType, &setID,
2✔
1946
                                ),
2✔
1947
                                tlv.MakePrimitiveRecord(
2✔
1948
                                        ampStateHtlcStateType, &htlcState,
2✔
1949
                                ),
2✔
1950
                                tlv.MakePrimitiveRecord(
2✔
1951
                                        ampStateSettleIndexType, &settleIndex,
2✔
1952
                                ),
2✔
1953
                                tlv.MakePrimitiveRecord(
2✔
1954
                                        ampStateSettleDateType,
2✔
1955
                                        &settleDateBytes,
2✔
1956
                                ),
2✔
1957
                                tlv.MakeDynamicRecord(
2✔
1958
                                        ampStateCircuitKeysType,
2✔
1959
                                        &invoiceKeys, nil,
2✔
1960
                                        encodeCircuitKeys, decodeCircuitKeys,
2✔
1961
                                ),
2✔
1962
                                tlv.MakePrimitiveRecord(
2✔
1963
                                        ampStateAmtPaidType, &amtPaid,
2✔
1964
                                ),
2✔
1965
                        )
2✔
1966
                        if err != nil {
2✔
1967
                                return err
×
1968
                        }
×
1969

1970
                        err = tlvStream.Decode(&innerTlvReader)
2✔
1971
                        if err != nil {
2✔
1972
                                return err
×
1973
                        }
×
1974

1975
                        var settleDate time.Time
2✔
1976
                        err = settleDate.UnmarshalBinary(settleDateBytes)
2✔
1977
                        if err != nil {
2✔
1978
                                return err
×
1979
                        }
×
1980

1981
                        (*v)[setID] = invpkg.InvoiceStateAMP{
2✔
1982
                                State:       invpkg.HtlcState(htlcState),
2✔
1983
                                SettleIndex: settleIndex,
2✔
1984
                                SettleDate:  settleDate,
2✔
1985
                                InvoiceKeys: invoiceKeys,
2✔
1986
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
2✔
1987
                        }
2✔
1988
                }
1989

1990
                return nil
2✔
1991
        }
1992

1993
        return tlv.NewTypeForDecodingErr(
×
1994
                val, "channeldb.AMPInvoiceState", l, l,
×
1995
        )
×
1996
}
1997

1998
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
1999
// as a map.
2000
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
2001
        error) {
2✔
2002

2✔
2003
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
2✔
2004
        for {
4✔
2005
                // Read the length of the tlv stream for this htlc.
2✔
2006
                var streamLen int64
2✔
2007
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
4✔
2008
                        if err == io.EOF {
4✔
2009
                                break
2✔
2010
                        }
2011

2012
                        return nil, err
×
2013
                }
2014

2015
                // Limit the reader so that it stops at the end of this htlc's
2016
                // stream.
2017
                htlcReader := io.LimitReader(r, streamLen)
2✔
2018

2✔
2019
                // Decode the contents into the htlc fields.
2✔
2020
                var (
2✔
2021
                        htlc                    invpkg.InvoiceHTLC
2✔
2022
                        key                     models.CircuitKey
2✔
2023
                        chanID                  uint64
2✔
2024
                        state                   uint8
2✔
2025
                        acceptTime, resolveTime uint64
2✔
2026
                        amt, mppTotalAmt        uint64
2✔
2027
                        amp                     = &record.AMP{}
2✔
2028
                        hash32                  = &[32]byte{}
2✔
2029
                        preimage32              = &[32]byte{}
2✔
2030
                )
2✔
2031
                tlvStream, err := tlv.NewStream(
2✔
2032
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
2✔
2033
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
2✔
2034
                        tlv.MakePrimitiveRecord(amtType, &amt),
2✔
2035
                        tlv.MakePrimitiveRecord(
2✔
2036
                                acceptHeightType, &htlc.AcceptHeight,
2✔
2037
                        ),
2✔
2038
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
2✔
2039
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
2✔
2040
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
2✔
2041
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
2✔
2042
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
2✔
2043
                        tlv.MakeDynamicRecord(
2✔
2044
                                htlcAMPType, amp, amp.PayloadSize,
2✔
2045
                                record.AMPEncoder, record.AMPDecoder,
2✔
2046
                        ),
2✔
2047
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
2✔
2048
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
2✔
2049
                )
2✔
2050
                if err != nil {
2✔
2051
                        return nil, err
×
2052
                }
×
2053

2054
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
2✔
2055
                if err != nil {
2✔
2056
                        return nil, err
×
2057
                }
×
2058

2059
                if _, ok := parsedTypes[htlcAMPType]; !ok {
4✔
2060
                        amp = nil
2✔
2061
                }
2✔
2062

2063
                var preimage *lntypes.Preimage
2✔
2064
                if _, ok := parsedTypes[htlcPreimageType]; ok {
4✔
2065
                        pimg := lntypes.Preimage(*preimage32)
2✔
2066
                        preimage = &pimg
2✔
2067
                }
2✔
2068

2069
                var hash *lntypes.Hash
2✔
2070
                if _, ok := parsedTypes[htlcHashType]; ok {
4✔
2071
                        h := lntypes.Hash(*hash32)
2✔
2072
                        hash = &h
2✔
2073
                }
2✔
2074

2075
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
2✔
2076
                htlc.AcceptTime = getNanoTime(acceptTime)
2✔
2077
                htlc.ResolveTime = getNanoTime(resolveTime)
2✔
2078
                htlc.State = invpkg.HtlcState(state)
2✔
2079
                htlc.Amt = lnwire.MilliSatoshi(amt)
2✔
2080
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
2✔
2081
                if amp != nil && hash != nil {
4✔
2082
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
2✔
2083
                                Record:   *amp,
2✔
2084
                                Hash:     *hash,
2✔
2085
                                Preimage: preimage,
2✔
2086
                        }
2✔
2087
                }
2✔
2088

2089
                // Reconstruct the custom records fields from the parsed types
2090
                // map return from the tlv parser.
2091
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
2✔
2092

2✔
2093
                htlcs[key] = &htlc
2✔
2094
        }
2095

2096
        return htlcs, nil
2✔
2097
}
2098

2099
// invoiceSetIDKeyLen is the length of the key that's used to store the
2100
// individual HTLCs prefixed by their ID along side the main invoice within the
2101
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2102
// set ID.
2103
const invoiceSetIDKeyLen = 4 + 32
2104

2105
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2106
// number where the HTLCs for this setID will be stored udner.
2107
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
2✔
2108
        // Construct the prefix key we need to obtain the invoice information:
2✔
2109
        // invoiceNum || setID.
2✔
2110
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
2✔
2111
        copy(invoiceSetIDKey[:], invoiceNum)
2✔
2112
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
2✔
2113

2✔
2114
        return invoiceSetIDKey
2✔
2115
}
2✔
2116

2117
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
2118
// greater AMP invoices. We do this by deleting the set of keys that share the
2119
// invoice number as a prefix.
UNCOV
2120
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
×
UNCOV
2121
        // Since it isn't safe to delete using an active cursor, we'll use the
×
UNCOV
2122
        // cursor simply to collect the set of keys we need to delete, _then_
×
UNCOV
2123
        // delete them in another pass.
×
UNCOV
2124
        var keysToDel [][]byte
×
UNCOV
2125
        err := forEachAMPInvoice(
×
UNCOV
2126
                invoiceBucket, invoiceNum,
×
UNCOV
2127
                func(cursorKey, v []byte) error {
×
UNCOV
2128
                        keysToDel = append(keysToDel, cursorKey)
×
UNCOV
2129
                        return nil
×
UNCOV
2130
                },
×
2131
        )
UNCOV
2132
        if err != nil {
×
2133
                return err
×
2134
        }
×
2135

2136
        // In this next phase, we'll then delete all the relevant invoices.
UNCOV
2137
        for _, keyToDel := range keysToDel {
×
UNCOV
2138
                if err := invoiceBucket.Delete(keyToDel); err != nil {
×
2139
                        return err
×
2140
                }
×
2141
        }
2142

UNCOV
2143
        return nil
×
2144
}
2145

2146
// delAMPSettleIndex removes all the entries in the settle index associated
2147
// with a given AMP invoice.
2148
func delAMPSettleIndex(invoiceNum []byte, invoices,
UNCOV
2149
        settleIndex kvdb.RwBucket) error {
×
UNCOV
2150

×
UNCOV
2151
        // First, we need to grab the AMP invoice state to see if there's
×
UNCOV
2152
        // anything that we even need to delete.
×
UNCOV
2153
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
×
UNCOV
2154
        if err != nil {
×
2155
                return err
×
2156
        }
×
2157

2158
        // If there's no AMP state at all (non-AMP invoice), then we can return
2159
        // early.
UNCOV
2160
        if len(ampState) == 0 {
×
UNCOV
2161
                return nil
×
UNCOV
2162
        }
×
2163

2164
        // Otherwise, we'll need to iterate and delete each settle index within
2165
        // the set of returned entries.
UNCOV
2166
        var settleIndexKey [8]byte
×
UNCOV
2167
        for _, subState := range ampState {
×
UNCOV
2168
                byteOrder.PutUint64(
×
UNCOV
2169
                        settleIndexKey[:], subState.SettleIndex,
×
UNCOV
2170
                )
×
UNCOV
2171

×
UNCOV
2172
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
×
2173
                        return err
×
2174
                }
×
2175
        }
2176

UNCOV
2177
        return nil
×
2178
}
2179

2180
// DeleteCanceledInvoices deletes all canceled invoices from the database.
UNCOV
2181
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
×
UNCOV
2182
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2183
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2184
                if invoices == nil {
×
2185
                        return nil
×
2186
                }
×
2187

UNCOV
2188
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2189
                        invoiceIndexBucket,
×
UNCOV
2190
                )
×
UNCOV
2191
                if invoiceIndex == nil {
×
UNCOV
2192
                        return nil
×
UNCOV
2193
                }
×
2194

UNCOV
2195
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2196
                        addIndexBucket,
×
UNCOV
2197
                )
×
UNCOV
2198
                if invoiceAddIndex == nil {
×
2199
                        return nil
×
2200
                }
×
2201

UNCOV
2202
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
UNCOV
2203

×
UNCOV
2204
                return invoiceIndex.ForEach(func(k, v []byte) error {
×
UNCOV
2205
                        // Skip the special numInvoicesKey as that does not
×
UNCOV
2206
                        // point to a valid invoice.
×
UNCOV
2207
                        if bytes.Equal(k, numInvoicesKey) {
×
UNCOV
2208
                                return nil
×
UNCOV
2209
                        }
×
2210

2211
                        // Skip sub-buckets.
UNCOV
2212
                        if v == nil {
×
2213
                                return nil
×
2214
                        }
×
2215

UNCOV
2216
                        invoice, err := fetchInvoice(v, invoices, nil, false)
×
UNCOV
2217
                        if err != nil {
×
2218
                                return err
×
2219
                        }
×
2220

UNCOV
2221
                        if invoice.State != invpkg.ContractCanceled {
×
UNCOV
2222
                                return nil
×
UNCOV
2223
                        }
×
2224

2225
                        // Delete the payment hash from the invoice index.
UNCOV
2226
                        err = invoiceIndex.Delete(k)
×
UNCOV
2227
                        if err != nil {
×
2228
                                return err
×
2229
                        }
×
2230

2231
                        // Delete payment address index reference if there's a
2232
                        // valid payment address.
UNCOV
2233
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
×
UNCOV
2234
                                // To ensure consistency check that the already
×
UNCOV
2235
                                // fetched invoice key matches the one in the
×
UNCOV
2236
                                // payment address index.
×
UNCOV
2237
                                key := payAddrIndex.Get(
×
UNCOV
2238
                                        invoice.Terms.PaymentAddr[:],
×
UNCOV
2239
                                )
×
UNCOV
2240
                                if bytes.Equal(key, k) {
×
2241
                                        // Delete from the payment address
×
2242
                                        // index.
×
2243
                                        if err := payAddrIndex.Delete(
×
2244
                                                invoice.Terms.PaymentAddr[:],
×
2245
                                        ); err != nil {
×
2246
                                                return err
×
2247
                                        }
×
2248
                                }
2249
                        }
2250

2251
                        // Remove from the add index.
UNCOV
2252
                        var addIndexKey [8]byte
×
UNCOV
2253
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
×
UNCOV
2254
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
UNCOV
2255
                        if err != nil {
×
2256
                                return err
×
2257
                        }
×
2258

2259
                        // Note that we don't need to delete the invoice from
2260
                        // the settle index as it is not added until the
2261
                        // invoice is settled.
2262

2263
                        // Now remove all sub invoices.
UNCOV
2264
                        err = delAMPInvoices(k, invoices)
×
UNCOV
2265
                        if err != nil {
×
2266
                                return err
×
2267
                        }
×
2268

2269
                        // Finally remove the serialized invoice from the
2270
                        // invoice bucket.
UNCOV
2271
                        return invoices.Delete(k)
×
2272
                })
UNCOV
2273
        }, func() {})
×
2274
}
2275

2276
// DeleteInvoice attempts to delete the passed invoices from the database in
2277
// one transaction. The passed delete references hold all keys required to
2278
// delete the invoices without also needing to deserialize them.
2279
func (d *DB) DeleteInvoice(_ context.Context,
UNCOV
2280
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
×
UNCOV
2281

×
UNCOV
2282
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2283
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2284
                if invoices == nil {
×
2285
                        return invpkg.ErrNoInvoicesCreated
×
2286
                }
×
2287

UNCOV
2288
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2289
                        invoiceIndexBucket,
×
UNCOV
2290
                )
×
UNCOV
2291
                if invoiceIndex == nil {
×
2292
                        return invpkg.ErrNoInvoicesCreated
×
2293
                }
×
2294

UNCOV
2295
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2296
                        addIndexBucket,
×
UNCOV
2297
                )
×
UNCOV
2298
                if invoiceAddIndex == nil {
×
2299
                        return invpkg.ErrNoInvoicesCreated
×
2300
                }
×
2301

2302
                // settleIndex can be nil, as the bucket is created lazily
2303
                // when the first invoice is settled.
UNCOV
2304
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
×
UNCOV
2305

×
UNCOV
2306
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
UNCOV
2307

×
UNCOV
2308
                for _, ref := range invoicesToDelete {
×
UNCOV
2309
                        // Fetch the invoice key for using it to check for
×
UNCOV
2310
                        // consistency and also to delete from the invoice
×
UNCOV
2311
                        // index.
×
UNCOV
2312
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
×
UNCOV
2313
                        if invoiceKey == nil {
×
UNCOV
2314
                                return invpkg.ErrInvoiceNotFound
×
UNCOV
2315
                        }
×
2316

UNCOV
2317
                        err := invoiceIndex.Delete(ref.PayHash[:])
×
UNCOV
2318
                        if err != nil {
×
2319
                                return err
×
2320
                        }
×
2321

2322
                        // Delete payment address index reference if there's a
2323
                        // valid payment address passed.
UNCOV
2324
                        if ref.PayAddr != nil {
×
UNCOV
2325
                                // To ensure consistency check that the already
×
UNCOV
2326
                                // fetched invoice key matches the one in the
×
UNCOV
2327
                                // payment address index.
×
UNCOV
2328
                                key := payAddrIndex.Get(ref.PayAddr[:])
×
UNCOV
2329
                                if bytes.Equal(key, invoiceKey) {
×
UNCOV
2330
                                        // Delete from the payment address
×
UNCOV
2331
                                        // index. Note that since the payment
×
UNCOV
2332
                                        // address index has been introduced
×
UNCOV
2333
                                        // with an empty migration it may be
×
UNCOV
2334
                                        // possible that the index doesn't have
×
UNCOV
2335
                                        // an entry for this invoice.
×
UNCOV
2336
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
×
UNCOV
2337
                                        if err := payAddrIndex.Delete(
×
UNCOV
2338
                                                ref.PayAddr[:],
×
UNCOV
2339
                                        ); err != nil {
×
2340
                                                return err
×
2341
                                        }
×
2342
                                }
2343
                        }
2344

UNCOV
2345
                        var addIndexKey [8]byte
×
UNCOV
2346
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
×
UNCOV
2347

×
UNCOV
2348
                        // To ensure consistency check that the key stored in
×
UNCOV
2349
                        // the add index also matches the previously fetched
×
UNCOV
2350
                        // invoice key.
×
UNCOV
2351
                        key := invoiceAddIndex.Get(addIndexKey[:])
×
UNCOV
2352
                        if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2353
                                return fmt.Errorf("unknown invoice " +
×
UNCOV
2354
                                        "in add index")
×
UNCOV
2355
                        }
×
2356

2357
                        // Remove from the add index.
UNCOV
2358
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
UNCOV
2359
                        if err != nil {
×
2360
                                return err
×
2361
                        }
×
2362

2363
                        // Remove from the settle index if available and
2364
                        // if the invoice is settled.
UNCOV
2365
                        if settleIndex != nil && ref.SettleIndex > 0 {
×
UNCOV
2366
                                var settleIndexKey [8]byte
×
UNCOV
2367
                                byteOrder.PutUint64(
×
UNCOV
2368
                                        settleIndexKey[:], ref.SettleIndex,
×
UNCOV
2369
                                )
×
UNCOV
2370

×
UNCOV
2371
                                // To ensure consistency check that the already
×
UNCOV
2372
                                // fetched invoice key matches the one in the
×
UNCOV
2373
                                // settle index
×
UNCOV
2374
                                key := settleIndex.Get(settleIndexKey[:])
×
UNCOV
2375
                                if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2376
                                        return fmt.Errorf("unknown invoice " +
×
UNCOV
2377
                                                "in settle index")
×
UNCOV
2378
                                }
×
2379

UNCOV
2380
                                err = settleIndex.Delete(settleIndexKey[:])
×
UNCOV
2381
                                if err != nil {
×
2382
                                        return err
×
2383
                                }
×
2384
                        }
2385

2386
                        // In addition to deleting the main invoice state, if
2387
                        // this is an AMP invoice, then we'll also need to
2388
                        // delete the set HTLC set stored as a key prefix. For
2389
                        // non-AMP invoices, this'll be a noop.
UNCOV
2390
                        err = delAMPSettleIndex(
×
UNCOV
2391
                                invoiceKey, invoices, settleIndex,
×
UNCOV
2392
                        )
×
UNCOV
2393
                        if err != nil {
×
2394
                                return err
×
2395
                        }
×
UNCOV
2396
                        err = delAMPInvoices(invoiceKey, invoices)
×
UNCOV
2397
                        if err != nil {
×
2398
                                return err
×
2399
                        }
×
2400

2401
                        // Finally remove the serialized invoice from the
2402
                        // invoice bucket.
UNCOV
2403
                        err = invoices.Delete(invoiceKey)
×
UNCOV
2404
                        if err != nil {
×
2405
                                return err
×
2406
                        }
×
2407
                }
2408

UNCOV
2409
                return nil
×
UNCOV
2410
        }, func() {})
×
2411

UNCOV
2412
        return err
×
2413
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc