• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 14880107727

07 May 2025 09:35AM UTC coverage: 58.446% (-10.5%) from 68.992%
14880107727

Pull #9789

github

web-flow
Merge ed3471042 into 67a40c90a
Pull Request #9789: multi: use updated TLV SizeFunc signature

54 of 95 new or added lines in 19 files covered. (56.84%)

28462 existing lines in 449 files now uncovered.

97165 of 166247 relevant lines covered (58.45%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.51
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "maps"
11
        "time"
12

13
        "github.com/lightningnetwork/lnd/graph/db/models"
14
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
15
        invpkg "github.com/lightningnetwork/lnd/invoices"
16
        "github.com/lightningnetwork/lnd/kvdb"
17
        "github.com/lightningnetwork/lnd/lntypes"
18
        "github.com/lightningnetwork/lnd/lnwire"
19
        "github.com/lightningnetwork/lnd/record"
20
        "github.com/lightningnetwork/lnd/tlv"
21
)
22

23
var (
24
        // invoiceBucket is the name of the bucket within the database that
25
        // stores all data related to invoices no matter their final state.
26
        // Within the invoice bucket, each invoice is keyed by its invoice ID
27
        // which is a monotonically increasing uint32.
28
        invoiceBucket = []byte("invoices")
29

30
        // paymentHashIndexBucket is the name of the sub-bucket within the
31
        // invoiceBucket which indexes all invoices by their payment hash. The
32
        // payment hash is the sha256 of the invoice's payment preimage. This
33
        // index is used to detect duplicates, and also to provide a fast path
34
        // for looking up incoming HTLCs to determine if we're able to settle
35
        // them fully.
36
        //
37
        // maps: payHash => invoiceKey
38
        invoiceIndexBucket = []byte("paymenthashes")
39

40
        // payAddrIndexBucket is the name of the top-level bucket that maps
41
        // payment addresses to their invoice number. This can be used
42
        // to efficiently query or update non-legacy invoices. Note that legacy
43
        // invoices will not be included in this index since they all have the
44
        // same, all-zero payment address, however all newly generated invoices
45
        // will end up in this index.
46
        //
47
        // maps: payAddr => invoiceKey
48
        payAddrIndexBucket = []byte("pay-addr-index")
49

50
        // setIDIndexBucket is the name of the top-level bucket that maps set
51
        // ids to their invoice number. This can be used to efficiently query or
52
        // update AMP invoice. Note that legacy or MPP invoices will not be
53
        // included in this index, since their HTLCs do not have a set id.
54
        //
55
        // maps: setID => invoiceKey
56
        setIDIndexBucket = []byte("set-id-index")
57

58
        // numInvoicesKey is the name of key which houses the auto-incrementing
59
        // invoice ID which is essentially used as a primary key. With each
60
        // invoice inserted, the primary key is incremented by one. This key is
61
        // stored within the invoiceIndexBucket. Within the invoiceBucket
62
        // invoices are uniquely identified by the invoice ID.
63
        numInvoicesKey = []byte("nik")
64

65
        // addIndexBucket is an index bucket that we'll use to create a
66
        // monotonically increasing set of add indexes. Each time we add a new
67
        // invoice, this sequence number will be incremented and then populated
68
        // within the new invoice.
69
        //
70
        // In addition to this sequence number, we map:
71
        //
72
        //   addIndexNo => invoiceKey
73
        addIndexBucket = []byte("invoice-add-index")
74

75
        // settleIndexBucket is an index bucket that we'll use to create a
76
        // monotonically increasing integer for tracking a "settle index". Each
77
        // time an invoice is settled, this sequence number will be incremented
78
        // as populate within the newly settled invoice.
79
        //
80
        // In addition to this sequence number, we map:
81
        //
82
        //   settleIndexNo => invoiceKey
83
        settleIndexBucket = []byte("invoice-settle-index")
84

85
        // invoiceBucketTombstone is a special key that indicates the invoice
86
        // bucket has been permanently closed. Its purpose is to prevent the
87
        // invoice bucket from being reopened in the future. A key use case for
88
        // the tombstone is to ensure users cannot switch back to the KV invoice
89
        // database after migrating to the native SQL database.
90
        invoiceBucketTombstone = []byte("invoice-tombstone")
91
)
92

93
const (
94
        // A set of tlv type definitions used to serialize invoice htlcs to the
95
        // database.
96
        //
97
        // NOTE: A migration should be added whenever this list changes. This
98
        // prevents against the database being rolled back to an older
99
        // format where the surrounding logic might assume a different set of
100
        // fields are known.
101
        chanIDType       tlv.Type = 1
102
        htlcIDType       tlv.Type = 3
103
        amtType          tlv.Type = 5
104
        acceptHeightType tlv.Type = 7
105
        acceptTimeType   tlv.Type = 9
106
        resolveTimeType  tlv.Type = 11
107
        expiryHeightType tlv.Type = 13
108
        htlcStateType    tlv.Type = 15
109
        mppTotalAmtType  tlv.Type = 17
110
        htlcAMPType      tlv.Type = 19
111
        htlcHashType     tlv.Type = 21
112
        htlcPreimageType tlv.Type = 23
113

114
        // A set of tlv type definitions used to serialize invoice bodiees.
115
        //
116
        // NOTE: A migration should be added whenever this list changes. This
117
        // prevents against the database being rolled back to an older
118
        // format where the surrounding logic might assume a different set of
119
        // fields are known.
120
        memoType            tlv.Type = 0
121
        payReqType          tlv.Type = 1
122
        createTimeType      tlv.Type = 2
123
        settleTimeType      tlv.Type = 3
124
        addIndexType        tlv.Type = 4
125
        settleIndexType     tlv.Type = 5
126
        preimageType        tlv.Type = 6
127
        valueType           tlv.Type = 7
128
        cltvDeltaType       tlv.Type = 8
129
        expiryType          tlv.Type = 9
130
        paymentAddrType     tlv.Type = 10
131
        featuresType        tlv.Type = 11
132
        invStateType        tlv.Type = 12
133
        amtPaidType         tlv.Type = 13
134
        hodlInvoiceType     tlv.Type = 14
135
        invoiceAmpStateType tlv.Type = 15
136

137
        // A set of tlv type definitions used to serialize the invoice AMP
138
        // state along-side the main invoice body.
139
        ampStateSetIDType       tlv.Type = 0
140
        ampStateHtlcStateType   tlv.Type = 1
141
        ampStateSettleIndexType tlv.Type = 2
142
        ampStateSettleDateType  tlv.Type = 3
143
        ampStateCircuitKeysType tlv.Type = 4
144
        ampStateAmtPaidType     tlv.Type = 5
145

146
        // invoiceProgressLogInterval is the interval we use limiting the
147
        // logging output of invoice processing.
148
        invoiceProgressLogInterval = 30 * time.Second
149
)
150

151
// AddInvoice inserts the targeted invoice into the database. If the invoice has
152
// *any* payment hashes which already exists within the database, then the
153
// insertion will be aborted and rejected due to the strict policy banning any
154
// duplicate payment hashes. A side effect of this function is that it sets
155
// AddIndex on newInvoice.
156
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
157
        paymentHash lntypes.Hash) (uint64, error) {
3✔
158

3✔
159
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
3✔
UNCOV
160
                return 0, err
×
UNCOV
161
        }
×
162

163
        var invoiceAddIndex uint64
3✔
164
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
165
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
3✔
166
                if err != nil {
3✔
167
                        return err
×
168
                }
×
169

170
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
171
                        invoiceIndexBucket,
3✔
172
                )
3✔
173
                if err != nil {
3✔
174
                        return err
×
175
                }
×
176
                addIndex, err := invoices.CreateBucketIfNotExists(
3✔
177
                        addIndexBucket,
3✔
178
                )
3✔
179
                if err != nil {
3✔
180
                        return err
×
181
                }
×
182

183
                // Ensure that an invoice an identical payment hash doesn't
184
                // already exist within the index.
185
                if invoiceIndex.Get(paymentHash[:]) != nil {
3✔
UNCOV
186
                        return invpkg.ErrDuplicateInvoice
×
UNCOV
187
                }
×
188

189
                // Check that we aren't inserting an invoice with a duplicate
190
                // payment address. The all-zeros payment address is
191
                // special-cased to support legacy keysend invoices which don't
192
                // assign one. This is safe since later we also will avoid
193
                // indexing them and avoid collisions.
194
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
3✔
195
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
6✔
196
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
3✔
197
                        if payAddrIndex.Get(paymentAddr) != nil {
6✔
198
                                return invpkg.ErrDuplicatePayAddr
3✔
199
                        }
3✔
200
                }
201

202
                // If the current running payment ID counter hasn't yet been
203
                // created, then create it now.
204
                var invoiceNum uint32
3✔
205
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
3✔
206
                if invoiceCounter == nil {
6✔
207
                        var scratch [4]byte
3✔
208
                        byteOrder.PutUint32(scratch[:], invoiceNum)
3✔
209
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
3✔
210
                        if err != nil {
3✔
211
                                return err
×
212
                        }
×
213
                } else {
3✔
214
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
3✔
215
                }
3✔
216

217
                newIndex, err := putInvoice(
3✔
218
                        invoices, invoiceIndex, payAddrIndex, addIndex,
3✔
219
                        newInvoice, invoiceNum, paymentHash,
3✔
220
                )
3✔
221
                if err != nil {
3✔
222
                        return err
×
223
                }
×
224

225
                invoiceAddIndex = newIndex
3✔
226
                return nil
3✔
227
        }, func() {
3✔
228
                invoiceAddIndex = 0
3✔
229
        })
3✔
230
        if err != nil {
6✔
231
                return 0, err
3✔
232
        }
3✔
233

234
        return invoiceAddIndex, err
3✔
235
}
236

237
// InvoicesAddedSince can be used by callers to seek into the event time series
238
// of all the invoices added in the database. The specified sinceAddIndex
239
// should be the highest add index that the caller knows of. This method will
240
// return all invoices with an add index greater than the specified
241
// sinceAddIndex.
242
//
243
// NOTE: The index starts from 1, as a result. We enforce that specifying a
244
// value below the starting index value is a noop.
245
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
246
        []invpkg.Invoice, error) {
3✔
247

3✔
248
        var (
3✔
249
                newInvoices    []invpkg.Invoice
3✔
250
                start          = time.Now()
3✔
251
                lastLogTime    = time.Now()
3✔
252
                processedCount int
3✔
253
        )
3✔
254

3✔
255
        // If an index of zero was specified, then in order to maintain
3✔
256
        // backwards compat, we won't send out any new invoices.
3✔
257
        if sinceAddIndex == 0 {
6✔
258
                return newInvoices, nil
3✔
259
        }
3✔
260

261
        var startIndex [8]byte
3✔
262
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
3✔
263

3✔
264
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
265
                invoices := tx.ReadBucket(invoiceBucket)
3✔
266
                if invoices == nil {
3✔
267
                        return nil
×
268
                }
×
269

270
                addIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
271
                if addIndex == nil {
3✔
272
                        return nil
×
273
                }
×
274

275
                // We'll now run through each entry in the add index starting
276
                // at our starting index. We'll continue until we reach the
277
                // very end of the current key space.
278
                invoiceCursor := addIndex.ReadCursor()
3✔
279

3✔
280
                // We'll seek to the starting index, then manually advance the
3✔
281
                // cursor in order to skip the entry with the since add index.
3✔
282
                invoiceCursor.Seek(startIndex[:])
3✔
283
                addSeqNo, invoiceKey := invoiceCursor.Next()
3✔
284

3✔
285
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
6✔
286
                        // For each key found, we'll look up the actual
3✔
287
                        // invoice, then accumulate it into our return value.
3✔
288
                        invoice, err := fetchInvoice(
3✔
289
                                invoiceKey, invoices, nil, false,
3✔
290
                        )
3✔
291
                        if err != nil {
3✔
292
                                return err
×
293
                        }
×
294

295
                        newInvoices = append(newInvoices, invoice)
3✔
296

3✔
297
                        processedCount++
3✔
298
                        if time.Since(lastLogTime) >=
3✔
299
                                invoiceProgressLogInterval {
3✔
300

×
301
                                log.Debugf("Processed %d invoices which "+
×
302
                                        "were added since add index %v",
×
303
                                        processedCount, sinceAddIndex)
×
304

×
305
                                lastLogTime = time.Now()
×
306
                        }
×
307
                }
308

309
                return nil
3✔
310
        }, func() {
3✔
311
                newInvoices = nil
3✔
312
        })
3✔
313
        if err != nil {
3✔
314
                return nil, err
×
315
        }
×
316

317
        elapsed := time.Since(start)
3✔
318
        log.Debugf("Completed scanning for invoices added since index %v: "+
3✔
319
                "total_processed=%d, found_invoices=%d, elapsed=%v",
3✔
320
                sinceAddIndex, processedCount, len(newInvoices),
3✔
321
                elapsed.Round(time.Millisecond))
3✔
322

3✔
323
        return newInvoices, nil
3✔
324
}
325

326
// LookupInvoice attempts to look up an invoice according to its 32 byte
327
// payment hash. If an invoice which can settle the HTLC identified by the
328
// passed payment hash isn't found, then an error is returned. Otherwise, the
329
// full invoice is returned. Before setting the incoming HTLC, the values
330
// SHOULD be checked to ensure the payer meets the agreed upon contractual
331
// terms of the payment.
332
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
333
        invpkg.Invoice, error) {
3✔
334

3✔
335
        var invoice invpkg.Invoice
3✔
336
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
337
                invoices := tx.ReadBucket(invoiceBucket)
3✔
338
                if invoices == nil {
3✔
339
                        return invpkg.ErrNoInvoicesCreated
×
340
                }
×
341
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
342
                if invoiceIndex == nil {
6✔
343
                        return invpkg.ErrNoInvoicesCreated
3✔
344
                }
3✔
345
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
3✔
346
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
3✔
347

3✔
348
                // Retrieve the invoice number for this invoice using
3✔
349
                // the provided invoice reference.
3✔
350
                invoiceNum, err := fetchInvoiceNumByRef(
3✔
351
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
3✔
352
                )
3✔
353
                if err != nil {
6✔
354
                        return err
3✔
355
                }
3✔
356

357
                var setID *invpkg.SetID
3✔
358
                switch {
3✔
359
                // If this is a payment address ref, and the blank modified was
360
                // specified, then we'll use the zero set ID to indicate that
361
                // we won't want any HTLCs returned.
362
                case ref.PayAddr() != nil &&
363
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
3✔
364

3✔
365
                        var zeroSetID invpkg.SetID
3✔
366
                        setID = &zeroSetID
3✔
367

368
                // If this is a set ID ref, and the htlc set only modified was
369
                // specified, then we'll pass through the specified setID so
370
                // only that will be returned.
371
                case ref.SetID() != nil &&
372
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
3✔
373

3✔
374
                        setID = (*invpkg.SetID)(ref.SetID())
3✔
375
                }
376

377
                // An invoice was found, retrieve the remainder of the invoice
378
                // body.
379
                i, err := fetchInvoice(
3✔
380
                        invoiceNum, invoices, []*invpkg.SetID{setID}, true,
3✔
381
                )
3✔
382
                if err != nil {
3✔
383
                        return err
×
384
                }
×
385
                invoice = i
3✔
386

3✔
387
                return nil
3✔
388
        }, func() {})
3✔
389
        if err != nil {
6✔
390
                return invoice, err
3✔
391
        }
3✔
392

393
        return invoice, nil
3✔
394
}
395

396
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
397
// reference. The payment address will be treated as the primary key, falling
398
// back to the payment hash if nothing is found for the payment address. An
399
// error is returned if the invoice is not found.
400
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
401
        ref invpkg.InvoiceRef) ([]byte, error) {
3✔
402

3✔
403
        // If the set id is present, we only consult the set id index for this
3✔
404
        // invoice. This type of query is only used to facilitate user-facing
3✔
405
        // requests to lookup, settle or cancel an AMP invoice.
3✔
406
        setID := ref.SetID()
3✔
407
        if setID != nil {
6✔
408
                invoiceNumBySetID := setIDIndex.Get(setID[:])
3✔
409
                if invoiceNumBySetID == nil {
3✔
UNCOV
410
                        return nil, invpkg.ErrInvoiceNotFound
×
UNCOV
411
                }
×
412

413
                return invoiceNumBySetID, nil
3✔
414
        }
415

416
        payHash := ref.PayHash()
3✔
417
        payAddr := ref.PayAddr()
3✔
418

3✔
419
        getInvoiceNumByHash := func() []byte {
6✔
420
                if payHash != nil {
6✔
421
                        return invoiceIndex.Get(payHash[:])
3✔
422
                }
3✔
423
                return nil
3✔
424
        }
425

426
        getInvoiceNumByAddr := func() []byte {
6✔
427
                if payAddr != nil {
6✔
428
                        // Only allow lookups for payment address if it is not a
3✔
429
                        // blank payment address, which is a special-cased value
3✔
430
                        // for legacy keysend invoices.
3✔
431
                        if *payAddr != invpkg.BlankPayAddr {
6✔
432
                                return payAddrIndex.Get(payAddr[:])
3✔
433
                        }
3✔
434
                }
435
                return nil
3✔
436
        }
437

438
        invoiceNumByHash := getInvoiceNumByHash()
3✔
439
        invoiceNumByAddr := getInvoiceNumByAddr()
3✔
440
        switch {
3✔
441
        // If payment address and payment hash both reference an existing
442
        // invoice, ensure they reference the _same_ invoice.
443
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
3✔
444
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
3✔
UNCOV
445
                        return nil, invpkg.ErrInvRefEquivocation
×
UNCOV
446
                }
×
447

448
                return invoiceNumByAddr, nil
3✔
449

450
        // Return invoices by payment addr only.
451
        //
452
        // NOTE: We constrain this lookup to only apply if the invoice ref does
453
        // not contain a payment hash. Legacy and MPP payments depend on the
454
        // payment hash index to enforce that the HTLCs payment hash matches the
455
        // payment hash for the invoice, without this check we would
456
        // inadvertently assume the invoice contains the correct preimage for
457
        // the HTLC, which we only enforce via the lookup by the invoice index.
458
        case invoiceNumByAddr != nil && payHash == nil:
3✔
459
                return invoiceNumByAddr, nil
3✔
460

461
        // If we were only able to reference the invoice by hash, return the
462
        // corresponding invoice number. This can happen when no payment address
463
        // was provided, or if it didn't match anything in our records.
464
        case invoiceNumByHash != nil:
3✔
465
                return invoiceNumByHash, nil
3✔
466

467
        // Otherwise we don't know of the target invoice.
468
        default:
3✔
469
                return nil, invpkg.ErrInvoiceNotFound
3✔
470
        }
471
}
472

473
// FetchPendingInvoices returns all invoices that have not yet been settled or
474
// canceled. The returned map is keyed by the payment hash of each respective
475
// invoice.
476
func (d *DB) FetchPendingInvoices(_ context.Context) (
477
        map[lntypes.Hash]invpkg.Invoice, error) {
3✔
478

3✔
479
        result := make(map[lntypes.Hash]invpkg.Invoice)
3✔
480

3✔
481
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
482
                invoices := tx.ReadBucket(invoiceBucket)
3✔
483
                if invoices == nil {
3✔
484
                        return nil
×
485
                }
×
486

487
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
488
                if invoiceIndex == nil {
6✔
489
                        // Mask the error if there's no invoice
3✔
490
                        // index as that simply means there are no
3✔
491
                        // invoices added yet to the DB. In this case
3✔
492
                        // we simply return an empty list.
3✔
493
                        return nil
3✔
494
                }
3✔
495

496
                return invoiceIndex.ForEach(func(k, v []byte) error {
6✔
497
                        // Skip the special numInvoicesKey as that does not
3✔
498
                        // point to a valid invoice.
3✔
499
                        if bytes.Equal(k, numInvoicesKey) {
6✔
500
                                return nil
3✔
501
                        }
3✔
502

503
                        // Skip sub-buckets.
504
                        if v == nil {
3✔
505
                                return nil
×
506
                        }
×
507

508
                        invoice, err := fetchInvoice(v, invoices, nil, false)
3✔
509
                        if err != nil {
3✔
510
                                return err
×
511
                        }
×
512

513
                        if invoice.IsPending() {
6✔
514
                                var paymentHash lntypes.Hash
3✔
515
                                copy(paymentHash[:], k)
3✔
516
                                result[paymentHash] = invoice
3✔
517
                        }
3✔
518

519
                        return nil
3✔
520
                })
521
        }, func() {
3✔
522
                result = make(map[lntypes.Hash]invpkg.Invoice)
3✔
523
        })
3✔
524

525
        if err != nil {
3✔
526
                return nil, err
×
527
        }
×
528

529
        return result, nil
3✔
530
}
531

532
// QueryInvoices allows a caller to query the invoice database for invoices
533
// within the specified add index range.
534
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
535
        invpkg.InvoiceSlice, error) {
3✔
536

3✔
537
        var resp invpkg.InvoiceSlice
3✔
538

3✔
539
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
540
                // If the bucket wasn't found, then there aren't any invoices
3✔
541
                // within the database yet, so we can simply exit.
3✔
542
                invoices := tx.ReadBucket(invoiceBucket)
3✔
543
                if invoices == nil {
3✔
544
                        return invpkg.ErrNoInvoicesCreated
×
545
                }
×
546

547
                // Get the add index bucket which we will use to iterate through
548
                // our indexed invoices.
549
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
550
                if invoiceAddIndex == nil {
6✔
551
                        return invpkg.ErrNoInvoicesCreated
3✔
552
                }
3✔
553

554
                // Create a paginator which reads from our add index bucket with
555
                // the parameters provided by the invoice query.
556
                paginator := newPaginator(
3✔
557
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
3✔
558
                        q.NumMaxInvoices,
3✔
559
                )
3✔
560

3✔
561
                // accumulateInvoices looks up an invoice based on the index we
3✔
562
                // are given, adds it to our set of invoices if it has the right
3✔
563
                // characteristics for our query and returns the number of items
3✔
564
                // we have added to our set of invoices.
3✔
565
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
6✔
566
                        invoice, err := fetchInvoice(
3✔
567
                                indexValue, invoices, nil, false,
3✔
568
                        )
3✔
569
                        if err != nil {
3✔
570
                                return false, err
×
571
                        }
×
572

573
                        // Skip any settled or canceled invoices if the caller
574
                        // is only interested in pending ones.
575
                        if q.PendingOnly && !invoice.IsPending() {
3✔
UNCOV
576
                                return false, nil
×
UNCOV
577
                        }
×
578

579
                        // Get the creation time in Unix seconds, this always
580
                        // rounds down the nanoseconds to full seconds.
581
                        createTime := invoice.CreationDate.Unix()
3✔
582

3✔
583
                        // Skip any invoices that were created before the
3✔
584
                        // specified time.
3✔
585
                        if createTime < q.CreationDateStart {
6✔
586
                                return false, nil
3✔
587
                        }
3✔
588

589
                        // Skip any invoices that were created after the
590
                        // specified time.
591
                        if q.CreationDateEnd != 0 &&
3✔
592
                                createTime > q.CreationDateEnd {
6✔
593

3✔
594
                                return false, nil
3✔
595
                        }
3✔
596

597
                        // At this point, we've exhausted the offset, so we'll
598
                        // begin collecting invoices found within the range.
599
                        resp.Invoices = append(resp.Invoices, invoice)
3✔
600

3✔
601
                        return true, nil
3✔
602
                }
603

604
                // Query our paginator using accumulateInvoices to build up a
605
                // set of invoices.
606
                if err := paginator.query(accumulateInvoices); err != nil {
3✔
607
                        return err
×
608
                }
×
609

610
                // If we iterated through the add index in reverse order, then
611
                // we'll need to reverse the slice of invoices to return them in
612
                // forward order.
613
                if q.Reversed {
3✔
UNCOV
614
                        numInvoices := len(resp.Invoices)
×
UNCOV
615
                        for i := 0; i < numInvoices/2; i++ {
×
UNCOV
616
                                reverse := numInvoices - i - 1
×
UNCOV
617
                                resp.Invoices[i], resp.Invoices[reverse] =
×
UNCOV
618
                                        resp.Invoices[reverse], resp.Invoices[i]
×
UNCOV
619
                        }
×
620
                }
621

622
                return nil
3✔
623
        }, func() {
3✔
624
                resp = invpkg.InvoiceSlice{
3✔
625
                        InvoiceQuery: q,
3✔
626
                }
3✔
627
        })
3✔
628
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
3✔
629
                return resp, err
×
630
        }
×
631

632
        // Finally, record the indexes of the first and last invoices returned
633
        // so that the caller can resume from this point later on.
634
        if len(resp.Invoices) > 0 {
6✔
635
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
3✔
636
                lastIdx := len(resp.Invoices) - 1
3✔
637
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
3✔
638
        }
3✔
639

640
        return resp, nil
3✔
641
}
642

643
// UpdateInvoice attempts to update an invoice corresponding to the passed
644
// payment hash. If an invoice matching the passed payment hash doesn't exist
645
// within the database, then the action will fail with a "not found" error.
646
//
647
// The update is performed inside the same database transaction that fetches the
648
// invoice and is therefore atomic. The fields to update are controlled by the
649
// supplied callback.  When updating an invoice, the update itself happens
650
// in-memory on a copy of the invoice. Once it is written successfully to the
651
// database, the in-memory copy is returned to the caller.
652
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
653
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
654
        *invpkg.Invoice, error) {
3✔
655

3✔
656
        var updatedInvoice *invpkg.Invoice
3✔
657
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
658
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
3✔
659
                if err != nil {
3✔
660
                        return err
×
661
                }
×
662
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
663
                        invoiceIndexBucket,
3✔
664
                )
3✔
665
                if err != nil {
3✔
666
                        return err
×
667
                }
×
668
                settleIndex, err := invoices.CreateBucketIfNotExists(
3✔
669
                        settleIndexBucket,
3✔
670
                )
3✔
671
                if err != nil {
3✔
672
                        return err
×
673
                }
×
674
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
3✔
675
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
3✔
676

3✔
677
                // Retrieve the invoice number for this invoice using the
3✔
678
                // provided invoice reference.
3✔
679
                invoiceNum, err := fetchInvoiceNumByRef(
3✔
680
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
3✔
681
                )
3✔
682
                if err != nil {
3✔
UNCOV
683
                        return err
×
UNCOV
684
                }
×
685

686
                // setIDHint can also be nil here, which means all the HTLCs
687
                // for AMP invoices are fetched. If the blank setID is passed
688
                // in, then no HTLCs are fetched for the AMP invoice. If a
689
                // specific setID is passed in, then only the HTLCs for that
690
                // setID are fetched for a particular sub-AMP invoice.
691
                invoice, err := fetchInvoice(
3✔
692
                        invoiceNum, invoices, []*invpkg.SetID{setIDHint}, false,
3✔
693
                )
3✔
694
                if err != nil {
3✔
695
                        return err
×
696
                }
×
697

698
                now := d.clock.Now()
3✔
699
                updater := &kvInvoiceUpdater{
3✔
700
                        db:                d,
3✔
701
                        invoicesBucket:    invoices,
3✔
702
                        settleIndexBucket: settleIndex,
3✔
703
                        setIDIndexBucket:  setIDIndex,
3✔
704
                        updateTime:        now,
3✔
705
                        invoiceNum:        invoiceNum,
3✔
706
                        invoice:           &invoice,
3✔
707
                        updatedAmpHtlcs:   make(ampHTLCsMap),
3✔
708
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
3✔
709
                }
3✔
710

3✔
711
                payHash := ref.PayHash()
3✔
712
                updatedInvoice, err = invpkg.UpdateInvoice(
3✔
713
                        payHash, updater.invoice, now, callback, updater,
3✔
714
                )
3✔
715
                if err != nil {
6✔
716
                        return err
3✔
717
                }
3✔
718

719
                // If this is an AMP update, then limit the returned AMP state
720
                // to only the requested set ID.
721
                if setIDHint != nil {
6✔
722
                        filterInvoiceAMPState(updatedInvoice, setIDHint)
3✔
723
                }
3✔
724

725
                return nil
3✔
726
        }, func() {
3✔
727
                updatedInvoice = nil
3✔
728
        })
3✔
729

730
        return updatedInvoice, err
3✔
731
}
732

733
// filterInvoiceAMPState filters the AMP state of the invoice to only include
734
// state for the specified set IDs.
735
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
3✔
736
        filteredAMPState := make(invpkg.AMPInvoiceState)
3✔
737

3✔
738
        for _, setID := range setIDs {
6✔
739
                if setID == nil {
6✔
740
                        return
3✔
741
                }
3✔
742

743
                ampState, ok := invoice.AMPState[*setID]
3✔
744
                if ok {
6✔
745
                        filteredAMPState[*setID] = ampState
3✔
746
                }
3✔
747
        }
748

749
        invoice.AMPState = filteredAMPState
3✔
750
}
751

752
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
753
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
754

755
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
756
// is used with the kv implementation of the invoice database. Note that this
757
// updater is not concurrency safe and synchronizaton is expected to be handled
758
// on the DB level.
759
type kvInvoiceUpdater struct {
760
        db                *DB
761
        invoicesBucket    kvdb.RwBucket
762
        settleIndexBucket kvdb.RwBucket
763
        setIDIndexBucket  kvdb.RwBucket
764

765
        // updateTime is the timestamp for the update.
766
        updateTime time.Time
767

768
        // invoiceNum is a legacy key similar to the add index that is used
769
        // only in the kv implementation.
770
        invoiceNum []byte
771

772
        // invoice is the invoice that we're updating. As a side effect of the
773
        // update this invoice will be mutated.
774
        invoice *invpkg.Invoice
775

776
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
777
        // cancelled as part of this update.
778
        updatedAmpHtlcs ampHTLCsMap
779

780
        // settledSetIDs holds the set IDs that are settled with this update.
781
        settledSetIDs map[invpkg.SetID]struct{}
782
}
783

784
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
785
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
786
        _ *invpkg.InvoiceHTLC) error {
3✔
787

3✔
788
        return nil
3✔
789
}
3✔
790

791
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
792
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
793
        _ time.Time) error {
3✔
794

3✔
795
        return nil
3✔
796
}
3✔
797

798
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
799
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
800
        _ lntypes.Preimage) error {
3✔
801

3✔
802
        return nil
3✔
803
}
3✔
804

805
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
806
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
807
        _ *lntypes.Preimage) error {
3✔
808

3✔
809
        return nil
3✔
810
}
3✔
811

812
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
813
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
3✔
814
        return nil
3✔
815
}
3✔
816

817
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
818
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
819
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
3✔
820

3✔
821
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
6✔
822
                switch state.State {
3✔
823
                case invpkg.HtlcStateAccepted:
3✔
824
                        // If we're just now creating the HTLCs for this set
3✔
825
                        // then we'll also pull in the existing HTLCs that are
3✔
826
                        // part of this set, so we can write them all to disk
3✔
827
                        // together (same value)
3✔
828
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
3✔
829
                                &setID, invpkg.HtlcStateAccepted,
3✔
830
                        )
3✔
831

UNCOV
832
                case invpkg.HtlcStateCanceled:
×
UNCOV
833
                        // Only HTLCs in the accepted state, can be cancelled,
×
UNCOV
834
                        // but we also want to merge that with HTLCs that may be
×
UNCOV
835
                        // canceled as well since it can be cancelled one by
×
UNCOV
836
                        // one.
×
UNCOV
837
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
×
UNCOV
838
                                &setID, invpkg.HtlcStateAccepted,
×
UNCOV
839
                        )
×
UNCOV
840

×
UNCOV
841
                        cancelledHtlcs := k.invoice.HTLCSet(
×
UNCOV
842
                                &setID, invpkg.HtlcStateCanceled,
×
UNCOV
843
                        )
×
UNCOV
844
                        maps.Copy(k.updatedAmpHtlcs[setID], cancelledHtlcs)
×
845

UNCOV
846
                case invpkg.HtlcStateSettled:
×
UNCOV
847
                        k.updatedAmpHtlcs[setID] = make(
×
UNCOV
848
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
×
UNCOV
849
                        )
×
850
                }
851
        }
852

853
        if state.State == invpkg.HtlcStateSettled {
6✔
854
                // Add the set ID to the set that was settled in this invoice
3✔
855
                // update. We'll use this later to update the settle index.
3✔
856
                k.settledSetIDs[setID] = struct{}{}
3✔
857
        }
3✔
858

859
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
3✔
860

3✔
861
        return nil
3✔
862
}
863

864
// Finalize finalizes the update before it is written to the database.
865
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
3✔
866
        switch updateType {
3✔
867
        case invpkg.AddHTLCsUpdate:
3✔
868
                return k.storeAddHtlcsUpdate()
3✔
869

870
        case invpkg.CancelHTLCsUpdate:
3✔
871
                return k.storeCancelHtlcsUpdate()
3✔
872

873
        case invpkg.SettleHodlInvoiceUpdate:
3✔
874
                return k.storeSettleHodlInvoiceUpdate()
3✔
875

876
        case invpkg.CancelInvoiceUpdate:
3✔
877
                // Persist all changes which where made when cancelling the
3✔
878
                // invoice. All HTLCs which were accepted are now canceled, so
3✔
879
                // we persist this state.
3✔
880
                return k.storeCancelHtlcsUpdate()
3✔
881
        }
882

883
        return fmt.Errorf("unknown update type: %v", updateType)
×
884
}
885

886
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
887
// set of HTLCs.
888
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
3✔
889
        err := k.serializeAndStoreInvoice()
3✔
890
        if err != nil {
3✔
891
                return err
×
892
        }
×
893

894
        // If this is an AMP invoice, then we'll actually store the rest
895
        // of the HTLCs in-line with the invoice, using the invoice ID
896
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
897
        // setID.
898
        if k.invoice.IsAMP() {
3✔
UNCOV
899
                return k.updateAMPInvoices()
×
UNCOV
900
        }
×
901

902
        return nil
3✔
903
}
904

905
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
906
// HTLCs.
907
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
3✔
908
        invoiceIsAMP := k.invoice.IsAMP()
3✔
909

3✔
910
        for htlcSetID := range k.updatedAmpHtlcs {
6✔
911
                // Check if this SetID already exist.
3✔
912
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
3✔
913

3✔
914
                if setIDInvNum == nil {
6✔
915
                        err := k.setIDIndexBucket.Put(
3✔
916
                                htlcSetID[:], k.invoiceNum,
3✔
917
                        )
3✔
918
                        if err != nil {
3✔
919
                                return err
×
920
                        }
×
921
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
3✔
UNCOV
922
                        return invpkg.ErrDuplicateSetID{
×
UNCOV
923
                                SetID: htlcSetID,
×
UNCOV
924
                        }
×
UNCOV
925
                }
×
926
        }
927

928
        // If this is a non-AMP invoice, then the state can eventually go to
929
        // ContractSettled, so we pass in nil value as part of
930
        // setSettleMetaFields.
931
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
6✔
932
                err := k.setSettleMetaFields(nil)
3✔
933
                if err != nil {
3✔
934
                        return err
×
935
                }
×
936
        }
937

938
        // As we don't update the settle index above for AMP invoices, we'll do
939
        // it here for each sub-AMP invoice that was settled.
940
        for settledSetID := range k.settledSetIDs {
6✔
941
                settledSetID := settledSetID
3✔
942
                err := k.setSettleMetaFields(&settledSetID)
3✔
943
                if err != nil {
3✔
944
                        return err
×
945
                }
×
946
        }
947

948
        err := k.serializeAndStoreInvoice()
3✔
949
        if err != nil {
3✔
950
                return err
×
951
        }
×
952

953
        // If this is an AMP invoice, then we'll actually store the rest of the
954
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
955
        // and the AMP key as a suffix: invoiceNum || setID.
956
        if invoiceIsAMP {
6✔
957
                return k.updateAMPInvoices()
3✔
958
        }
3✔
959

960
        return nil
3✔
961
}
962

963
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
964
// settling a hodl invoice.
965
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
3✔
966
        err := k.setSettleMetaFields(nil)
3✔
967
        if err != nil {
3✔
968
                return err
×
969
        }
×
970

971
        return k.serializeAndStoreInvoice()
3✔
972
}
973

974
// setSettleMetaFields updates the metadata associated with settlement of an
975
// invoice. If a non-nil setID is passed in, then the value will be append to
976
// the invoice number as well, in order to allow us to detect repeated payments
977
// to the same AMP invoices "across time".
978
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
3✔
979
        // Now that we know the invoice hasn't already been settled, we'll
3✔
980
        // update the settle index so we can place this settle event in the
3✔
981
        // proper location within our time series.
3✔
982
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
3✔
983
        if err != nil {
3✔
984
                return err
×
985
        }
×
986

987
        // Make a new byte array on the stack that can potentially store the 4
988
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
989
        // here which is the number of bytes copied so we can only store the 4
990
        // bytes if this is a non-AMP invoice.
991
        var indexKey [invoiceSetIDKeyLen]byte
3✔
992
        valueLen := copy(indexKey[:], k.invoiceNum)
3✔
993

3✔
994
        if setID != nil {
6✔
995
                valueLen += copy(indexKey[valueLen:], setID[:])
3✔
996
        }
3✔
997

998
        var seqNoBytes [8]byte
3✔
999
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
3✔
1000
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
3✔
1001
        if err != nil {
3✔
1002
                return err
×
1003
        }
×
1004

1005
        // If the setID is nil, then this means that this is a non-AMP settle,
1006
        // so we'll update the invoice settle index directly.
1007
        if setID == nil {
6✔
1008
                k.invoice.SettleDate = k.updateTime
3✔
1009
                k.invoice.SettleIndex = nextSettleSeqNo
3✔
1010
        } else {
6✔
1011
                // If the set ID isn't blank, we'll update the AMP state map
3✔
1012
                // which tracks when each of the setIDs associated with a given
3✔
1013
                // AMP invoice are settled.
3✔
1014
                ampState := k.invoice.AMPState[*setID]
3✔
1015

3✔
1016
                ampState.SettleDate = k.updateTime
3✔
1017
                ampState.SettleIndex = nextSettleSeqNo
3✔
1018

3✔
1019
                k.invoice.AMPState[*setID] = ampState
3✔
1020
        }
3✔
1021

1022
        return nil
3✔
1023
}
1024

1025
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
1026
// then continually write the invoices to the end of the invoice value, we
1027
// instead write the invoices into a new key preifx that follows the main
1028
// invoice number. This ensures that we don't need to continually decode a
1029
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
1030
// associated with a particular HTLC set.
1031
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
3✔
1032
        for setID, htlcSet := range k.updatedAmpHtlcs {
6✔
1033
                // First write out the set of HTLCs including all the relevant
3✔
1034
                // TLV values.
3✔
1035
                var b bytes.Buffer
3✔
1036
                if err := serializeHtlcs(&b, htlcSet); err != nil {
3✔
1037
                        return err
×
1038
                }
×
1039

1040
                // Next store each HTLC in-line, using a prefix based off the
1041
                // invoice number.
1042
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
3✔
1043

3✔
1044
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
3✔
1045
                if err != nil {
3✔
1046
                        return err
×
1047
                }
×
1048
        }
1049

1050
        return nil
3✔
1051
}
1052

1053
// serializeAndStoreInvoice is a helper function used to store invoices.
1054
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
3✔
1055
        var buf bytes.Buffer
3✔
1056
        if err := serializeInvoice(&buf, k.invoice); err != nil {
3✔
1057
                return err
×
1058
        }
×
1059

1060
        return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
3✔
1061
}
1062

1063
// InvoicesSettledSince can be used by callers to catch up any settled invoices
1064
// they missed within the settled invoice time series. We'll return all known
1065
// settled invoice that have a settle index higher than the passed
1066
// sinceSettleIndex.
1067
//
1068
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1069
// value below the starting index value is a noop.
1070
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1071
        []invpkg.Invoice, error) {
3✔
1072

3✔
1073
        var (
3✔
1074
                settledInvoices []invpkg.Invoice
3✔
1075
                start           = time.Now()
3✔
1076
                lastLogTime     = time.Now()
3✔
1077
                processedCount  int
3✔
1078
        )
3✔
1079

3✔
1080
        // If an index of zero was specified, then in order to maintain
3✔
1081
        // backwards compat, we won't send out any new invoices.
3✔
1082
        if sinceSettleIndex == 0 {
6✔
1083
                return settledInvoices, nil
3✔
1084
        }
3✔
1085

1086
        var startIndex [8]byte
3✔
1087
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
3✔
1088

3✔
1089
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
1090
                invoices := tx.ReadBucket(invoiceBucket)
3✔
1091
                if invoices == nil {
3✔
1092
                        return nil
×
1093
                }
×
1094

1095
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
3✔
1096
                if settleIndex == nil {
3✔
1097
                        return nil
×
1098
                }
×
1099

1100
                // We'll now run through each entry in the add index starting
1101
                // at our starting index. We'll continue until we reach the
1102
                // very end of the current key space.
1103
                invoiceCursor := settleIndex.ReadCursor()
3✔
1104

3✔
1105
                // We'll seek to the starting index, then manually advance the
3✔
1106
                // cursor in order to skip the entry with the since add index.
3✔
1107
                invoiceCursor.Seek(startIndex[:])
3✔
1108
                seqNo, indexValue := invoiceCursor.Next()
3✔
1109

3✔
1110
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
6✔
1111
                        // Depending on the length of the index value, this may
3✔
1112
                        // or may not be an AMP invoice, so we'll extract the
3✔
1113
                        // invoice value into two components: the invoice num,
3✔
1114
                        // and the setID (may not be there).
3✔
1115
                        var (
3✔
1116
                                invoiceKey [4]byte
3✔
1117
                                setID      *invpkg.SetID
3✔
1118
                        )
3✔
1119

3✔
1120
                        valueLen := copy(invoiceKey[:], indexValue)
3✔
1121
                        if len(indexValue) == invoiceSetIDKeyLen {
6✔
1122
                                setID = new(invpkg.SetID)
3✔
1123
                                copy(setID[:], indexValue[valueLen:])
3✔
1124
                        }
3✔
1125

1126
                        // For each key found, we'll look up the actual
1127
                        // invoice, then accumulate it into our return value.
1128
                        invoice, err := fetchInvoice(
3✔
1129
                                invoiceKey[:], invoices, []*invpkg.SetID{setID},
3✔
1130
                                true,
3✔
1131
                        )
3✔
1132
                        if err != nil {
3✔
1133
                                return err
×
1134
                        }
×
1135

1136
                        settledInvoices = append(settledInvoices, invoice)
3✔
1137

3✔
1138
                        processedCount++
3✔
1139
                        if time.Since(lastLogTime) >=
3✔
1140
                                invoiceProgressLogInterval {
3✔
1141

×
1142
                                log.Debugf("Processed %d settled invoices "+
×
1143
                                        "which have a settle index greater "+
×
1144
                                        "than %v", processedCount,
×
1145
                                        sinceSettleIndex)
×
1146

×
1147
                                lastLogTime = time.Now()
×
1148
                        }
×
1149
                }
1150

1151
                return nil
3✔
1152
        }, func() {
3✔
1153
                settledInvoices = nil
3✔
1154
        })
3✔
1155
        if err != nil {
3✔
1156
                return nil, err
×
1157
        }
×
1158

1159
        elapsed := time.Since(start)
3✔
1160
        log.Debugf("Completed scanning for settled invoices starting at "+
3✔
1161
                "index %v: total_processed=%d, found_invoices=%d, elapsed=%v",
3✔
1162
                sinceSettleIndex, processedCount, len(settledInvoices),
3✔
1163
                elapsed.Round(time.Millisecond))
3✔
1164

3✔
1165
        return settledInvoices, nil
3✔
1166
}
1167

1168
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
1169
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
1170
        uint64, error) {
3✔
1171

3✔
1172
        // Create the invoice key which is just the big-endian representation
3✔
1173
        // of the invoice number.
3✔
1174
        var invoiceKey [4]byte
3✔
1175
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
3✔
1176

3✔
1177
        // Increment the num invoice counter index so the next invoice bares
3✔
1178
        // the proper ID.
3✔
1179
        var scratch [4]byte
3✔
1180
        invoiceCounter := invoiceNum + 1
3✔
1181
        byteOrder.PutUint32(scratch[:], invoiceCounter)
3✔
1182
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
3✔
1183
                return 0, err
×
1184
        }
×
1185

1186
        // Add the payment hash to the invoice index. This will let us quickly
1187
        // identify if we can settle an incoming payment, and also to possibly
1188
        // allow a single invoice to have multiple payment installations.
1189
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
3✔
1190
        if err != nil {
3✔
1191
                return 0, err
×
1192
        }
×
1193

1194
        // Add the invoice to the payment address index, but only if the invoice
1195
        // has a non-zero payment address. The all-zero payment address is still
1196
        // in use by legacy keysend, so we special-case here to avoid
1197
        // collisions.
1198
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
6✔
1199
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
3✔
1200
                if err != nil {
3✔
1201
                        return 0, err
×
1202
                }
×
1203
        }
1204

1205
        // Next, we'll obtain the next add invoice index (sequence
1206
        // number), so we can properly place this invoice within this
1207
        // event stream.
1208
        nextAddSeqNo, err := addIndex.NextSequence()
3✔
1209
        if err != nil {
3✔
1210
                return 0, err
×
1211
        }
×
1212

1213
        // With the next sequence obtained, we'll updating the event series in
1214
        // the add index bucket to map this current add counter to the index of
1215
        // this new invoice.
1216
        var seqNoBytes [8]byte
3✔
1217
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
3✔
1218
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
3✔
1219
                return 0, err
×
1220
        }
×
1221

1222
        i.AddIndex = nextAddSeqNo
3✔
1223

3✔
1224
        // Finally, serialize the invoice itself to be written to the disk.
3✔
1225
        var buf bytes.Buffer
3✔
1226
        if err := serializeInvoice(&buf, i); err != nil {
3✔
1227
                return 0, err
×
1228
        }
×
1229

1230
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
3✔
1231
                return 0, err
×
1232
        }
×
1233

1234
        return nextAddSeqNo, nil
3✔
1235
}
1236

1237
// recordSize returns the amount of bytes this TLV record will occupy when
1238
// encoded.
1239
func ampRecordSize(a *invpkg.AMPInvoiceState) func() (uint64, error) {
3✔
1240
        var (
3✔
1241
                b   bytes.Buffer
3✔
1242
                buf [8]byte
3✔
1243
        )
3✔
1244

3✔
1245
        // We know that encoding works since the tests pass in the build this
3✔
1246
        // file is checked into, so we'll simplify things and simply encode it
3✔
1247
        // ourselves then report the total amount of bytes used.
3✔
1248
        if err := ampStateEncoder(&b, a, &buf); err != nil {
3✔
NEW
1249
                return func() (uint64, error) {
×
NEW
1250
                        return 0, err
×
NEW
1251
                }
×
1252
        }
1253

1254
        return func() (uint64, error) {
6✔
1255
                return uint64(len(b.Bytes())), nil
3✔
1256
        }
3✔
1257
}
1258

1259
// serializeInvoice serializes an invoice to a writer.
1260
//
1261
// Note: this function is in use for a migration. Before making changes that
1262
// would modify the on disk format, make a copy of the original code and store
1263
// it with the migration.
1264
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
3✔
1265
        creationDateBytes, err := i.CreationDate.MarshalBinary()
3✔
1266
        if err != nil {
3✔
1267
                return err
×
1268
        }
×
1269

1270
        settleDateBytes, err := i.SettleDate.MarshalBinary()
3✔
1271
        if err != nil {
3✔
1272
                return err
×
1273
        }
×
1274

1275
        var fb bytes.Buffer
3✔
1276
        err = i.Terms.Features.EncodeBase256(&fb)
3✔
1277
        if err != nil {
3✔
1278
                return err
×
1279
        }
×
1280
        featureBytes := fb.Bytes()
3✔
1281

3✔
1282
        preimage := [32]byte(invpkg.UnknownPreimage)
3✔
1283
        if i.Terms.PaymentPreimage != nil {
6✔
1284
                preimage = *i.Terms.PaymentPreimage
3✔
1285
                if preimage == invpkg.UnknownPreimage {
3✔
1286
                        return errors.New("cannot use all-zeroes preimage")
×
1287
                }
×
1288
        }
1289
        value := uint64(i.Terms.Value)
3✔
1290
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
3✔
1291
        expiry := uint64(i.Terms.Expiry)
3✔
1292

3✔
1293
        amtPaid := uint64(i.AmtPaid)
3✔
1294
        state := uint8(i.State)
3✔
1295

3✔
1296
        var hodlInvoice uint8
3✔
1297
        if i.HodlInvoice {
6✔
1298
                hodlInvoice = 1
3✔
1299
        }
3✔
1300

1301
        tlvStream, err := tlv.NewStream(
3✔
1302
                // Memo and payreq.
3✔
1303
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1304
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3✔
1305

3✔
1306
                // Add/settle metadata.
3✔
1307
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3✔
1308
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1309
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1310
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1311

3✔
1312
                // Terms.
3✔
1313
                tlv.MakePrimitiveRecord(preimageType, &preimage),
3✔
1314
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1315
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1316
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1317
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1318
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1319

3✔
1320
                // Invoice state.
3✔
1321
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1322
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1323

3✔
1324
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1325

3✔
1326
                // Invoice AMP state.
3✔
1327
                tlv.MakeDynamicRecord(
3✔
1328
                        invoiceAmpStateType, &i.AMPState,
3✔
1329
                        ampRecordSize(&i.AMPState),
3✔
1330
                        ampStateEncoder, ampStateDecoder,
3✔
1331
                ),
3✔
1332
        )
3✔
1333
        if err != nil {
3✔
1334
                return err
×
1335
        }
×
1336

1337
        var b bytes.Buffer
3✔
1338
        if err = tlvStream.Encode(&b); err != nil {
3✔
1339
                return err
×
1340
        }
×
1341

1342
        err = binary.Write(w, byteOrder, uint64(b.Len()))
3✔
1343
        if err != nil {
3✔
1344
                return err
×
1345
        }
×
1346

1347
        if _, err = w.Write(b.Bytes()); err != nil {
3✔
1348
                return err
×
1349
        }
×
1350

1351
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
1352
        // in-line with the rest of the invoice.
1353
        if i.IsAMP() {
6✔
1354
                return nil
3✔
1355
        }
3✔
1356

1357
        return serializeHtlcs(w, i.Htlcs)
3✔
1358
}
1359

1360
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
1361
// a writer.
1362
func serializeHtlcs(w io.Writer,
1363
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
3✔
1364

3✔
1365
        for key, htlc := range htlcs {
6✔
1366
                // Encode the htlc in a tlv stream.
3✔
1367
                chanID := key.ChanID.ToUint64()
3✔
1368
                amt := uint64(htlc.Amt)
3✔
1369
                mppTotalAmt := uint64(htlc.MppTotalAmt)
3✔
1370
                acceptTime := putNanoTime(htlc.AcceptTime)
3✔
1371
                resolveTime := putNanoTime(htlc.ResolveTime)
3✔
1372
                state := uint8(htlc.State)
3✔
1373

3✔
1374
                var records []tlv.Record
3✔
1375
                records = append(records,
3✔
1376
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
1377
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
1378
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
1379
                        tlv.MakePrimitiveRecord(
3✔
1380
                                acceptHeightType, &htlc.AcceptHeight,
3✔
1381
                        ),
3✔
1382
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
1383
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
1384
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
1385
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
1386
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
1387
                )
3✔
1388

3✔
1389
                if htlc.AMP != nil {
6✔
1390
                        setIDRecord := tlv.MakeDynamicRecord(
3✔
1391
                                htlcAMPType, &htlc.AMP.Record,
3✔
1392
                                func() (uint64, error) {
6✔
1393
                                        return htlc.AMP.Record.PayloadSize(),
3✔
1394
                                                nil
3✔
1395
                                },
3✔
1396
                                record.AMPEncoder, record.AMPDecoder,
1397
                        )
1398
                        records = append(records, setIDRecord)
3✔
1399

3✔
1400
                        hash32 := [32]byte(htlc.AMP.Hash)
3✔
1401
                        hashRecord := tlv.MakePrimitiveRecord(
3✔
1402
                                htlcHashType, &hash32,
3✔
1403
                        )
3✔
1404
                        records = append(records, hashRecord)
3✔
1405

3✔
1406
                        if htlc.AMP.Preimage != nil {
6✔
1407
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
3✔
1408
                                preimageRecord := tlv.MakePrimitiveRecord(
3✔
1409
                                        htlcPreimageType, &preimage32,
3✔
1410
                                )
3✔
1411
                                records = append(records, preimageRecord)
3✔
1412
                        }
3✔
1413
                }
1414

1415
                // Convert the custom records to tlv.Record types that are ready
1416
                // for serialization.
1417
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
3✔
1418

3✔
1419
                // Append the custom records. Their ids are in the experimental
3✔
1420
                // range and sorted, so there is no need to sort again.
3✔
1421
                records = append(records, customRecords...)
3✔
1422

3✔
1423
                tlvStream, err := tlv.NewStream(records...)
3✔
1424
                if err != nil {
3✔
1425
                        return err
×
1426
                }
×
1427

1428
                var b bytes.Buffer
3✔
1429
                if err := tlvStream.Encode(&b); err != nil {
3✔
1430
                        return err
×
1431
                }
×
1432

1433
                // Write the length of the tlv stream followed by the stream
1434
                // bytes.
1435
                err = binary.Write(w, byteOrder, uint64(b.Len()))
3✔
1436
                if err != nil {
3✔
1437
                        return err
×
1438
                }
×
1439

1440
                if _, err := w.Write(b.Bytes()); err != nil {
3✔
1441
                        return err
×
1442
                }
×
1443
        }
1444

1445
        return nil
3✔
1446
}
1447

1448
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
1449
// timestamp will be mapped to 0, since calling UnixNano in that case is
1450
// undefined.
1451
func putNanoTime(t time.Time) uint64 {
3✔
1452
        if t.IsZero() {
6✔
1453
                return 0
3✔
1454
        }
3✔
1455
        return uint64(t.UnixNano())
3✔
1456
}
1457

1458
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
1459
// is provided, an zero-value time stamp is returned.
1460
func getNanoTime(ns uint64) time.Time {
3✔
1461
        if ns == 0 {
6✔
1462
                return time.Time{}
3✔
1463
        }
3✔
1464
        return time.Unix(0, int64(ns))
3✔
1465
}
1466

1467
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
1468
// identified by the setID value.
1469
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1470
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1471
        error) {
3✔
1472

3✔
1473
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1474
        for _, setID := range setIDs {
6✔
1475
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
3✔
1476

3✔
1477
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
3✔
1478
                if htlcSetBytes == nil {
6✔
1479
                        // A set ID was passed in, but we don't have this
3✔
1480
                        // stored yet, meaning that the setID is being added
3✔
1481
                        // for the first time.
3✔
1482
                        return htlcs, invpkg.ErrInvoiceNotFound
3✔
1483
                }
3✔
1484

1485
                htlcSetReader := bytes.NewReader(htlcSetBytes)
3✔
1486
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
3✔
1487
                if err != nil {
3✔
1488
                        return nil, err
×
1489
                }
×
1490

1491
                maps.Copy(htlcs, htlcsBySetID)
3✔
1492
        }
1493

1494
        return htlcs, nil
3✔
1495
}
1496

1497
// forEachAMPInvoice is a helper function that attempts to iterate over each of
1498
// the HTLC sets (based on their set ID) for the given AMP invoice identified
1499
// by its invoiceNum. The callback closure is called for each key within the
1500
// prefix range.
1501
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1502
        callback func(key, htlcSet []byte) error) error {
3✔
1503

3✔
1504
        invoiceCursor := invoiceBucket.ReadCursor()
3✔
1505

3✔
1506
        // Seek to the first key that includes the invoice data itself.
3✔
1507
        invoiceCursor.Seek(invoiceNum)
3✔
1508

3✔
1509
        // Advance to the very first key _after_ the invoice data, as this is
3✔
1510
        // where we'll encounter our first HTLC (if any are present).
3✔
1511
        cursorKey, htlcSet := invoiceCursor.Next()
3✔
1512

3✔
1513
        // If at this point, the cursor key doesn't match the invoice num
3✔
1514
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
3✔
1515
        // associated with it.
3✔
1516
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
6✔
1517
                return nil
3✔
1518
        }
3✔
1519

1520
        // Otherwise continue to iterate until we no longer match the prefix,
1521
        // executing the call back at each step.
1522
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
6✔
1523
                err := callback(cursorKey, htlcSet)
3✔
1524
                if err != nil {
3✔
1525
                        return err
×
1526
                }
×
1527
        }
1528

1529
        return nil
3✔
1530
}
1531

1532
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
1533
// AMP bucket to find all the individual HTLCs (by setID) associated with a
1534
// given invoice. If a list of set IDs are specified, then only HTLCs
1535
// associated with that setID will be retrieved.
1536
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1537
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1538
        error) {
3✔
1539

3✔
1540
        // If a set of setIDs was specified, then we can skip the cursor and
3✔
1541
        // just read out exactly what we need.
3✔
1542
        if len(setIDs) != 0 && setIDs[0] != nil {
6✔
1543
                return fetchFilteredAmpInvoices(
3✔
1544
                        invoiceBucket, invoiceNum, setIDs...,
3✔
1545
                )
3✔
1546
        }
3✔
1547

1548
        // Otherwise, iterate over all the htlc sets that are prefixed beside
1549
        // this invoice in the main invoice bucket.
1550
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1551
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
3✔
1552
                func(key, htlcSet []byte) error {
6✔
1553
                        htlcSetReader := bytes.NewReader(htlcSet)
3✔
1554
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
3✔
1555
                        if err != nil {
3✔
1556
                                return err
×
1557
                        }
×
1558

1559
                        maps.Copy(htlcs, htlcsBySetID)
3✔
1560

3✔
1561
                        return nil
3✔
1562
                },
1563
        )
1564

1565
        if err != nil {
3✔
1566
                return nil, err
×
1567
        }
×
1568

1569
        return htlcs, nil
3✔
1570
}
1571

1572
// fetchInvoice attempts to read out the relevant state for the invoice as
1573
// specified by the invoice number. If the setID fields are set, then only the
1574
// HTLC information pertaining to those set IDs is returned.
1575
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1576
        setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
3✔
1577

3✔
1578
        invoiceBytes := invoices.Get(invoiceNum)
3✔
1579
        if invoiceBytes == nil {
3✔
1580
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
×
1581
        }
×
1582

1583
        invoiceReader := bytes.NewReader(invoiceBytes)
3✔
1584

3✔
1585
        invoice, err := deserializeInvoice(invoiceReader)
3✔
1586
        if err != nil {
3✔
1587
                return invpkg.Invoice{}, err
×
1588
        }
×
1589

1590
        // If this is an AMP invoice we'll also attempt to read out the set of
1591
        // HTLCs that were paid to prior set IDs, if needed.
1592
        if !invoice.IsAMP() {
6✔
1593
                return invoice, nil
3✔
1594
        }
3✔
1595

1596
        if shouldFetchAMPHTLCs(invoice, setIDs) {
6✔
1597
                invoice.Htlcs, err = fetchAmpSubInvoices(
3✔
1598
                        invoices, invoiceNum, setIDs...,
3✔
1599
                )
3✔
1600
                // TODO(positiveblue): we should fail when we are not able to
3✔
1601
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
3✔
1602
                // the invoice and channeldb package break if we return this
3✔
1603
                // error. We need to update them when we migrate this logic to
3✔
1604
                // the sql implementation.
3✔
1605
                if err != nil {
6✔
1606
                        log.Errorf("unable to fetch amp htlcs for inv "+
3✔
1607
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
3✔
1608
                }
3✔
1609

1610
                if filterAMPState {
6✔
1611
                        filterInvoiceAMPState(&invoice, setIDs...)
3✔
1612
                }
3✔
1613
        }
1614

1615
        return invoice, nil
3✔
1616
}
1617

1618
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
1619
// were paid to the relevant set IDs.
1620
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
3✔
1621
        // For AMP invoice that already have HTLCs populated (created before
3✔
1622
        // recurring invoices), then we don't need to read from the prefix
3✔
1623
        // keyed section of the bucket.
3✔
1624
        if len(invoice.Htlcs) != 0 {
3✔
1625
                return false
×
1626
        }
×
1627

1628
        // If the "zero" setID was specified, then this means that no HTLC data
1629
        // should be returned alongside of it.
1630
        if len(setIDs) != 0 && setIDs[0] != nil &&
3✔
1631
                *setIDs[0] == invpkg.BlankPayAddr {
6✔
1632

3✔
1633
                return false
3✔
1634
        }
3✔
1635

1636
        return true
3✔
1637
}
1638

1639
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
1640
// an AMP invoice. This methods only decode the relevant state vs the entire
1641
// invoice.
1642
func fetchInvoiceStateAMP(invoiceNum []byte,
UNCOV
1643
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
×
UNCOV
1644

×
UNCOV
1645
        // Fetch the raw invoice bytes.
×
UNCOV
1646
        invoiceBytes := invoices.Get(invoiceNum)
×
UNCOV
1647
        if invoiceBytes == nil {
×
1648
                return nil, invpkg.ErrInvoiceNotFound
×
1649
        }
×
1650

UNCOV
1651
        r := bytes.NewReader(invoiceBytes)
×
UNCOV
1652

×
UNCOV
1653
        var bodyLen int64
×
UNCOV
1654
        err := binary.Read(r, byteOrder, &bodyLen)
×
UNCOV
1655
        if err != nil {
×
1656
                return nil, err
×
1657
        }
×
1658

1659
        // Next, we'll make a new TLV stream that only attempts to decode the
1660
        // bytes we actually need.
UNCOV
1661
        ampState := make(invpkg.AMPInvoiceState)
×
UNCOV
1662
        tlvStream, err := tlv.NewStream(
×
UNCOV
1663
                // Invoice AMP state.
×
UNCOV
1664
                tlv.MakeDynamicRecord(
×
UNCOV
1665
                        invoiceAmpStateType, &ampState, nil,
×
UNCOV
1666
                        ampStateEncoder, ampStateDecoder,
×
UNCOV
1667
                ),
×
UNCOV
1668
        )
×
UNCOV
1669
        if err != nil {
×
1670
                return nil, err
×
1671
        }
×
1672

UNCOV
1673
        invoiceReader := io.LimitReader(r, bodyLen)
×
UNCOV
1674
        if err = tlvStream.Decode(invoiceReader); err != nil {
×
1675
                return nil, err
×
1676
        }
×
1677

UNCOV
1678
        return ampState, nil
×
1679
}
1680

1681
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
3✔
1682
        var (
3✔
1683
                preimageBytes [32]byte
3✔
1684
                value         uint64
3✔
1685
                cltvDelta     uint32
3✔
1686
                expiry        uint64
3✔
1687
                amtPaid       uint64
3✔
1688
                state         uint8
3✔
1689
                hodlInvoice   uint8
3✔
1690

3✔
1691
                creationDateBytes []byte
3✔
1692
                settleDateBytes   []byte
3✔
1693
                featureBytes      []byte
3✔
1694
        )
3✔
1695

3✔
1696
        var i invpkg.Invoice
3✔
1697
        i.AMPState = make(invpkg.AMPInvoiceState)
3✔
1698
        tlvStream, err := tlv.NewStream(
3✔
1699
                // Memo and payreq.
3✔
1700
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1701
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3✔
1702

3✔
1703
                // Add/settle metadata.
3✔
1704
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3✔
1705
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1706
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1707
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1708

3✔
1709
                // Terms.
3✔
1710
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
3✔
1711
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1712
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1713
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1714
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1715
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1716

3✔
1717
                // Invoice state.
3✔
1718
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1719
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1720

3✔
1721
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1722

3✔
1723
                // Invoice AMP state.
3✔
1724
                tlv.MakeDynamicRecord(
3✔
1725
                        invoiceAmpStateType, &i.AMPState, nil,
3✔
1726
                        ampStateEncoder, ampStateDecoder,
3✔
1727
                ),
3✔
1728
        )
3✔
1729
        if err != nil {
3✔
1730
                return i, err
×
1731
        }
×
1732

1733
        var bodyLen int64
3✔
1734
        err = binary.Read(r, byteOrder, &bodyLen)
3✔
1735
        if err != nil {
3✔
1736
                return i, err
×
1737
        }
×
1738

1739
        lr := io.LimitReader(r, bodyLen)
3✔
1740
        if err = tlvStream.Decode(lr); err != nil {
3✔
1741
                return i, err
×
1742
        }
×
1743

1744
        preimage := lntypes.Preimage(preimageBytes)
3✔
1745
        if preimage != invpkg.UnknownPreimage {
6✔
1746
                i.Terms.PaymentPreimage = &preimage
3✔
1747
        }
3✔
1748

1749
        i.Terms.Value = lnwire.MilliSatoshi(value)
3✔
1750
        i.Terms.FinalCltvDelta = int32(cltvDelta)
3✔
1751
        i.Terms.Expiry = time.Duration(expiry)
3✔
1752
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
3✔
1753
        i.State = invpkg.ContractState(state)
3✔
1754

3✔
1755
        if hodlInvoice != 0 {
6✔
1756
                i.HodlInvoice = true
3✔
1757
        }
3✔
1758

1759
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
3✔
1760
        if err != nil {
3✔
1761
                return i, err
×
1762
        }
×
1763

1764
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
3✔
1765
        if err != nil {
3✔
1766
                return i, err
×
1767
        }
×
1768

1769
        rawFeatures := lnwire.NewRawFeatureVector()
3✔
1770
        err = rawFeatures.DecodeBase256(
3✔
1771
                bytes.NewReader(featureBytes), len(featureBytes),
3✔
1772
        )
3✔
1773
        if err != nil {
3✔
1774
                return i, err
×
1775
        }
×
1776

1777
        i.Terms.Features = lnwire.NewFeatureVector(
3✔
1778
                rawFeatures, lnwire.Features,
3✔
1779
        )
3✔
1780

3✔
1781
        i.Htlcs, err = deserializeHtlcs(r)
3✔
1782
        return i, err
3✔
1783
}
1784

1785
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1786
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
6✔
1787
                // We encode the set of circuit keys as a varint length prefix.
3✔
1788
                // followed by a series of fixed sized uint8 integers.
3✔
1789
                numKeys := uint64(len(*v))
3✔
1790

3✔
1791
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
3✔
1792
                        return err
×
1793
                }
×
1794

1795
                for key := range *v {
6✔
1796
                        scidInt := key.ChanID.ToUint64()
3✔
1797

3✔
1798
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
3✔
1799
                                return err
×
1800
                        }
×
1801
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
3✔
1802
                                return err
×
1803
                        }
×
1804
                }
1805

1806
                return nil
3✔
1807
        }
1808

1809
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1810
}
1811

1812
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1813
        l uint64) error {
3✔
1814

3✔
1815
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
6✔
1816
                // First, we'll read out the varint that encodes the number of
3✔
1817
                // circuit keys encoded.
3✔
1818
                numKeys, err := tlv.ReadVarInt(r, buf)
3✔
1819
                if err != nil {
3✔
1820
                        return err
×
1821
                }
×
1822

1823
                // Now that we know how many keys to expect, iterate reading
1824
                // each one until we're done.
1825
                for i := uint64(0); i < numKeys; i++ {
6✔
1826
                        var (
3✔
1827
                                key  models.CircuitKey
3✔
1828
                                scid uint64
3✔
1829
                        )
3✔
1830

3✔
1831
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
3✔
1832
                                return err
×
1833
                        }
×
1834

1835
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
3✔
1836

3✔
1837
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
3✔
1838
                        if err != nil {
3✔
1839
                                return err
×
1840
                        }
×
1841

1842
                        (*v)[key] = struct{}{}
3✔
1843
                }
1844

1845
                return nil
3✔
1846
        }
1847

1848
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
×
1849
}
1850

1851
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1852
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1853
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
6✔
1854
                // We'll encode the AMP state as a series of KV pairs on the
3✔
1855
                // wire with a length prefix.
3✔
1856
                numRecords := uint64(len(*v))
3✔
1857

3✔
1858
                // First, we'll write out the number of records as a var int.
3✔
1859
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
3✔
1860
                        return err
×
1861
                }
×
1862

1863
                // With that written out, we'll now encode the entries
1864
                // themselves as a sub-TLV record, which includes its _own_
1865
                // inner length prefix.
1866
                for setID, ampState := range *v {
6✔
1867
                        setID := [32]byte(setID)
3✔
1868
                        ampState := ampState
3✔
1869

3✔
1870
                        htlcState := uint8(ampState.State)
3✔
1871
                        settleDate := ampState.SettleDate
3✔
1872
                        settleDateBytes, err := settleDate.MarshalBinary()
3✔
1873
                        if err != nil {
3✔
1874
                                return err
×
1875
                        }
×
1876

1877
                        amtPaid := uint64(ampState.AmtPaid)
3✔
1878

3✔
1879
                        var ampStateTlvBytes bytes.Buffer
3✔
1880
                        tlvStream, err := tlv.NewStream(
3✔
1881
                                tlv.MakePrimitiveRecord(
3✔
1882
                                        ampStateSetIDType, &setID,
3✔
1883
                                ),
3✔
1884
                                tlv.MakePrimitiveRecord(
3✔
1885
                                        ampStateHtlcStateType, &htlcState,
3✔
1886
                                ),
3✔
1887
                                tlv.MakePrimitiveRecord(
3✔
1888
                                        ampStateSettleIndexType,
3✔
1889
                                        &ampState.SettleIndex,
3✔
1890
                                ),
3✔
1891
                                tlv.MakePrimitiveRecord(
3✔
1892
                                        ampStateSettleDateType,
3✔
1893
                                        &settleDateBytes,
3✔
1894
                                ),
3✔
1895
                                tlv.MakeDynamicRecord(
3✔
1896
                                        ampStateCircuitKeysType,
3✔
1897
                                        &ampState.InvoiceKeys,
3✔
1898
                                        func() (uint64, error) {
6✔
1899
                                                // The record takes 8 bytes to
3✔
1900
                                                // encode the set of circuits,
3✔
1901
                                                // 8 bytes for the scid for the
3✔
1902
                                                // key, and 8 bytes for the HTLC
3✔
1903
                                                // index.
3✔
1904
                                                keys := ampState.InvoiceKeys
3✔
1905
                                                numKeys := uint64(len(keys))
3✔
1906
                                                size := tlv.VarIntSize(numKeys)
3✔
1907
                                                dataSize := (numKeys * 16)
3✔
1908

3✔
1909
                                                return size + dataSize, nil
3✔
1910
                                        },
3✔
1911
                                        encodeCircuitKeys, decodeCircuitKeys,
1912
                                ),
1913
                                tlv.MakePrimitiveRecord(
1914
                                        ampStateAmtPaidType, &amtPaid,
1915
                                ),
1916
                        )
1917
                        if err != nil {
3✔
1918
                                return err
×
1919
                        }
×
1920

1921
                        err = tlvStream.Encode(&ampStateTlvBytes)
3✔
1922
                        if err != nil {
3✔
1923
                                return err
×
1924
                        }
×
1925

1926
                        // We encode the record with a varint length followed by
1927
                        // the _raw_ TLV bytes.
1928
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
3✔
1929
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
3✔
1930
                                return err
×
1931
                        }
×
1932

1933
                        _, err = w.Write(ampStateTlvBytes.Bytes())
3✔
1934
                        if err != nil {
3✔
1935
                                return err
×
1936
                        }
×
1937
                }
1938

1939
                return nil
3✔
1940
        }
1941

1942
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1943
}
1944

1945
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1946
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
1947
        l uint64) error {
3✔
1948

3✔
1949
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
6✔
1950
                // First, we'll decode the varint that encodes how many set IDs
3✔
1951
                // are encoded within the greater map.
3✔
1952
                numRecords, err := tlv.ReadVarInt(r, buf)
3✔
1953
                if err != nil {
3✔
1954
                        return err
×
1955
                }
×
1956

1957
                // Now that we know how many records we'll need to read, we can
1958
                // iterate and read them all out in series.
1959
                for i := uint64(0); i < numRecords; i++ {
6✔
1960
                        // Read out the varint that encodes the size of this
3✔
1961
                        // inner TLV record.
3✔
1962
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
3✔
1963
                        if err != nil {
3✔
1964
                                return err
×
1965
                        }
×
1966

1967
                        // Using this information, we'll create a new limited
1968
                        // reader that'll return an EOF once the end has been
1969
                        // reached so the stream stops consuming bytes.
1970
                        innerTlvReader := io.LimitedReader{
3✔
1971
                                R: r,
3✔
1972
                                N: int64(stateRecordSize),
3✔
1973
                        }
3✔
1974

3✔
1975
                        var (
3✔
1976
                                setID           [32]byte
3✔
1977
                                htlcState       uint8
3✔
1978
                                settleIndex     uint64
3✔
1979
                                settleDateBytes []byte
3✔
1980
                                invoiceKeys     = make(
3✔
1981
                                        map[models.CircuitKey]struct{},
3✔
1982
                                )
3✔
1983
                                amtPaid uint64
3✔
1984
                        )
3✔
1985
                        tlvStream, err := tlv.NewStream(
3✔
1986
                                tlv.MakePrimitiveRecord(
3✔
1987
                                        ampStateSetIDType, &setID,
3✔
1988
                                ),
3✔
1989
                                tlv.MakePrimitiveRecord(
3✔
1990
                                        ampStateHtlcStateType, &htlcState,
3✔
1991
                                ),
3✔
1992
                                tlv.MakePrimitiveRecord(
3✔
1993
                                        ampStateSettleIndexType, &settleIndex,
3✔
1994
                                ),
3✔
1995
                                tlv.MakePrimitiveRecord(
3✔
1996
                                        ampStateSettleDateType,
3✔
1997
                                        &settleDateBytes,
3✔
1998
                                ),
3✔
1999
                                tlv.MakeDynamicRecord(
3✔
2000
                                        ampStateCircuitKeysType,
3✔
2001
                                        &invoiceKeys, nil,
3✔
2002
                                        encodeCircuitKeys, decodeCircuitKeys,
3✔
2003
                                ),
3✔
2004
                                tlv.MakePrimitiveRecord(
3✔
2005
                                        ampStateAmtPaidType, &amtPaid,
3✔
2006
                                ),
3✔
2007
                        )
3✔
2008
                        if err != nil {
3✔
2009
                                return err
×
2010
                        }
×
2011

2012
                        err = tlvStream.Decode(&innerTlvReader)
3✔
2013
                        if err != nil {
3✔
2014
                                return err
×
2015
                        }
×
2016

2017
                        var settleDate time.Time
3✔
2018
                        err = settleDate.UnmarshalBinary(settleDateBytes)
3✔
2019
                        if err != nil {
3✔
2020
                                return err
×
2021
                        }
×
2022

2023
                        (*v)[setID] = invpkg.InvoiceStateAMP{
3✔
2024
                                State:       invpkg.HtlcState(htlcState),
3✔
2025
                                SettleIndex: settleIndex,
3✔
2026
                                SettleDate:  settleDate,
3✔
2027
                                InvoiceKeys: invoiceKeys,
3✔
2028
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
3✔
2029
                        }
3✔
2030
                }
2031

2032
                return nil
3✔
2033
        }
2034

2035
        return tlv.NewTypeForDecodingErr(
×
2036
                val, "channeldb.AMPInvoiceState", l, l,
×
2037
        )
×
2038
}
2039

2040
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
2041
// as a map.
2042
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
2043
        error) {
3✔
2044

3✔
2045
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
2046
        for {
6✔
2047
                // Read the length of the tlv stream for this htlc.
3✔
2048
                var streamLen int64
3✔
2049
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
6✔
2050
                        if err == io.EOF {
6✔
2051
                                break
3✔
2052
                        }
2053

2054
                        return nil, err
×
2055
                }
2056

2057
                // Limit the reader so that it stops at the end of this htlc's
2058
                // stream.
2059
                htlcReader := io.LimitReader(r, streamLen)
3✔
2060

3✔
2061
                // Decode the contents into the htlc fields.
3✔
2062
                var (
3✔
2063
                        htlc                    invpkg.InvoiceHTLC
3✔
2064
                        key                     models.CircuitKey
3✔
2065
                        chanID                  uint64
3✔
2066
                        state                   uint8
3✔
2067
                        acceptTime, resolveTime uint64
3✔
2068
                        amt, mppTotalAmt        uint64
3✔
2069
                        amp                     = &record.AMP{}
3✔
2070
                        hash32                  = &[32]byte{}
3✔
2071
                        preimage32              = &[32]byte{}
3✔
2072
                )
3✔
2073
                tlvStream, err := tlv.NewStream(
3✔
2074
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
2075
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
2076
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
2077
                        tlv.MakePrimitiveRecord(
3✔
2078
                                acceptHeightType, &htlc.AcceptHeight,
3✔
2079
                        ),
3✔
2080
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
2081
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
2082
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
2083
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
2084
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
2085
                        tlv.MakeDynamicRecord(
3✔
2086
                                htlcAMPType, amp, func() (uint64, error) {
3✔
NEW
2087
                                        return amp.PayloadSize(), nil
×
NEW
2088
                                },
×
2089
                                record.AMPEncoder, record.AMPDecoder,
2090
                        ),
2091
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
2092
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
2093
                )
2094
                if err != nil {
3✔
2095
                        return nil, err
×
2096
                }
×
2097

2098
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
3✔
2099
                if err != nil {
3✔
2100
                        return nil, err
×
2101
                }
×
2102

2103
                if _, ok := parsedTypes[htlcAMPType]; !ok {
6✔
2104
                        amp = nil
3✔
2105
                }
3✔
2106

2107
                var preimage *lntypes.Preimage
3✔
2108
                if _, ok := parsedTypes[htlcPreimageType]; ok {
6✔
2109
                        pimg := lntypes.Preimage(*preimage32)
3✔
2110
                        preimage = &pimg
3✔
2111
                }
3✔
2112

2113
                var hash *lntypes.Hash
3✔
2114
                if _, ok := parsedTypes[htlcHashType]; ok {
6✔
2115
                        h := lntypes.Hash(*hash32)
3✔
2116
                        hash = &h
3✔
2117
                }
3✔
2118

2119
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
3✔
2120
                htlc.AcceptTime = getNanoTime(acceptTime)
3✔
2121
                htlc.ResolveTime = getNanoTime(resolveTime)
3✔
2122
                htlc.State = invpkg.HtlcState(state)
3✔
2123
                htlc.Amt = lnwire.MilliSatoshi(amt)
3✔
2124
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
3✔
2125
                if amp != nil && hash != nil {
6✔
2126
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
3✔
2127
                                Record:   *amp,
3✔
2128
                                Hash:     *hash,
3✔
2129
                                Preimage: preimage,
3✔
2130
                        }
3✔
2131
                }
3✔
2132

2133
                // Reconstruct the custom records fields from the parsed types
2134
                // map return from the tlv parser.
2135
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
3✔
2136

3✔
2137
                htlcs[key] = &htlc
3✔
2138
        }
2139

2140
        return htlcs, nil
3✔
2141
}
2142

2143
// invoiceSetIDKeyLen is the length of the key that's used to store the
2144
// individual HTLCs prefixed by their ID along side the main invoice within the
2145
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2146
// set ID.
2147
const invoiceSetIDKeyLen = 4 + 32
2148

2149
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2150
// number where the HTLCs for this setID will be stored udner.
2151
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
3✔
2152
        // Construct the prefix key we need to obtain the invoice information:
3✔
2153
        // invoiceNum || setID.
3✔
2154
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
3✔
2155
        copy(invoiceSetIDKey[:], invoiceNum)
3✔
2156
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
3✔
2157

3✔
2158
        return invoiceSetIDKey
3✔
2159
}
3✔
2160

2161
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
2162
// greater AMP invoices. We do this by deleting the set of keys that share the
2163
// invoice number as a prefix.
UNCOV
2164
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
×
UNCOV
2165
        // Since it isn't safe to delete using an active cursor, we'll use the
×
UNCOV
2166
        // cursor simply to collect the set of keys we need to delete, _then_
×
UNCOV
2167
        // delete them in another pass.
×
UNCOV
2168
        var keysToDel [][]byte
×
UNCOV
2169
        err := forEachAMPInvoice(
×
UNCOV
2170
                invoiceBucket, invoiceNum,
×
UNCOV
2171
                func(cursorKey, v []byte) error {
×
UNCOV
2172
                        keysToDel = append(keysToDel, cursorKey)
×
UNCOV
2173
                        return nil
×
UNCOV
2174
                },
×
2175
        )
UNCOV
2176
        if err != nil {
×
2177
                return err
×
2178
        }
×
2179

2180
        // In this next phase, we'll then delete all the relevant invoices.
UNCOV
2181
        for _, keyToDel := range keysToDel {
×
UNCOV
2182
                if err := invoiceBucket.Delete(keyToDel); err != nil {
×
2183
                        return err
×
2184
                }
×
2185
        }
2186

UNCOV
2187
        return nil
×
2188
}
2189

2190
// delAMPSettleIndex removes all the entries in the settle index associated
2191
// with a given AMP invoice.
2192
func delAMPSettleIndex(invoiceNum []byte, invoices,
UNCOV
2193
        settleIndex kvdb.RwBucket) error {
×
UNCOV
2194

×
UNCOV
2195
        // First, we need to grab the AMP invoice state to see if there's
×
UNCOV
2196
        // anything that we even need to delete.
×
UNCOV
2197
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
×
UNCOV
2198
        if err != nil {
×
2199
                return err
×
2200
        }
×
2201

2202
        // If there's no AMP state at all (non-AMP invoice), then we can return
2203
        // early.
UNCOV
2204
        if len(ampState) == 0 {
×
UNCOV
2205
                return nil
×
UNCOV
2206
        }
×
2207

2208
        // Otherwise, we'll need to iterate and delete each settle index within
2209
        // the set of returned entries.
UNCOV
2210
        var settleIndexKey [8]byte
×
UNCOV
2211
        for _, subState := range ampState {
×
UNCOV
2212
                byteOrder.PutUint64(
×
UNCOV
2213
                        settleIndexKey[:], subState.SettleIndex,
×
UNCOV
2214
                )
×
UNCOV
2215

×
UNCOV
2216
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
×
2217
                        return err
×
2218
                }
×
2219
        }
2220

UNCOV
2221
        return nil
×
2222
}
2223

2224
// DeleteCanceledInvoices deletes all canceled invoices from the database.
UNCOV
2225
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
×
UNCOV
2226
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2227
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2228
                if invoices == nil {
×
2229
                        return nil
×
2230
                }
×
2231

UNCOV
2232
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2233
                        invoiceIndexBucket,
×
UNCOV
2234
                )
×
UNCOV
2235
                if invoiceIndex == nil {
×
UNCOV
2236
                        return nil
×
UNCOV
2237
                }
×
2238

UNCOV
2239
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2240
                        addIndexBucket,
×
UNCOV
2241
                )
×
UNCOV
2242
                if invoiceAddIndex == nil {
×
2243
                        return nil
×
2244
                }
×
2245

UNCOV
2246
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
UNCOV
2247

×
UNCOV
2248
                return invoiceIndex.ForEach(func(k, v []byte) error {
×
UNCOV
2249
                        // Skip the special numInvoicesKey as that does not
×
UNCOV
2250
                        // point to a valid invoice.
×
UNCOV
2251
                        if bytes.Equal(k, numInvoicesKey) {
×
UNCOV
2252
                                return nil
×
UNCOV
2253
                        }
×
2254

2255
                        // Skip sub-buckets.
UNCOV
2256
                        if v == nil {
×
2257
                                return nil
×
2258
                        }
×
2259

UNCOV
2260
                        invoice, err := fetchInvoice(v, invoices, nil, false)
×
UNCOV
2261
                        if err != nil {
×
2262
                                return err
×
2263
                        }
×
2264

UNCOV
2265
                        if invoice.State != invpkg.ContractCanceled {
×
UNCOV
2266
                                return nil
×
UNCOV
2267
                        }
×
2268

2269
                        // Delete the payment hash from the invoice index.
UNCOV
2270
                        err = invoiceIndex.Delete(k)
×
UNCOV
2271
                        if err != nil {
×
2272
                                return err
×
2273
                        }
×
2274

2275
                        // Delete payment address index reference if there's a
2276
                        // valid payment address.
UNCOV
2277
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
×
UNCOV
2278
                                // To ensure consistency check that the already
×
UNCOV
2279
                                // fetched invoice key matches the one in the
×
UNCOV
2280
                                // payment address index.
×
UNCOV
2281
                                key := payAddrIndex.Get(
×
UNCOV
2282
                                        invoice.Terms.PaymentAddr[:],
×
UNCOV
2283
                                )
×
UNCOV
2284
                                if bytes.Equal(key, k) {
×
2285
                                        // Delete from the payment address
×
2286
                                        // index.
×
2287
                                        if err := payAddrIndex.Delete(
×
2288
                                                invoice.Terms.PaymentAddr[:],
×
2289
                                        ); err != nil {
×
2290
                                                return err
×
2291
                                        }
×
2292
                                }
2293
                        }
2294

2295
                        // Remove from the add index.
UNCOV
2296
                        var addIndexKey [8]byte
×
UNCOV
2297
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
×
UNCOV
2298
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
UNCOV
2299
                        if err != nil {
×
2300
                                return err
×
2301
                        }
×
2302

2303
                        // Note that we don't need to delete the invoice from
2304
                        // the settle index as it is not added until the
2305
                        // invoice is settled.
2306

2307
                        // Now remove all sub invoices.
UNCOV
2308
                        err = delAMPInvoices(k, invoices)
×
UNCOV
2309
                        if err != nil {
×
2310
                                return err
×
2311
                        }
×
2312

2313
                        // Finally remove the serialized invoice from the
2314
                        // invoice bucket.
UNCOV
2315
                        return invoices.Delete(k)
×
2316
                })
UNCOV
2317
        }, func() {})
×
2318
}
2319

2320
// DeleteInvoice attempts to delete the passed invoices from the database in
2321
// one transaction. The passed delete references hold all keys required to
2322
// delete the invoices without also needing to deserialize them.
2323
func (d *DB) DeleteInvoice(_ context.Context,
UNCOV
2324
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
×
UNCOV
2325

×
UNCOV
2326
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2327
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2328
                if invoices == nil {
×
2329
                        return invpkg.ErrNoInvoicesCreated
×
2330
                }
×
2331

UNCOV
2332
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2333
                        invoiceIndexBucket,
×
UNCOV
2334
                )
×
UNCOV
2335
                if invoiceIndex == nil {
×
2336
                        return invpkg.ErrNoInvoicesCreated
×
2337
                }
×
2338

UNCOV
2339
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2340
                        addIndexBucket,
×
UNCOV
2341
                )
×
UNCOV
2342
                if invoiceAddIndex == nil {
×
2343
                        return invpkg.ErrNoInvoicesCreated
×
2344
                }
×
2345

2346
                // settleIndex can be nil, as the bucket is created lazily
2347
                // when the first invoice is settled.
UNCOV
2348
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
×
UNCOV
2349

×
UNCOV
2350
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
UNCOV
2351

×
UNCOV
2352
                for _, ref := range invoicesToDelete {
×
UNCOV
2353
                        // Fetch the invoice key for using it to check for
×
UNCOV
2354
                        // consistency and also to delete from the invoice
×
UNCOV
2355
                        // index.
×
UNCOV
2356
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
×
UNCOV
2357
                        if invoiceKey == nil {
×
UNCOV
2358
                                return invpkg.ErrInvoiceNotFound
×
UNCOV
2359
                        }
×
2360

UNCOV
2361
                        err := invoiceIndex.Delete(ref.PayHash[:])
×
UNCOV
2362
                        if err != nil {
×
2363
                                return err
×
2364
                        }
×
2365

2366
                        // Delete payment address index reference if there's a
2367
                        // valid payment address passed.
UNCOV
2368
                        if ref.PayAddr != nil {
×
UNCOV
2369
                                // To ensure consistency check that the already
×
UNCOV
2370
                                // fetched invoice key matches the one in the
×
UNCOV
2371
                                // payment address index.
×
UNCOV
2372
                                key := payAddrIndex.Get(ref.PayAddr[:])
×
UNCOV
2373
                                if bytes.Equal(key, invoiceKey) {
×
UNCOV
2374
                                        // Delete from the payment address
×
UNCOV
2375
                                        // index. Note that since the payment
×
UNCOV
2376
                                        // address index has been introduced
×
UNCOV
2377
                                        // with an empty migration it may be
×
UNCOV
2378
                                        // possible that the index doesn't have
×
UNCOV
2379
                                        // an entry for this invoice.
×
UNCOV
2380
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
×
UNCOV
2381
                                        if err := payAddrIndex.Delete(
×
UNCOV
2382
                                                ref.PayAddr[:],
×
UNCOV
2383
                                        ); err != nil {
×
2384
                                                return err
×
2385
                                        }
×
2386
                                }
2387
                        }
2388

UNCOV
2389
                        var addIndexKey [8]byte
×
UNCOV
2390
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
×
UNCOV
2391

×
UNCOV
2392
                        // To ensure consistency check that the key stored in
×
UNCOV
2393
                        // the add index also matches the previously fetched
×
UNCOV
2394
                        // invoice key.
×
UNCOV
2395
                        key := invoiceAddIndex.Get(addIndexKey[:])
×
UNCOV
2396
                        if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2397
                                return fmt.Errorf("unknown invoice " +
×
UNCOV
2398
                                        "in add index")
×
UNCOV
2399
                        }
×
2400

2401
                        // Remove from the add index.
UNCOV
2402
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
UNCOV
2403
                        if err != nil {
×
2404
                                return err
×
2405
                        }
×
2406

2407
                        // Remove from the settle index if available and
2408
                        // if the invoice is settled.
UNCOV
2409
                        if settleIndex != nil && ref.SettleIndex > 0 {
×
UNCOV
2410
                                var settleIndexKey [8]byte
×
UNCOV
2411
                                byteOrder.PutUint64(
×
UNCOV
2412
                                        settleIndexKey[:], ref.SettleIndex,
×
UNCOV
2413
                                )
×
UNCOV
2414

×
UNCOV
2415
                                // To ensure consistency check that the already
×
UNCOV
2416
                                // fetched invoice key matches the one in the
×
UNCOV
2417
                                // settle index
×
UNCOV
2418
                                key := settleIndex.Get(settleIndexKey[:])
×
UNCOV
2419
                                if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2420
                                        return fmt.Errorf("unknown invoice " +
×
UNCOV
2421
                                                "in settle index")
×
UNCOV
2422
                                }
×
2423

UNCOV
2424
                                err = settleIndex.Delete(settleIndexKey[:])
×
UNCOV
2425
                                if err != nil {
×
2426
                                        return err
×
2427
                                }
×
2428
                        }
2429

2430
                        // In addition to deleting the main invoice state, if
2431
                        // this is an AMP invoice, then we'll also need to
2432
                        // delete the set HTLC set stored as a key prefix. For
2433
                        // non-AMP invoices, this'll be a noop.
UNCOV
2434
                        err = delAMPSettleIndex(
×
UNCOV
2435
                                invoiceKey, invoices, settleIndex,
×
UNCOV
2436
                        )
×
UNCOV
2437
                        if err != nil {
×
2438
                                return err
×
2439
                        }
×
UNCOV
2440
                        err = delAMPInvoices(invoiceKey, invoices)
×
UNCOV
2441
                        if err != nil {
×
2442
                                return err
×
2443
                        }
×
2444

2445
                        // Finally remove the serialized invoice from the
2446
                        // invoice bucket.
UNCOV
2447
                        err = invoices.Delete(invoiceKey)
×
UNCOV
2448
                        if err != nil {
×
2449
                                return err
×
2450
                        }
×
2451
                }
2452

UNCOV
2453
                return nil
×
UNCOV
2454
        }, func() {})
×
2455

UNCOV
2456
        return err
×
2457
}
2458

2459
// SetInvoiceBucketTombstone sets the tombstone key in the invoice bucket to
2460
// mark the bucket as permanently closed. This prevents it from being reopened
2461
// in the future.
UNCOV
2462
func (d *DB) SetInvoiceBucketTombstone() error {
×
UNCOV
2463
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2464
                // Access the top-level invoice bucket.
×
UNCOV
2465
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2466
                if invoices == nil {
×
2467
                        return fmt.Errorf("invoice bucket does not exist")
×
2468
                }
×
2469

2470
                // Add the tombstone key to the invoice bucket.
UNCOV
2471
                err := invoices.Put(invoiceBucketTombstone, []byte("1"))
×
UNCOV
2472
                if err != nil {
×
2473
                        return fmt.Errorf("failed to set tombstone: %w", err)
×
2474
                }
×
2475

UNCOV
2476
                return nil
×
UNCOV
2477
        }, func() {})
×
2478
}
2479

2480
// GetInvoiceBucketTombstone checks if the tombstone key exists in the invoice
2481
// bucket. It returns true if the tombstone is present and false otherwise.
2482
func (d *DB) GetInvoiceBucketTombstone() (bool, error) {
3✔
2483
        var tombstoneExists bool
3✔
2484

3✔
2485
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
2486
                // Access the top-level invoice bucket.
3✔
2487
                invoices := tx.ReadBucket(invoiceBucket)
3✔
2488
                if invoices == nil {
3✔
2489
                        return fmt.Errorf("invoice bucket does not exist")
×
2490
                }
×
2491

2492
                // Check if the tombstone key exists.
2493
                tombstone := invoices.Get(invoiceBucketTombstone)
3✔
2494
                tombstoneExists = tombstone != nil
3✔
2495

3✔
2496
                return nil
3✔
2497
        }, func() {})
3✔
2498
        if err != nil {
3✔
2499
                return false, err
×
2500
        }
×
2501

2502
        return tombstoneExists, nil
3✔
2503
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc