• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13153965138

05 Feb 2025 09:13AM UTC coverage: 49.301% (-9.5%) from 58.798%
13153965138

Pull #9477

github

ellemouton
docs: update release notes
Pull Request #9477: discovery+graph: various preparations for moving funding tx validation to the gossiper

0 of 11 new or added lines in 2 files covered. (0.0%)

27263 existing lines in 435 files now uncovered.

100706 of 204266 relevant lines covered (49.3%)

1.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.7
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "time"
11

12
        "github.com/lightningnetwork/lnd/graph/db/models"
13
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
14
        invpkg "github.com/lightningnetwork/lnd/invoices"
15
        "github.com/lightningnetwork/lnd/kvdb"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/tlv"
20
)
21

22
var (
23
        // invoiceBucket is the name of the bucket within the database that
24
        // stores all data related to invoices no matter their final state.
25
        // Within the invoice bucket, each invoice is keyed by its invoice ID
26
        // which is a monotonically increasing uint32.
27
        invoiceBucket = []byte("invoices")
28

29
        // paymentHashIndexBucket is the name of the sub-bucket within the
30
        // invoiceBucket which indexes all invoices by their payment hash. The
31
        // payment hash is the sha256 of the invoice's payment preimage. This
32
        // index is used to detect duplicates, and also to provide a fast path
33
        // for looking up incoming HTLCs to determine if we're able to settle
34
        // them fully.
35
        //
36
        // maps: payHash => invoiceKey
37
        invoiceIndexBucket = []byte("paymenthashes")
38

39
        // payAddrIndexBucket is the name of the top-level bucket that maps
40
        // payment addresses to their invoice number. This can be used
41
        // to efficiently query or update non-legacy invoices. Note that legacy
42
        // invoices will not be included in this index since they all have the
43
        // same, all-zero payment address, however all newly generated invoices
44
        // will end up in this index.
45
        //
46
        // maps: payAddr => invoiceKey
47
        payAddrIndexBucket = []byte("pay-addr-index")
48

49
        // setIDIndexBucket is the name of the top-level bucket that maps set
50
        // ids to their invoice number. This can be used to efficiently query or
51
        // update AMP invoice. Note that legacy or MPP invoices will not be
52
        // included in this index, since their HTLCs do not have a set id.
53
        //
54
        // maps: setID => invoiceKey
55
        setIDIndexBucket = []byte("set-id-index")
56

57
        // numInvoicesKey is the name of key which houses the auto-incrementing
58
        // invoice ID which is essentially used as a primary key. With each
59
        // invoice inserted, the primary key is incremented by one. This key is
60
        // stored within the invoiceIndexBucket. Within the invoiceBucket
61
        // invoices are uniquely identified by the invoice ID.
62
        numInvoicesKey = []byte("nik")
63

64
        // addIndexBucket is an index bucket that we'll use to create a
65
        // monotonically increasing set of add indexes. Each time we add a new
66
        // invoice, this sequence number will be incremented and then populated
67
        // within the new invoice.
68
        //
69
        // In addition to this sequence number, we map:
70
        //
71
        //   addIndexNo => invoiceKey
72
        addIndexBucket = []byte("invoice-add-index")
73

74
        // settleIndexBucket is an index bucket that we'll use to create a
75
        // monotonically increasing integer for tracking a "settle index". Each
76
        // time an invoice is settled, this sequence number will be incremented
77
        // as populate within the newly settled invoice.
78
        //
79
        // In addition to this sequence number, we map:
80
        //
81
        //   settleIndexNo => invoiceKey
82
        settleIndexBucket = []byte("invoice-settle-index")
83
)
84

85
const (
86
        // A set of tlv type definitions used to serialize invoice htlcs to the
87
        // database.
88
        //
89
        // NOTE: A migration should be added whenever this list changes. This
90
        // prevents against the database being rolled back to an older
91
        // format where the surrounding logic might assume a different set of
92
        // fields are known.
93
        chanIDType       tlv.Type = 1
94
        htlcIDType       tlv.Type = 3
95
        amtType          tlv.Type = 5
96
        acceptHeightType tlv.Type = 7
97
        acceptTimeType   tlv.Type = 9
98
        resolveTimeType  tlv.Type = 11
99
        expiryHeightType tlv.Type = 13
100
        htlcStateType    tlv.Type = 15
101
        mppTotalAmtType  tlv.Type = 17
102
        htlcAMPType      tlv.Type = 19
103
        htlcHashType     tlv.Type = 21
104
        htlcPreimageType tlv.Type = 23
105

106
        // A set of tlv type definitions used to serialize invoice bodiees.
107
        //
108
        // NOTE: A migration should be added whenever this list changes. This
109
        // prevents against the database being rolled back to an older
110
        // format where the surrounding logic might assume a different set of
111
        // fields are known.
112
        memoType            tlv.Type = 0
113
        payReqType          tlv.Type = 1
114
        createTimeType      tlv.Type = 2
115
        settleTimeType      tlv.Type = 3
116
        addIndexType        tlv.Type = 4
117
        settleIndexType     tlv.Type = 5
118
        preimageType        tlv.Type = 6
119
        valueType           tlv.Type = 7
120
        cltvDeltaType       tlv.Type = 8
121
        expiryType          tlv.Type = 9
122
        paymentAddrType     tlv.Type = 10
123
        featuresType        tlv.Type = 11
124
        invStateType        tlv.Type = 12
125
        amtPaidType         tlv.Type = 13
126
        hodlInvoiceType     tlv.Type = 14
127
        invoiceAmpStateType tlv.Type = 15
128

129
        // A set of tlv type definitions used to serialize the invoice AMP
130
        // state along-side the main invoice body.
131
        ampStateSetIDType       tlv.Type = 0
132
        ampStateHtlcStateType   tlv.Type = 1
133
        ampStateSettleIndexType tlv.Type = 2
134
        ampStateSettleDateType  tlv.Type = 3
135
        ampStateCircuitKeysType tlv.Type = 4
136
        ampStateAmtPaidType     tlv.Type = 5
137
)
138

139
// AddInvoice inserts the targeted invoice into the database. If the invoice has
140
// *any* payment hashes which already exists within the database, then the
141
// insertion will be aborted and rejected due to the strict policy banning any
142
// duplicate payment hashes. A side effect of this function is that it sets
143
// AddIndex on newInvoice.
144
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
145
        paymentHash lntypes.Hash) (uint64, error) {
146

147
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
148
                return 0, err
149
        }
150

151
        var invoiceAddIndex uint64
152
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
3✔
153
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
3✔
154
                if err != nil {
3✔
UNCOV
155
                        return err
×
UNCOV
156
                }
×
157

158
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
159
                        invoiceIndexBucket,
6✔
160
                )
3✔
161
                if err != nil {
3✔
162
                        return err
×
163
                }
×
164
                addIndex, err := invoices.CreateBucketIfNotExists(
165
                        addIndexBucket,
3✔
166
                )
3✔
167
                if err != nil {
3✔
168
                        return err
3✔
169
                }
×
170

×
171
                // Ensure that an invoice an identical payment hash doesn't
3✔
172
                // already exist within the index.
3✔
173
                if invoiceIndex.Get(paymentHash[:]) != nil {
3✔
174
                        return invpkg.ErrDuplicateInvoice
3✔
175
                }
×
176

×
177
                // Check that we aren't inserting an invoice with a duplicate
178
                // payment address. The all-zeros payment address is
179
                // special-cased to support legacy keysend invoices which don't
180
                // assign one. This is safe since later we also will avoid
3✔
UNCOV
181
                // indexing them and avoid collisions.
×
UNCOV
182
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
183
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
184
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
185
                        if payAddrIndex.Get(paymentAddr) != nil {
186
                                return invpkg.ErrDuplicatePayAddr
187
                        }
188
                }
189

3✔
190
                // If the current running payment ID counter hasn't yet been
6✔
191
                // created, then create it now.
3✔
192
                var invoiceNum uint32
6✔
193
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
3✔
194
                if invoiceCounter == nil {
3✔
195
                        var scratch [4]byte
196
                        byteOrder.PutUint32(scratch[:], invoiceNum)
197
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
198
                        if err != nil {
199
                                return err
3✔
200
                        }
3✔
201
                } else {
6✔
202
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
3✔
203
                }
3✔
204

3✔
205
                newIndex, err := putInvoice(
3✔
206
                        invoices, invoiceIndex, payAddrIndex, addIndex,
×
207
                        newInvoice, invoiceNum, paymentHash,
×
208
                )
3✔
209
                if err != nil {
3✔
210
                        return err
3✔
211
                }
212

3✔
213
                invoiceAddIndex = newIndex
3✔
214
                return nil
3✔
215
        }, func() {
3✔
216
                invoiceAddIndex = 0
3✔
217
        })
×
218
        if err != nil {
×
219
                return 0, err
220
        }
3✔
221

3✔
222
        return invoiceAddIndex, err
3✔
223
}
3✔
224

3✔
225
// InvoicesAddedSince can be used by callers to seek into the event time series
6✔
226
// of all the invoices added in the database. The specified sinceAddIndex
3✔
227
// should be the highest add index that the caller knows of. This method will
3✔
228
// return all invoices with an add index greater than the specified
229
// sinceAddIndex.
3✔
230
//
231
// NOTE: The index starts from 1, as a result. We enforce that specifying a
232
// value below the starting index value is a noop.
233
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
234
        []invpkg.Invoice, error) {
235

236
        var newInvoices []invpkg.Invoice
237

238
        // If an index of zero was specified, then in order to maintain
239
        // backwards compat, we won't send out any new invoices.
240
        if sinceAddIndex == 0 {
241
                return newInvoices, nil
3✔
242
        }
3✔
243

3✔
244
        var startIndex [8]byte
3✔
245
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
3✔
246

3✔
247
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
248
                invoices := tx.ReadBucket(invoiceBucket)
3✔
249
                if invoices == nil {
3✔
250
                        return nil
251
                }
3✔
252

3✔
253
                addIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
254
                if addIndex == nil {
6✔
255
                        return nil
3✔
256
                }
3✔
257

×
258
                // We'll now run through each entry in the add index starting
×
259
                // at our starting index. We'll continue until we reach the
260
                // very end of the current key space.
3✔
261
                invoiceCursor := addIndex.ReadCursor()
3✔
262

×
263
                // We'll seek to the starting index, then manually advance the
×
264
                // cursor in order to skip the entry with the since add index.
265
                invoiceCursor.Seek(startIndex[:])
266
                addSeqNo, invoiceKey := invoiceCursor.Next()
267

268
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
3✔
269

3✔
270
                        // For each key found, we'll look up the actual
3✔
271
                        // invoice, then accumulate it into our return value.
3✔
272
                        invoice, err := fetchInvoice(
3✔
273
                                invoiceKey, invoices, nil, false,
3✔
274
                        )
3✔
275
                        if err != nil {
6✔
276
                                return err
3✔
277
                        }
3✔
278

3✔
279
                        newInvoices = append(newInvoices, invoice)
3✔
280
                }
3✔
281

3✔
282
                return nil
3✔
283
        }, func() {
×
284
                newInvoices = nil
×
285
        })
286
        if err != nil {
3✔
287
                return nil, err
288
        }
289

3✔
290
        return newInvoices, nil
3✔
291
}
3✔
292

3✔
293
// LookupInvoice attempts to look up an invoice according to its 32 byte
3✔
294
// payment hash. If an invoice which can settle the HTLC identified by the
×
295
// passed payment hash isn't found, then an error is returned. Otherwise, the
×
296
// full invoice is returned. Before setting the incoming HTLC, the values
297
// SHOULD be checked to ensure the payer meets the agreed upon contractual
3✔
298
// terms of the payment.
299
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
300
        invpkg.Invoice, error) {
301

302
        var invoice invpkg.Invoice
303
        err := kvdb.View(d, func(tx kvdb.RTx) error {
304
                invoices := tx.ReadBucket(invoiceBucket)
305
                if invoices == nil {
306
                        return invpkg.ErrNoInvoicesCreated
307
                }
3✔
308
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
309
                if invoiceIndex == nil {
3✔
310
                        return invpkg.ErrNoInvoicesCreated
6✔
311
                }
3✔
312
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
3✔
313
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
×
314

×
315
                // Retrieve the invoice number for this invoice using
3✔
316
                // the provided invoice reference.
6✔
317
                invoiceNum, err := fetchInvoiceNumByRef(
3✔
318
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
3✔
319
                )
3✔
320
                if err != nil {
3✔
321
                        return err
3✔
322
                }
3✔
323

3✔
324
                var setID *invpkg.SetID
3✔
325
                switch {
3✔
326
                // If this is a payment address ref, and the blank modified was
3✔
327
                // specified, then we'll use the zero set ID to indicate that
6✔
328
                // we won't want any HTLCs returned.
3✔
329
                case ref.PayAddr() != nil &&
3✔
330
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
331

3✔
332
                        var zeroSetID invpkg.SetID
3✔
333
                        setID = &zeroSetID
334

335
                // If this is a set ID ref, and the htlc set only modified was
336
                // specified, then we'll pass through the specified setID so
337
                // only that will be returned.
3✔
338
                case ref.SetID() != nil &&
3✔
339
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
3✔
340

3✔
341
                        setID = (*invpkg.SetID)(ref.SetID())
342
                }
343

344
                // An invoice was found, retrieve the remainder of the invoice
345
                // body.
346
                i, err := fetchInvoice(
3✔
347
                        invoiceNum, invoices, []*invpkg.SetID{setID}, true,
3✔
348
                )
3✔
349
                if err != nil {
350
                        return err
351
                }
352
                invoice = i
353

3✔
354
                return nil
3✔
355
        }, func() {})
3✔
356
        if err != nil {
3✔
357
                return invoice, err
×
358
        }
×
359

3✔
360
        return invoice, nil
3✔
361
}
3✔
362

3✔
363
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
6✔
364
// reference. The payment address will be treated as the primary key, falling
3✔
365
// back to the payment hash if nothing is found for the payment address. An
3✔
366
// error is returned if the invoice is not found.
367
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
3✔
368
        ref invpkg.InvoiceRef) ([]byte, error) {
369

370
        // If the set id is present, we only consult the set id index for this
371
        // invoice. This type of query is only used to facilitate user-facing
372
        // requests to lookup, settle or cancel an AMP invoice.
373
        setID := ref.SetID()
374
        if setID != nil {
375
                invoiceNumBySetID := setIDIndex.Get(setID[:])
3✔
376
                if invoiceNumBySetID == nil {
3✔
377
                        return nil, invpkg.ErrInvoiceNotFound
3✔
378
                }
3✔
379

3✔
380
                return invoiceNumBySetID, nil
3✔
381
        }
6✔
382

3✔
383
        payHash := ref.PayHash()
3✔
UNCOV
384
        payAddr := ref.PayAddr()
×
UNCOV
385

×
386
        getInvoiceNumByHash := func() []byte {
387
                if payHash != nil {
3✔
388
                        return invoiceIndex.Get(payHash[:])
389
                }
390
                return nil
3✔
391
        }
3✔
392

3✔
393
        getInvoiceNumByAddr := func() []byte {
6✔
394
                if payAddr != nil {
6✔
395
                        // Only allow lookups for payment address if it is not a
3✔
396
                        // blank payment address, which is a special-cased value
3✔
397
                        // for legacy keysend invoices.
3✔
398
                        if *payAddr != invpkg.BlankPayAddr {
399
                                return payAddrIndex.Get(payAddr[:])
400
                        }
6✔
401
                }
6✔
402
                return nil
3✔
403
        }
3✔
404

3✔
405
        invoiceNumByHash := getInvoiceNumByHash()
6✔
406
        invoiceNumByAddr := getInvoiceNumByAddr()
3✔
407
        switch {
3✔
408
        // If payment address and payment hash both reference an existing
409
        // invoice, ensure they reference the _same_ invoice.
3✔
410
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
411
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
412
                        return nil, invpkg.ErrInvRefEquivocation
3✔
413
                }
3✔
414

3✔
415
                return invoiceNumByAddr, nil
416

417
        // Return invoices by payment addr only.
3✔
418
        //
3✔
UNCOV
419
        // NOTE: We constrain this lookup to only apply if the invoice ref does
×
UNCOV
420
        // not contain a payment hash. Legacy and MPP payments depend on the
×
421
        // payment hash index to enforce that the HTLCs payment hash matches the
422
        // payment hash for the invoice, without this check we would
3✔
423
        // inadvertently assume the invoice contains the correct preimage for
424
        // the HTLC, which we only enforce via the lookup by the invoice index.
425
        case invoiceNumByAddr != nil && payHash == nil:
426
                return invoiceNumByAddr, nil
427

428
        // If we were only able to reference the invoice by hash, return the
429
        // corresponding invoice number. This can happen when no payment address
430
        // was provided, or if it didn't match anything in our records.
431
        case invoiceNumByHash != nil:
432
                return invoiceNumByHash, nil
3✔
433

3✔
434
        // Otherwise we don't know of the target invoice.
435
        default:
436
                return nil, invpkg.ErrInvoiceNotFound
437
        }
438
}
3✔
439

3✔
440
// FetchPendingInvoices returns all invoices that have not yet been settled or
441
// canceled. The returned map is keyed by the payment hash of each respective
442
// invoice.
3✔
443
func (d *DB) FetchPendingInvoices(_ context.Context) (
3✔
444
        map[lntypes.Hash]invpkg.Invoice, error) {
445

446
        result := make(map[lntypes.Hash]invpkg.Invoice)
447

448
        err := kvdb.View(d, func(tx kvdb.RTx) error {
449
                invoices := tx.ReadBucket(invoiceBucket)
450
                if invoices == nil {
451
                        return nil
3✔
452
                }
3✔
453

3✔
454
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
455
                if invoiceIndex == nil {
6✔
456
                        // Mask the error if there's no invoice
3✔
457
                        // index as that simply means there are no
3✔
458
                        // invoices added yet to the DB. In this case
×
459
                        // we simply return an empty list.
×
460
                        return nil
461
                }
3✔
462

6✔
463
                return invoiceIndex.ForEach(func(k, v []byte) error {
3✔
464
                        // Skip the special numInvoicesKey as that does not
3✔
465
                        // point to a valid invoice.
3✔
466
                        if bytes.Equal(k, numInvoicesKey) {
3✔
467
                                return nil
3✔
468
                        }
3✔
469

470
                        // Skip sub-buckets.
6✔
471
                        if v == nil {
3✔
472
                                return nil
3✔
473
                        }
6✔
474

3✔
475
                        invoice, err := fetchInvoice(v, invoices, nil, false)
3✔
476
                        if err != nil {
477
                                return err
478
                        }
3✔
479

×
480
                        if invoice.IsPending() {
×
481
                                var paymentHash lntypes.Hash
482
                                copy(paymentHash[:], k)
3✔
483
                                result[paymentHash] = invoice
3✔
484
                        }
×
485

×
486
                        return nil
487
                })
6✔
488
        }, func() {
3✔
489
                result = make(map[lntypes.Hash]invpkg.Invoice)
3✔
490
        })
3✔
491

3✔
492
        if err != nil {
493
                return nil, err
3✔
494
        }
495

3✔
496
        return result, nil
3✔
497
}
3✔
498

499
// QueryInvoices allows a caller to query the invoice database for invoices
3✔
500
// within the specified add index range.
×
501
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
×
502
        invpkg.InvoiceSlice, error) {
503

3✔
504
        var resp invpkg.InvoiceSlice
505

506
        err := kvdb.View(d, func(tx kvdb.RTx) error {
507
                // If the bucket wasn't found, then there aren't any invoices
508
                // within the database yet, so we can simply exit.
509
                invoices := tx.ReadBucket(invoiceBucket)
3✔
510
                if invoices == nil {
3✔
511
                        return invpkg.ErrNoInvoicesCreated
3✔
512
                }
3✔
513

6✔
514
                // Get the add index bucket which we will use to iterate through
3✔
515
                // our indexed invoices.
3✔
516
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
517
                if invoiceAddIndex == nil {
3✔
518
                        return invpkg.ErrNoInvoicesCreated
×
519
                }
×
520

521
                // Create a paginator which reads from our add index bucket with
522
                // the parameters provided by the invoice query.
523
                paginator := newPaginator(
3✔
524
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
6✔
525
                        q.NumMaxInvoices,
3✔
526
                )
3✔
527

528
                // accumulateInvoices looks up an invoice based on the index we
529
                // are given, adds it to our set of invoices if it has the right
530
                // characteristics for our query and returns the number of items
3✔
531
                // we have added to our set of invoices.
3✔
532
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
3✔
533
                        invoice, err := fetchInvoice(
3✔
534
                                indexValue, invoices, nil, false,
3✔
535
                        )
3✔
536
                        if err != nil {
3✔
537
                                return false, err
3✔
538
                        }
3✔
539

6✔
540
                        // Skip any settled or canceled invoices if the caller
3✔
541
                        // is only interested in pending ones.
3✔
542
                        if q.PendingOnly && !invoice.IsPending() {
3✔
543
                                return false, nil
3✔
544
                        }
×
545

×
546
                        // Get the creation time in Unix seconds, this always
547
                        // rounds down the nanoseconds to full seconds.
548
                        createTime := invoice.CreationDate.Unix()
549

3✔
UNCOV
550
                        // Skip any invoices that were created before the
×
UNCOV
551
                        // specified time.
×
552
                        if createTime < q.CreationDateStart {
553
                                return false, nil
554
                        }
555

3✔
556
                        // Skip any invoices that were created after the
3✔
557
                        // specified time.
3✔
558
                        if q.CreationDateEnd != 0 &&
3✔
559
                                createTime > q.CreationDateEnd {
6✔
560

3✔
561
                                return false, nil
3✔
562
                        }
563

564
                        // At this point, we've exhausted the offset, so we'll
565
                        // begin collecting invoices found within the range.
3✔
566
                        resp.Invoices = append(resp.Invoices, invoice)
6✔
567

3✔
568
                        return true, nil
3✔
569
                }
3✔
570

571
                // Query our paginator using accumulateInvoices to build up a
572
                // set of invoices.
573
                if err := paginator.query(accumulateInvoices); err != nil {
3✔
574
                        return err
3✔
575
                }
3✔
576

577
                // If we iterated through the add index in reverse order, then
578
                // we'll need to reverse the slice of invoices to return them in
579
                // forward order.
580
                if q.Reversed {
3✔
581
                        numInvoices := len(resp.Invoices)
×
582
                        for i := 0; i < numInvoices/2; i++ {
×
583
                                reverse := numInvoices - i - 1
584
                                resp.Invoices[i], resp.Invoices[reverse] =
585
                                        resp.Invoices[reverse], resp.Invoices[i]
586
                        }
587
                }
3✔
UNCOV
588

×
UNCOV
589
                return nil
×
UNCOV
590
        }, func() {
×
UNCOV
591
                resp = invpkg.InvoiceSlice{
×
UNCOV
592
                        InvoiceQuery: q,
×
UNCOV
593
                }
×
594
        })
595
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
596
                return resp, err
3✔
597
        }
3✔
598

3✔
599
        // Finally, record the indexes of the first and last invoices returned
3✔
600
        // so that the caller can resume from this point later on.
3✔
601
        if len(resp.Invoices) > 0 {
3✔
602
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
3✔
603
                lastIdx := len(resp.Invoices) - 1
×
604
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
×
605
        }
606

607
        return resp, nil
608
}
6✔
609

3✔
610
// UpdateInvoice attempts to update an invoice corresponding to the passed
3✔
611
// payment hash. If an invoice matching the passed payment hash doesn't exist
3✔
612
// within the database, then the action will fail with a "not found" error.
3✔
613
//
614
// The update is performed inside the same database transaction that fetches the
3✔
615
// invoice and is therefore atomic. The fields to update are controlled by the
616
// supplied callback.  When updating an invoice, the update itself happens
617
// in-memory on a copy of the invoice. Once it is written successfully to the
618
// database, the in-memory copy is returned to the caller.
619
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
620
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
621
        *invpkg.Invoice, error) {
622

623
        var updatedInvoice *invpkg.Invoice
624
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
625
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
626
                if err != nil {
627
                        return err
628
                }
3✔
629
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
630
                        invoiceIndexBucket,
3✔
631
                )
6✔
632
                if err != nil {
3✔
633
                        return err
3✔
634
                }
×
635
                settleIndex, err := invoices.CreateBucketIfNotExists(
×
636
                        settleIndexBucket,
3✔
637
                )
3✔
638
                if err != nil {
3✔
639
                        return err
3✔
640
                }
×
641
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
×
642
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
3✔
643

3✔
644
                // Retrieve the invoice number for this invoice using the
3✔
645
                // provided invoice reference.
3✔
646
                invoiceNum, err := fetchInvoiceNumByRef(
×
647
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
×
648
                )
3✔
649
                if err != nil {
3✔
650
                        return err
3✔
651
                }
3✔
652

3✔
653
                // setIDHint can also be nil here, which means all the HTLCs
3✔
654
                // for AMP invoices are fetched. If the blank setID is passed
3✔
655
                // in, then no HTLCs are fetched for the AMP invoice. If a
3✔
656
                // specific setID is passed in, then only the HTLCs for that
3✔
UNCOV
657
                // setID are fetched for a particular sub-AMP invoice.
×
UNCOV
658
                invoice, err := fetchInvoice(
×
659
                        invoiceNum, invoices, []*invpkg.SetID{setIDHint}, false,
660
                )
661
                if err != nil {
662
                        return err
663
                }
664

665
                now := d.clock.Now()
3✔
666
                updater := &kvInvoiceUpdater{
3✔
667
                        db:                d,
3✔
668
                        invoicesBucket:    invoices,
3✔
669
                        settleIndexBucket: settleIndex,
×
670
                        setIDIndexBucket:  setIDIndex,
×
671
                        updateTime:        now,
672
                        invoiceNum:        invoiceNum,
3✔
673
                        invoice:           &invoice,
3✔
674
                        updatedAmpHtlcs:   make(ampHTLCsMap),
3✔
675
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
3✔
676
                }
3✔
677

3✔
678
                payHash := ref.PayHash()
3✔
679
                updatedInvoice, err = invpkg.UpdateInvoice(
3✔
680
                        payHash, updater.invoice, now, callback, updater,
3✔
681
                )
3✔
682
                if err != nil {
3✔
683
                        return err
3✔
684
                }
3✔
685

3✔
686
                // If this is an AMP update, then limit the returned AMP state
3✔
687
                // to only the requested set ID.
3✔
688
                if setIDHint != nil {
3✔
689
                        filterInvoiceAMPState(updatedInvoice, setIDHint)
6✔
690
                }
3✔
691

3✔
692
                return nil
693
        }, func() {
694
                updatedInvoice = nil
695
        })
6✔
696

3✔
697
        return updatedInvoice, err
3✔
698
}
699

3✔
700
// filterInvoiceAMPState filters the AMP state of the invoice to only include
3✔
701
// state for the specified set IDs.
3✔
702
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
3✔
703
        filteredAMPState := make(invpkg.AMPInvoiceState)
704

3✔
705
        for _, setID := range setIDs {
706
                if setID == nil {
707
                        return
708
                }
709

3✔
710
                ampState, ok := invoice.AMPState[*setID]
3✔
711
                if ok {
3✔
712
                        filteredAMPState[*setID] = ampState
6✔
713
                }
6✔
714
        }
3✔
715

3✔
716
        invoice.AMPState = filteredAMPState
717
}
3✔
718

6✔
719
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
3✔
720
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
3✔
721

722
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
723
// is used with the kv implementation of the invoice database. Note that this
3✔
724
// updater is not concurrency safe and synchronizaton is expected to be handled
725
// on the DB level.
726
type kvInvoiceUpdater struct {
727
        db                *DB
728
        invoicesBucket    kvdb.RwBucket
729
        settleIndexBucket kvdb.RwBucket
730
        setIDIndexBucket  kvdb.RwBucket
731

732
        // updateTime is the timestamp for the update.
733
        updateTime time.Time
734

735
        // invoiceNum is a legacy key similar to the add index that is used
736
        // only in the kv implementation.
737
        invoiceNum []byte
738

739
        // invoice is the invoice that we're updating. As a side effect of the
740
        // update this invoice will be mutated.
741
        invoice *invpkg.Invoice
742

743
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
744
        // cancelled as part of this update.
745
        updatedAmpHtlcs ampHTLCsMap
746

747
        // settledSetIDs holds the set IDs that are settled with this update.
748
        settledSetIDs map[invpkg.SetID]struct{}
749
}
750

751
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
752
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
753
        _ *invpkg.InvoiceHTLC) error {
754

755
        return nil
756
}
757

758
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
759
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
760
        _ time.Time) error {
3✔
761

3✔
762
        return nil
3✔
763
}
3✔
764

765
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
766
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
767
        _ lntypes.Preimage) error {
3✔
768

3✔
769
        return nil
3✔
770
}
3✔
771

772
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
773
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
774
        _ *lntypes.Preimage) error {
3✔
775

3✔
776
        return nil
3✔
777
}
3✔
778

779
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
780
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
781
        return nil
3✔
782
}
3✔
783

3✔
784
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
3✔
785
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
786
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
787

3✔
788
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
3✔
789
                switch state.State {
3✔
790
                case invpkg.HtlcStateAccepted:
791
                        // If we're just now creating the HTLCs for this set
792
                        // then we'll also pull in the existing HTLCs that are
793
                        // part of this set, so we can write them all to disk
3✔
794
                        // together (same value)
3✔
795
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
6✔
796
                                &setID, invpkg.HtlcStateAccepted,
3✔
797
                        )
3✔
798

3✔
799
                case invpkg.HtlcStateCanceled:
3✔
800
                        // Only HTLCs in the accepted state, can be cancelled,
3✔
801
                        // but we also want to merge that with HTLCs that may be
3✔
802
                        // canceled as well since it can be cancelled one by
3✔
803
                        // one.
3✔
804
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
3✔
805
                                &setID, invpkg.HtlcStateAccepted,
UNCOV
806
                        )
×
UNCOV
807

×
UNCOV
808
                        cancelledHtlcs := k.invoice.HTLCSet(
×
UNCOV
809
                                &setID, invpkg.HtlcStateCanceled,
×
UNCOV
810
                        )
×
UNCOV
811
                        for htlcKey, htlc := range cancelledHtlcs {
×
UNCOV
812
                                k.updatedAmpHtlcs[setID][htlcKey] = htlc
×
UNCOV
813
                        }
×
UNCOV
814

×
UNCOV
815
                case invpkg.HtlcStateSettled:
×
UNCOV
816
                        k.updatedAmpHtlcs[setID] = make(
×
UNCOV
817
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
×
UNCOV
818
                        )
×
UNCOV
819
                }
×
UNCOV
820
        }
×
821

UNCOV
822
        if state.State == invpkg.HtlcStateSettled {
×
UNCOV
823
                // Add the set ID to the set that was settled in this invoice
×
UNCOV
824
                // update. We'll use this later to update the settle index.
×
UNCOV
825
                k.settledSetIDs[setID] = struct{}{}
×
826
        }
827

828
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
829

6✔
830
        return nil
3✔
831
}
3✔
832

3✔
833
// Finalize finalizes the update before it is written to the database.
3✔
834
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
835
        switch updateType {
3✔
836
        case invpkg.AddHTLCsUpdate:
3✔
837
                return k.storeAddHtlcsUpdate()
3✔
838

839
        case invpkg.CancelHTLCsUpdate:
840
                return k.storeCancelHtlcsUpdate()
841

3✔
842
        case invpkg.SettleHodlInvoiceUpdate:
3✔
843
                return k.storeSettleHodlInvoiceUpdate()
3✔
844

3✔
845
        case invpkg.CancelInvoiceUpdate:
846
                // Persist all changes which where made when cancelling the
3✔
847
                // invoice. All HTLCs which were accepted are now canceled, so
3✔
848
                // we persist this state.
849
                return k.storeCancelHtlcsUpdate()
3✔
850
        }
3✔
851

852
        return fmt.Errorf("unknown update type: %v", updateType)
3✔
853
}
3✔
854

3✔
855
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
3✔
856
// set of HTLCs.
3✔
857
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
858
        err := k.serializeAndStoreInvoice()
859
        if err != nil {
×
860
                return err
861
        }
862

863
        // If this is an AMP invoice, then we'll actually store the rest
864
        // of the HTLCs in-line with the invoice, using the invoice ID
3✔
865
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
3✔
866
        // setID.
3✔
867
        if k.invoice.IsAMP() {
×
868
                return k.updateAMPInvoices()
×
869
        }
870

871
        return nil
872
}
873

874
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
3✔
UNCOV
875
// HTLCs.
×
UNCOV
876
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
×
877
        invoiceIsAMP := k.invoice.IsAMP()
878

3✔
879
        for htlcSetID := range k.updatedAmpHtlcs {
880
                // Check if this SetID already exist.
881
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
882

883
                if setIDInvNum == nil {
3✔
884
                        err := k.setIDIndexBucket.Put(
3✔
885
                                htlcSetID[:], k.invoiceNum,
3✔
886
                        )
6✔
887
                        if err != nil {
3✔
888
                                return err
3✔
889
                        }
3✔
890
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
6✔
891
                        return invpkg.ErrDuplicateSetID{
3✔
892
                                SetID: htlcSetID,
3✔
893
                        }
3✔
894
                }
3✔
895
        }
×
896

×
897
        // If this is a non-AMP invoice, then the state can eventually go to
3✔
UNCOV
898
        // ContractSettled, so we pass in nil value as part of
×
UNCOV
899
        // setSettleMetaFields.
×
UNCOV
900
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
×
UNCOV
901
                err := k.setSettleMetaFields(nil)
×
902
                if err != nil {
903
                        return err
904
                }
905
        }
906

907
        // As we don't update the settle index above for AMP invoices, we'll do
6✔
908
        // it here for each sub-AMP invoice that was settled.
3✔
909
        for settledSetID := range k.settledSetIDs {
3✔
910
                settledSetID := settledSetID
×
911
                err := k.setSettleMetaFields(&settledSetID)
×
912
                if err != nil {
913
                        return err
914
                }
915
        }
916

6✔
917
        err := k.serializeAndStoreInvoice()
3✔
918
        if err != nil {
3✔
919
                return err
3✔
920
        }
×
921

×
922
        // If this is an AMP invoice, then we'll actually store the rest of the
923
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
924
        // and the AMP key as a suffix: invoiceNum || setID.
3✔
925
        if invoiceIsAMP {
3✔
926
                return k.updateAMPInvoices()
×
927
        }
×
928

929
        return nil
930
}
931

932
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
6✔
933
// settling a hodl invoice.
3✔
934
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
3✔
935
        err := k.setSettleMetaFields(nil)
936
        if err != nil {
3✔
937
                return err
938
        }
939

940
        return k.serializeAndStoreInvoice()
941
}
3✔
942

3✔
943
// setSettleMetaFields updates the metadata associated with settlement of an
3✔
944
// invoice. If a non-nil setID is passed in, then the value will be append to
×
945
// the invoice number as well, in order to allow us to detect repeated payments
×
946
// to the same AMP invoices "across time".
947
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
3✔
948
        // Now that we know the invoice hasn't already been settled, we'll
949
        // update the settle index so we can place this settle event in the
950
        // proper location within our time series.
951
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
952
        if err != nil {
953
                return err
954
        }
3✔
955

3✔
956
        // Make a new byte array on the stack that can potentially store the 4
3✔
957
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
3✔
958
        // here which is the number of bytes copied so we can only store the 4
3✔
959
        // bytes if this is a non-AMP invoice.
3✔
960
        var indexKey [invoiceSetIDKeyLen]byte
×
961
        valueLen := copy(indexKey[:], k.invoiceNum)
×
962

963
        if setID != nil {
964
                valueLen += copy(indexKey[valueLen:], setID[:])
965
        }
966

967
        var seqNoBytes [8]byte
3✔
968
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
3✔
969
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
3✔
970
        if err != nil {
6✔
971
                return err
3✔
972
        }
3✔
973

974
        // If the setID is nil, then this means that this is a non-AMP settle,
3✔
975
        // so we'll update the invoice settle index directly.
3✔
976
        if setID == nil {
3✔
977
                k.invoice.SettleDate = k.updateTime
3✔
978
                k.invoice.SettleIndex = nextSettleSeqNo
×
979
        } else {
×
980
                // If the set ID isn't blank, we'll update the AMP state map
981
                // which tracks when each of the setIDs associated with a given
982
                // AMP invoice are settled.
983
                ampState := k.invoice.AMPState[*setID]
6✔
984

3✔
985
                ampState.SettleDate = k.updateTime
3✔
986
                ampState.SettleIndex = nextSettleSeqNo
6✔
987

3✔
988
                k.invoice.AMPState[*setID] = ampState
3✔
989
        }
3✔
990

3✔
991
        return nil
3✔
992
}
3✔
993

3✔
994
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
3✔
995
// then continually write the invoices to the end of the invoice value, we
3✔
996
// instead write the invoices into a new key preifx that follows the main
3✔
997
// invoice number. This ensures that we don't need to continually decode a
998
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
3✔
999
// associated with a particular HTLC set.
1000
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
1001
        for setID, htlcSet := range k.updatedAmpHtlcs {
1002
                // First write out the set of HTLCs including all the relevant
1003
                // TLV values.
1004
                var b bytes.Buffer
1005
                if err := serializeHtlcs(&b, htlcSet); err != nil {
1006
                        return err
1007
                }
3✔
1008

6✔
1009
                // Next store each HTLC in-line, using a prefix based off the
3✔
1010
                // invoice number.
3✔
1011
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
3✔
1012

3✔
1013
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
×
1014
                if err != nil {
×
1015
                        return err
1016
                }
1017
        }
1018

3✔
1019
        return nil
3✔
1020
}
3✔
1021

3✔
1022
// serializeAndStoreInvoice is a helper function used to store invoices.
×
1023
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
×
1024
        var buf bytes.Buffer
1025
        if err := serializeInvoice(&buf, k.invoice); err != nil {
1026
                return err
3✔
1027
        }
1028

1029
        return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
1030
}
3✔
1031

3✔
1032
// InvoicesSettledSince can be used by callers to catch up any settled invoices
3✔
1033
// they missed within the settled invoice time series. We'll return all known
×
1034
// settled invoice that have a settle index higher than the passed
×
1035
// sinceSettleIndex.
1036
//
3✔
1037
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1038
// value below the starting index value is a noop.
1039
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1040
        []invpkg.Invoice, error) {
1041

1042
        var settledInvoices []invpkg.Invoice
1043

1044
        // If an index of zero was specified, then in order to maintain
1045
        // backwards compat, we won't send out any new invoices.
1046
        if sinceSettleIndex == 0 {
1047
                return settledInvoices, nil
3✔
1048
        }
3✔
1049

3✔
1050
        var startIndex [8]byte
3✔
1051
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
3✔
1052

3✔
1053
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
1054
                invoices := tx.ReadBucket(invoiceBucket)
3✔
1055
                if invoices == nil {
3✔
1056
                        return nil
1057
                }
3✔
1058

3✔
1059
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
3✔
1060
                if settleIndex == nil {
6✔
1061
                        return nil
3✔
1062
                }
3✔
1063

×
1064
                // We'll now run through each entry in the add index starting
×
1065
                // at our starting index. We'll continue until we reach the
1066
                // very end of the current key space.
3✔
1067
                invoiceCursor := settleIndex.ReadCursor()
3✔
1068

×
1069
                // We'll seek to the starting index, then manually advance the
×
1070
                // cursor in order to skip the entry with the since add index.
1071
                invoiceCursor.Seek(startIndex[:])
1072
                seqNo, indexValue := invoiceCursor.Next()
1073

1074
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
3✔
1075
                        // Depending on the length of the index value, this may
3✔
1076
                        // or may not be an AMP invoice, so we'll extract the
3✔
1077
                        // invoice value into two components: the invoice num,
3✔
1078
                        // and the setID (may not be there).
3✔
1079
                        var (
3✔
1080
                                invoiceKey [4]byte
3✔
1081
                                setID      *invpkg.SetID
6✔
1082
                        )
3✔
1083

3✔
1084
                        valueLen := copy(invoiceKey[:], indexValue)
3✔
1085
                        if len(indexValue) == invoiceSetIDKeyLen {
3✔
1086
                                setID = new(invpkg.SetID)
3✔
1087
                                copy(setID[:], indexValue[valueLen:])
3✔
1088
                        }
3✔
1089

3✔
1090
                        // For each key found, we'll look up the actual
3✔
1091
                        // invoice, then accumulate it into our return value.
3✔
1092
                        invoice, err := fetchInvoice(
6✔
1093
                                invoiceKey[:], invoices, []*invpkg.SetID{setID},
3✔
1094
                                true,
3✔
1095
                        )
3✔
1096
                        if err != nil {
1097
                                return err
1098
                        }
1099

3✔
1100
                        settledInvoices = append(settledInvoices, invoice)
3✔
1101
                }
3✔
1102

3✔
1103
                return nil
3✔
1104
        }, func() {
×
1105
                settledInvoices = nil
×
1106
        })
1107
        if err != nil {
3✔
1108
                return nil, err
1109
        }
1110

3✔
1111
        return settledInvoices, nil
3✔
1112
}
3✔
1113

3✔
1114
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
3✔
1115
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
×
1116
        uint64, error) {
×
1117

1118
        // Create the invoice key which is just the big-endian representation
3✔
1119
        // of the invoice number.
1120
        var invoiceKey [4]byte
1121
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
1122

1123
        // Increment the num invoice counter index so the next invoice bares
3✔
1124
        // the proper ID.
3✔
1125
        var scratch [4]byte
3✔
1126
        invoiceCounter := invoiceNum + 1
3✔
1127
        byteOrder.PutUint32(scratch[:], invoiceCounter)
3✔
1128
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
3✔
1129
                return 0, err
3✔
1130
        }
3✔
1131

3✔
1132
        // Add the payment hash to the invoice index. This will let us quickly
3✔
1133
        // identify if we can settle an incoming payment, and also to possibly
3✔
1134
        // allow a single invoice to have multiple payment installations.
3✔
1135
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
3✔
1136
        if err != nil {
×
1137
                return 0, err
×
1138
        }
1139

1140
        // Add the invoice to the payment address index, but only if the invoice
1141
        // has a non-zero payment address. The all-zero payment address is still
1142
        // in use by legacy keysend, so we special-case here to avoid
3✔
1143
        // collisions.
3✔
1144
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
×
1145
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
×
1146
                if err != nil {
1147
                        return 0, err
1148
                }
1149
        }
1150

1151
        // Next, we'll obtain the next add invoice index (sequence
6✔
1152
        // number), so we can properly place this invoice within this
3✔
1153
        // event stream.
3✔
1154
        nextAddSeqNo, err := addIndex.NextSequence()
×
1155
        if err != nil {
×
1156
                return 0, err
1157
        }
1158

1159
        // With the next sequence obtained, we'll updating the event series in
1160
        // the add index bucket to map this current add counter to the index of
1161
        // this new invoice.
3✔
1162
        var seqNoBytes [8]byte
3✔
1163
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
×
1164
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
×
1165
                return 0, err
1166
        }
1167

1168
        i.AddIndex = nextAddSeqNo
1169

3✔
1170
        // Finally, serialize the invoice itself to be written to the disk.
3✔
1171
        var buf bytes.Buffer
3✔
1172
        if err := serializeInvoice(&buf, i); err != nil {
×
1173
                return 0, err
×
1174
        }
1175

3✔
1176
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
3✔
1177
                return 0, err
3✔
1178
        }
3✔
1179

3✔
1180
        return nextAddSeqNo, nil
×
1181
}
×
1182

1183
// recordSize returns the amount of bytes this TLV record will occupy when
3✔
1184
// encoded.
×
1185
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
×
1186
        var (
1187
                b   bytes.Buffer
3✔
1188
                buf [8]byte
1189
        )
1190

1191
        // We know that encoding works since the tests pass in the build this
1192
        // file is checked into, so we'll simplify things and simply encode it
3✔
1193
        // ourselves then report the total amount of bytes used.
3✔
1194
        if err := ampStateEncoder(&b, a, &buf); err != nil {
3✔
1195
                // This should never error out, but we log it just in case it
3✔
1196
                // does.
3✔
1197
                log.Errorf("encoding the amp invoice state failed: %v", err)
3✔
1198
        }
3✔
1199

3✔
1200
        return func() uint64 {
3✔
1201
                return uint64(len(b.Bytes()))
3✔
1202
        }
×
1203
}
×
1204

×
1205
// serializeInvoice serializes an invoice to a writer.
×
1206
//
1207
// Note: this function is in use for a migration. Before making changes that
6✔
1208
// would modify the on disk format, make a copy of the original code and store
3✔
1209
// it with the migration.
3✔
1210
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
1211
        creationDateBytes, err := i.CreationDate.MarshalBinary()
1212
        if err != nil {
1213
                return err
1214
        }
1215

1216
        settleDateBytes, err := i.SettleDate.MarshalBinary()
1217
        if err != nil {
3✔
1218
                return err
3✔
1219
        }
3✔
1220

×
1221
        var fb bytes.Buffer
×
1222
        err = i.Terms.Features.EncodeBase256(&fb)
1223
        if err != nil {
3✔
1224
                return err
3✔
1225
        }
×
1226
        featureBytes := fb.Bytes()
×
1227

1228
        preimage := [32]byte(invpkg.UnknownPreimage)
3✔
1229
        if i.Terms.PaymentPreimage != nil {
3✔
1230
                preimage = *i.Terms.PaymentPreimage
3✔
1231
                if preimage == invpkg.UnknownPreimage {
×
1232
                        return errors.New("cannot use all-zeroes preimage")
×
1233
                }
3✔
1234
        }
3✔
1235
        value := uint64(i.Terms.Value)
3✔
1236
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
6✔
1237
        expiry := uint64(i.Terms.Expiry)
3✔
1238

3✔
1239
        amtPaid := uint64(i.AmtPaid)
×
1240
        state := uint8(i.State)
×
1241

1242
        var hodlInvoice uint8
3✔
1243
        if i.HodlInvoice {
3✔
1244
                hodlInvoice = 1
3✔
1245
        }
3✔
1246

3✔
1247
        tlvStream, err := tlv.NewStream(
3✔
1248
                // Memo and payreq.
3✔
1249
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1250
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
6✔
1251

3✔
1252
                // Add/settle metadata.
3✔
1253
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
1254
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1255
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1256
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1257

3✔
1258
                // Terms.
3✔
1259
                tlv.MakePrimitiveRecord(preimageType, &preimage),
3✔
1260
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1261
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1262
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1263
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1264
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1265

3✔
1266
                // Invoice state.
3✔
1267
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1268
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1269

3✔
1270
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1271

3✔
1272
                // Invoice AMP state.
3✔
1273
                tlv.MakeDynamicRecord(
3✔
1274
                        invoiceAmpStateType, &i.AMPState,
3✔
1275
                        ampRecordSize(&i.AMPState),
3✔
1276
                        ampStateEncoder, ampStateDecoder,
3✔
1277
                ),
3✔
1278
        )
3✔
1279
        if err != nil {
3✔
1280
                return err
3✔
1281
        }
3✔
1282

3✔
1283
        var b bytes.Buffer
3✔
1284
        if err = tlvStream.Encode(&b); err != nil {
3✔
1285
                return err
3✔
1286
        }
3✔
1287

×
1288
        err = binary.Write(w, byteOrder, uint64(b.Len()))
×
1289
        if err != nil {
1290
                return err
3✔
1291
        }
3✔
1292

×
1293
        if _, err = w.Write(b.Bytes()); err != nil {
×
1294
                return err
1295
        }
3✔
1296

3✔
1297
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
×
1298
        // in-line with the rest of the invoice.
×
1299
        if i.IsAMP() {
1300
                return nil
3✔
1301
        }
×
1302

×
1303
        return serializeHtlcs(w, i.Htlcs)
1304
}
1305

1306
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
6✔
1307
// a writer.
3✔
1308
func serializeHtlcs(w io.Writer,
3✔
1309
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
1310

3✔
1311
        for key, htlc := range htlcs {
1312
                // Encode the htlc in a tlv stream.
1313
                chanID := key.ChanID.ToUint64()
1314
                amt := uint64(htlc.Amt)
1315
                mppTotalAmt := uint64(htlc.MppTotalAmt)
1316
                acceptTime := putNanoTime(htlc.AcceptTime)
3✔
1317
                resolveTime := putNanoTime(htlc.ResolveTime)
3✔
1318
                state := uint8(htlc.State)
6✔
1319

3✔
1320
                var records []tlv.Record
3✔
1321
                records = append(records,
3✔
1322
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
1323
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
1324
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
1325
                        tlv.MakePrimitiveRecord(
3✔
1326
                                acceptHeightType, &htlc.AcceptHeight,
3✔
1327
                        ),
3✔
1328
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
1329
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
1330
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
1331
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
1332
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
1333
                )
3✔
1334

3✔
1335
                if htlc.AMP != nil {
3✔
1336
                        setIDRecord := tlv.MakeDynamicRecord(
3✔
1337
                                htlcAMPType, &htlc.AMP.Record,
3✔
1338
                                htlc.AMP.Record.PayloadSize,
3✔
1339
                                record.AMPEncoder, record.AMPDecoder,
3✔
1340
                        )
3✔
1341
                        records = append(records, setIDRecord)
3✔
1342

6✔
1343
                        hash32 := [32]byte(htlc.AMP.Hash)
3✔
1344
                        hashRecord := tlv.MakePrimitiveRecord(
3✔
1345
                                htlcHashType, &hash32,
3✔
1346
                        )
3✔
1347
                        records = append(records, hashRecord)
3✔
1348

3✔
1349
                        if htlc.AMP.Preimage != nil {
3✔
1350
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
3✔
1351
                                preimageRecord := tlv.MakePrimitiveRecord(
3✔
1352
                                        htlcPreimageType, &preimage32,
3✔
1353
                                )
3✔
1354
                                records = append(records, preimageRecord)
3✔
1355
                        }
3✔
1356
                }
6✔
1357

3✔
1358
                // Convert the custom records to tlv.Record types that are ready
3✔
1359
                // for serialization.
3✔
1360
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
3✔
1361

3✔
1362
                // Append the custom records. Their ids are in the experimental
3✔
1363
                // range and sorted, so there is no need to sort again.
1364
                records = append(records, customRecords...)
1365

1366
                tlvStream, err := tlv.NewStream(records...)
1367
                if err != nil {
3✔
1368
                        return err
3✔
1369
                }
3✔
1370

3✔
1371
                var b bytes.Buffer
3✔
1372
                if err := tlvStream.Encode(&b); err != nil {
3✔
1373
                        return err
3✔
1374
                }
3✔
1375

×
1376
                // Write the length of the tlv stream followed by the stream
×
1377
                // bytes.
1378
                err = binary.Write(w, byteOrder, uint64(b.Len()))
3✔
1379
                if err != nil {
3✔
1380
                        return err
×
1381
                }
×
1382

1383
                if _, err := w.Write(b.Bytes()); err != nil {
1384
                        return err
1385
                }
3✔
1386
        }
3✔
1387

×
1388
        return nil
×
1389
}
1390

3✔
1391
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
×
1392
// timestamp will be mapped to 0, since calling UnixNano in that case is
×
1393
// undefined.
1394
func putNanoTime(t time.Time) uint64 {
1395
        if t.IsZero() {
3✔
1396
                return 0
1397
        }
1398
        return uint64(t.UnixNano())
1399
}
1400

1401
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
3✔
1402
// is provided, an zero-value time stamp is returned.
6✔
1403
func getNanoTime(ns uint64) time.Time {
3✔
1404
        if ns == 0 {
3✔
1405
                return time.Time{}
3✔
1406
        }
1407
        return time.Unix(0, int64(ns))
1408
}
1409

1410
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
3✔
1411
// identified by the setID value.
6✔
1412
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
3✔
1413
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
3✔
1414
        error) {
3✔
1415

1416
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
1417
        for _, setID := range setIDs {
1418
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
1419

1420
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
1421
                if htlcSetBytes == nil {
3✔
1422
                        // A set ID was passed in, but we don't have this
3✔
1423
                        // stored yet, meaning that the setID is being added
3✔
1424
                        // for the first time.
6✔
1425
                        return htlcs, invpkg.ErrInvoiceNotFound
3✔
1426
                }
3✔
1427

3✔
1428
                htlcSetReader := bytes.NewReader(htlcSetBytes)
6✔
1429
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
3✔
1430
                if err != nil {
3✔
1431
                        return nil, err
3✔
1432
                }
3✔
1433

3✔
1434
                for key, htlc := range htlcsBySetID {
1435
                        htlcs[key] = htlc
3✔
1436
                }
3✔
1437
        }
3✔
1438

×
1439
        return htlcs, nil
×
1440
}
1441

6✔
1442
// forEachAMPInvoice is a helper function that attempts to iterate over each of
3✔
1443
// the HTLC sets (based on their set ID) for the given AMP invoice identified
3✔
1444
// by its invoiceNum. The callback closure is called for each key within the
1445
// prefix range.
1446
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
3✔
1447
        callback func(key, htlcSet []byte) error) error {
1448

1449
        invoiceCursor := invoiceBucket.ReadCursor()
1450

1451
        // Seek to the first key that includes the invoice data itself.
1452
        invoiceCursor.Seek(invoiceNum)
1453

1454
        // Advance to the very first key _after_ the invoice data, as this is
3✔
1455
        // where we'll encounter our first HTLC (if any are present).
3✔
1456
        cursorKey, htlcSet := invoiceCursor.Next()
3✔
1457

3✔
1458
        // If at this point, the cursor key doesn't match the invoice num
3✔
1459
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
3✔
1460
        // associated with it.
3✔
1461
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
3✔
1462
                return nil
3✔
1463
        }
3✔
1464

3✔
1465
        // Otherwise continue to iterate until we no longer match the prefix,
3✔
1466
        // executing the call back at each step.
3✔
1467
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
3✔
1468
                err := callback(cursorKey, htlcSet)
6✔
1469
                if err != nil {
3✔
1470
                        return err
3✔
1471
                }
1472
        }
1473

1474
        return nil
6✔
1475
}
3✔
1476

3✔
1477
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
×
1478
// AMP bucket to find all the individual HTLCs (by setID) associated with a
×
1479
// given invoice. If a list of set IDs are specified, then only HTLCs
1480
// associated with that setID will be retrieved.
1481
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
3✔
1482
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1483
        error) {
1484

1485
        // If a set of setIDs was specified, then we can skip the cursor and
1486
        // just read out exactly what we need.
1487
        if len(setIDs) != 0 && setIDs[0] != nil {
1488
                return fetchFilteredAmpInvoices(
1489
                        invoiceBucket, invoiceNum, setIDs...,
1490
                )
3✔
1491
        }
3✔
1492

3✔
1493
        // Otherwise, iterate over all the htlc sets that are prefixed beside
3✔
1494
        // this invoice in the main invoice bucket.
6✔
1495
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1496
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
3✔
1497
                func(key, htlcSet []byte) error {
3✔
1498
                        htlcSetReader := bytes.NewReader(htlcSet)
3✔
1499
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
1500
                        if err != nil {
1501
                                return err
1502
                        }
3✔
1503

3✔
1504
                        for key, htlc := range htlcsBySetID {
6✔
1505
                                htlcs[key] = htlc
3✔
1506
                        }
3✔
1507

3✔
1508
                        return nil
×
1509
                },
×
1510
        )
1511

6✔
1512
        if err != nil {
3✔
1513
                return nil, err
3✔
1514
        }
1515

3✔
1516
        return htlcs, nil
1517
}
1518

1519
// fetchInvoice attempts to read out the relevant state for the invoice as
3✔
1520
// specified by the invoice number. If the setID fields are set, then only the
×
1521
// HTLC information pertaining to those set IDs is returned.
×
1522
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1523
        setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
3✔
1524

1525
        invoiceBytes := invoices.Get(invoiceNum)
1526
        if invoiceBytes == nil {
1527
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
1528
        }
1529

1530
        invoiceReader := bytes.NewReader(invoiceBytes)
3✔
1531

3✔
1532
        invoice, err := deserializeInvoice(invoiceReader)
3✔
1533
        if err != nil {
3✔
1534
                return invpkg.Invoice{}, err
×
1535
        }
×
1536

1537
        // If this is an AMP invoice we'll also attempt to read out the set of
3✔
1538
        // HTLCs that were paid to prior set IDs, if needed.
3✔
1539
        if !invoice.IsAMP() {
3✔
1540
                return invoice, nil
3✔
1541
        }
×
1542

×
1543
        if shouldFetchAMPHTLCs(invoice, setIDs) {
1544
                invoice.Htlcs, err = fetchAmpSubInvoices(
1545
                        invoices, invoiceNum, setIDs...,
1546
                )
6✔
1547
                // TODO(positiveblue): we should fail when we are not able to
3✔
1548
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
3✔
1549
                // the invoice and channeldb package break if we return this
1550
                // error. We need to update them when we migrate this logic to
6✔
1551
                // the sql implementation.
3✔
1552
                if err != nil {
3✔
1553
                        log.Errorf("unable to fetch amp htlcs for inv "+
3✔
1554
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
3✔
1555
                }
3✔
1556

3✔
1557
                if filterAMPState {
3✔
1558
                        filterInvoiceAMPState(&invoice, setIDs...)
3✔
1559
                }
6✔
1560
        }
3✔
1561

3✔
1562
        return invoice, nil
3✔
1563
}
1564

6✔
1565
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
3✔
1566
// were paid to the relevant set IDs.
3✔
1567
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
1568
        // For AMP invoice that already have HTLCs populated (created before
1569
        // recurring invoices), then we don't need to read from the prefix
3✔
1570
        // keyed section of the bucket.
1571
        if len(invoice.Htlcs) != 0 {
1572
                return false
1573
        }
1574

3✔
1575
        // If the "zero" setID was specified, then this means that no HTLC data
3✔
1576
        // should be returned alongside of it.
3✔
1577
        if len(setIDs) != 0 && setIDs[0] != nil &&
3✔
1578
                *setIDs[0] == invpkg.BlankPayAddr {
3✔
1579

×
1580
                return false
×
1581
        }
1582

1583
        return true
1584
}
3✔
1585

6✔
1586
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
3✔
1587
// an AMP invoice. This methods only decode the relevant state vs the entire
3✔
1588
// invoice.
3✔
1589
func fetchInvoiceStateAMP(invoiceNum []byte,
1590
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
3✔
1591

1592
        // Fetch the raw invoice bytes.
1593
        invoiceBytes := invoices.Get(invoiceNum)
1594
        if invoiceBytes == nil {
1595
                return nil, invpkg.ErrInvoiceNotFound
1596
        }
UNCOV
1597

×
UNCOV
1598
        r := bytes.NewReader(invoiceBytes)
×
UNCOV
1599

×
UNCOV
1600
        var bodyLen int64
×
UNCOV
1601
        err := binary.Read(r, byteOrder, &bodyLen)
×
1602
        if err != nil {
×
1603
                return nil, err
×
1604
        }
UNCOV
1605

×
UNCOV
1606
        // Next, we'll make a new TLV stream that only attempts to decode the
×
UNCOV
1607
        // bytes we actually need.
×
UNCOV
1608
        ampState := make(invpkg.AMPInvoiceState)
×
UNCOV
1609
        tlvStream, err := tlv.NewStream(
×
1610
                // Invoice AMP state.
×
1611
                tlv.MakeDynamicRecord(
×
1612
                        invoiceAmpStateType, &ampState, nil,
1613
                        ampStateEncoder, ampStateDecoder,
1614
                ),
UNCOV
1615
        )
×
UNCOV
1616
        if err != nil {
×
UNCOV
1617
                return nil, err
×
UNCOV
1618
        }
×
UNCOV
1619

×
UNCOV
1620
        invoiceReader := io.LimitReader(r, bodyLen)
×
UNCOV
1621
        if err = tlvStream.Decode(invoiceReader); err != nil {
×
UNCOV
1622
                return nil, err
×
UNCOV
1623
        }
×
1624

×
1625
        return ampState, nil
×
1626
}
UNCOV
1627

×
UNCOV
1628
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
×
1629
        var (
×
1630
                preimageBytes [32]byte
×
1631
                value         uint64
UNCOV
1632
                cltvDelta     uint32
×
1633
                expiry        uint64
1634
                amtPaid       uint64
1635
                state         uint8
3✔
1636
                hodlInvoice   uint8
3✔
1637

3✔
1638
                creationDateBytes []byte
3✔
1639
                settleDateBytes   []byte
3✔
1640
                featureBytes      []byte
3✔
1641
        )
3✔
1642

3✔
1643
        var i invpkg.Invoice
3✔
1644
        i.AMPState = make(invpkg.AMPInvoiceState)
3✔
1645
        tlvStream, err := tlv.NewStream(
3✔
1646
                // Memo and payreq.
3✔
1647
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1648
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3✔
1649

3✔
1650
                // Add/settle metadata.
3✔
1651
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3✔
1652
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1653
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1654
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1655

3✔
1656
                // Terms.
3✔
1657
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
3✔
1658
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1659
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1660
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1661
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1662
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1663

3✔
1664
                // Invoice state.
3✔
1665
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1666
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1667

3✔
1668
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1669

3✔
1670
                // Invoice AMP state.
3✔
1671
                tlv.MakeDynamicRecord(
3✔
1672
                        invoiceAmpStateType, &i.AMPState, nil,
3✔
1673
                        ampStateEncoder, ampStateDecoder,
3✔
1674
                ),
3✔
1675
        )
3✔
1676
        if err != nil {
3✔
1677
                return i, err
3✔
1678
        }
3✔
1679

3✔
1680
        var bodyLen int64
3✔
1681
        err = binary.Read(r, byteOrder, &bodyLen)
3✔
1682
        if err != nil {
3✔
1683
                return i, err
3✔
1684
        }
×
1685

×
1686
        lr := io.LimitReader(r, bodyLen)
1687
        if err = tlvStream.Decode(lr); err != nil {
3✔
1688
                return i, err
3✔
1689
        }
3✔
1690

×
1691
        preimage := lntypes.Preimage(preimageBytes)
×
1692
        if preimage != invpkg.UnknownPreimage {
1693
                i.Terms.PaymentPreimage = &preimage
3✔
1694
        }
3✔
1695

×
1696
        i.Terms.Value = lnwire.MilliSatoshi(value)
×
1697
        i.Terms.FinalCltvDelta = int32(cltvDelta)
1698
        i.Terms.Expiry = time.Duration(expiry)
3✔
1699
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
6✔
1700
        i.State = invpkg.ContractState(state)
3✔
1701

3✔
1702
        if hodlInvoice != 0 {
1703
                i.HodlInvoice = true
3✔
1704
        }
3✔
1705

3✔
1706
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
3✔
1707
        if err != nil {
3✔
1708
                return i, err
3✔
1709
        }
6✔
1710

3✔
1711
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
3✔
1712
        if err != nil {
1713
                return i, err
3✔
1714
        }
3✔
1715

×
1716
        rawFeatures := lnwire.NewRawFeatureVector()
×
1717
        err = rawFeatures.DecodeBase256(
1718
                bytes.NewReader(featureBytes), len(featureBytes),
3✔
1719
        )
3✔
1720
        if err != nil {
×
1721
                return i, err
×
1722
        }
1723

3✔
1724
        i.Terms.Features = lnwire.NewFeatureVector(
3✔
1725
                rawFeatures, lnwire.Features,
3✔
1726
        )
3✔
1727

3✔
1728
        i.Htlcs, err = deserializeHtlcs(r)
×
1729
        return i, err
×
1730
}
1731

3✔
1732
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1733
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
3✔
1734
                // We encode the set of circuit keys as a varint length prefix.
3✔
1735
                // followed by a series of fixed sized uint8 integers.
3✔
1736
                numKeys := uint64(len(*v))
3✔
1737

1738
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
1739
                        return err
3✔
1740
                }
6✔
1741

3✔
1742
                for key := range *v {
3✔
1743
                        scidInt := key.ChanID.ToUint64()
3✔
1744

3✔
1745
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
3✔
1746
                                return err
×
1747
                        }
×
1748
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
1749
                                return err
6✔
1750
                        }
3✔
1751
                }
3✔
1752

3✔
1753
                return nil
×
1754
        }
×
1755

3✔
1756
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1757
}
×
1758

1759
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1760
        l uint64) error {
3✔
1761

1762
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
1763
                // First, we'll read out the varint that encodes the number of
×
1764
                // circuit keys encoded.
1765
                numKeys, err := tlv.ReadVarInt(r, buf)
1766
                if err != nil {
1767
                        return err
3✔
1768
                }
3✔
1769

6✔
1770
                // Now that we know how many keys to expect, iterate reading
3✔
1771
                // each one until we're done.
3✔
1772
                for i := uint64(0); i < numKeys; i++ {
3✔
1773
                        var (
3✔
1774
                                key  models.CircuitKey
×
1775
                                scid uint64
×
1776
                        )
1777

1778
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
1779
                                return err
6✔
1780
                        }
3✔
1781

3✔
1782
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
3✔
1783

3✔
1784
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
3✔
1785
                        if err != nil {
3✔
1786
                                return err
×
1787
                        }
×
1788

1789
                        (*v)[key] = struct{}{}
3✔
1790
                }
3✔
1791

3✔
1792
                return nil
3✔
1793
        }
×
1794

×
1795
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
1796
}
3✔
1797

1798
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1799
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1800
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
1801
                // We'll encode the AMP state as a series of KV pairs on the
1802
                // wire with a length prefix.
×
1803
                numRecords := uint64(len(*v))
1804

1805
                // First, we'll write out the number of records as a var int.
1806
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
3✔
1807
                        return err
6✔
1808
                }
3✔
1809

3✔
1810
                // With that written out, we'll now encode the entries
3✔
1811
                // themselves as a sub-TLV record, which includes its _own_
3✔
1812
                // inner length prefix.
3✔
1813
                for setID, ampState := range *v {
3✔
1814
                        setID := [32]byte(setID)
×
1815
                        ampState := ampState
×
1816

1817
                        htlcState := uint8(ampState.State)
1818
                        settleDate := ampState.SettleDate
1819
                        settleDateBytes, err := settleDate.MarshalBinary()
1820
                        if err != nil {
6✔
1821
                                return err
3✔
1822
                        }
3✔
1823

3✔
1824
                        amtPaid := uint64(ampState.AmtPaid)
3✔
1825

3✔
1826
                        var ampStateTlvBytes bytes.Buffer
3✔
1827
                        tlvStream, err := tlv.NewStream(
3✔
1828
                                tlv.MakePrimitiveRecord(
×
1829
                                        ampStateSetIDType, &setID,
×
1830
                                ),
1831
                                tlv.MakePrimitiveRecord(
3✔
1832
                                        ampStateHtlcStateType, &htlcState,
3✔
1833
                                ),
3✔
1834
                                tlv.MakePrimitiveRecord(
3✔
1835
                                        ampStateSettleIndexType,
3✔
1836
                                        &ampState.SettleIndex,
3✔
1837
                                ),
3✔
1838
                                tlv.MakePrimitiveRecord(
3✔
1839
                                        ampStateSettleDateType,
3✔
1840
                                        &settleDateBytes,
3✔
1841
                                ),
3✔
1842
                                tlv.MakeDynamicRecord(
3✔
1843
                                        ampStateCircuitKeysType,
3✔
1844
                                        &ampState.InvoiceKeys,
3✔
1845
                                        func() uint64 {
3✔
1846
                                                // The record takes 8 bytes to
3✔
1847
                                                // encode the set of circuits,
3✔
1848
                                                // 8 bytes for the scid for the
3✔
1849
                                                // key, and 8 bytes for the HTLC
3✔
1850
                                                // index.
3✔
1851
                                                keys := ampState.InvoiceKeys
3✔
1852
                                                numKeys := uint64(len(keys))
6✔
1853
                                                size := tlv.VarIntSize(numKeys)
3✔
1854
                                                dataSize := (numKeys * 16)
3✔
1855

3✔
1856
                                                return size + dataSize
3✔
1857
                                        },
3✔
1858
                                        encodeCircuitKeys, decodeCircuitKeys,
3✔
1859
                                ),
3✔
1860
                                tlv.MakePrimitiveRecord(
3✔
1861
                                        ampStateAmtPaidType, &amtPaid,
3✔
1862
                                ),
3✔
1863
                        )
3✔
1864
                        if err != nil {
3✔
1865
                                return err
1866
                        }
1867

1868
                        err = tlvStream.Encode(&ampStateTlvBytes)
1869
                        if err != nil {
1870
                                return err
1871
                        }
3✔
1872

×
1873
                        // We encode the record with a varint length followed by
×
1874
                        // the _raw_ TLV bytes.
1875
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
3✔
1876
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
3✔
1877
                                return err
×
1878
                        }
×
1879

1880
                        _, err = w.Write(ampStateTlvBytes.Bytes())
1881
                        if err != nil {
1882
                                return err
3✔
1883
                        }
3✔
1884
                }
×
1885

×
1886
                return nil
1887
        }
3✔
1888

3✔
1889
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1890
}
×
1891

1892
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1893
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
3✔
1894
        l uint64) error {
1895

1896
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
×
1897
                // First, we'll decode the varint that encodes how many set IDs
1898
                // are encoded within the greater map.
1899
                numRecords, err := tlv.ReadVarInt(r, buf)
1900
                if err != nil {
1901
                        return err
3✔
1902
                }
3✔
1903

6✔
1904
                // Now that we know how many records we'll need to read, we can
3✔
1905
                // iterate and read them all out in series.
3✔
1906
                for i := uint64(0); i < numRecords; i++ {
3✔
1907
                        // Read out the varint that encodes the size of this
3✔
1908
                        // inner TLV record.
×
1909
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
×
1910
                        if err != nil {
1911
                                return err
1912
                        }
1913

6✔
1914
                        // Using this information, we'll create a new limited
3✔
1915
                        // reader that'll return an EOF once the end has been
3✔
1916
                        // reached so the stream stops consuming bytes.
3✔
1917
                        innerTlvReader := io.LimitedReader{
3✔
1918
                                R: r,
×
1919
                                N: int64(stateRecordSize),
×
1920
                        }
1921

1922
                        var (
1923
                                setID           [32]byte
1924
                                htlcState       uint8
3✔
1925
                                settleIndex     uint64
3✔
1926
                                settleDateBytes []byte
3✔
1927
                                invoiceKeys     = make(
3✔
1928
                                        map[models.CircuitKey]struct{},
3✔
1929
                                )
3✔
1930
                                amtPaid uint64
3✔
1931
                        )
3✔
1932
                        tlvStream, err := tlv.NewStream(
3✔
1933
                                tlv.MakePrimitiveRecord(
3✔
1934
                                        ampStateSetIDType, &setID,
3✔
1935
                                ),
3✔
1936
                                tlv.MakePrimitiveRecord(
3✔
1937
                                        ampStateHtlcStateType, &htlcState,
3✔
1938
                                ),
3✔
1939
                                tlv.MakePrimitiveRecord(
3✔
1940
                                        ampStateSettleIndexType, &settleIndex,
3✔
1941
                                ),
3✔
1942
                                tlv.MakePrimitiveRecord(
3✔
1943
                                        ampStateSettleDateType,
3✔
1944
                                        &settleDateBytes,
3✔
1945
                                ),
3✔
1946
                                tlv.MakeDynamicRecord(
3✔
1947
                                        ampStateCircuitKeysType,
3✔
1948
                                        &invoiceKeys, nil,
3✔
1949
                                        encodeCircuitKeys, decodeCircuitKeys,
3✔
1950
                                ),
3✔
1951
                                tlv.MakePrimitiveRecord(
3✔
1952
                                        ampStateAmtPaidType, &amtPaid,
3✔
1953
                                ),
3✔
1954
                        )
3✔
1955
                        if err != nil {
3✔
1956
                                return err
3✔
1957
                        }
3✔
1958

3✔
1959
                        err = tlvStream.Decode(&innerTlvReader)
3✔
1960
                        if err != nil {
3✔
1961
                                return err
3✔
1962
                        }
3✔
1963

×
1964
                        var settleDate time.Time
×
1965
                        err = settleDate.UnmarshalBinary(settleDateBytes)
1966
                        if err != nil {
3✔
1967
                                return err
3✔
1968
                        }
×
1969

×
1970
                        (*v)[setID] = invpkg.InvoiceStateAMP{
1971
                                State:       invpkg.HtlcState(htlcState),
3✔
1972
                                SettleIndex: settleIndex,
3✔
1973
                                SettleDate:  settleDate,
3✔
1974
                                InvoiceKeys: invoiceKeys,
×
1975
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
×
1976
                        }
1977
                }
3✔
1978

3✔
1979
                return nil
3✔
1980
        }
3✔
1981

3✔
1982
        return tlv.NewTypeForDecodingErr(
3✔
1983
                val, "channeldb.AMPInvoiceState", l, l,
3✔
1984
        )
1985
}
1986

3✔
1987
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
1988
// as a map.
1989
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
×
1990
        error) {
×
1991

×
1992
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
1993
        for {
1994
                // Read the length of the tlv stream for this htlc.
1995
                var streamLen int64
1996
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
1997
                        if err == io.EOF {
3✔
1998
                                break
3✔
1999
                        }
3✔
2000

6✔
2001
                        return nil, err
3✔
2002
                }
3✔
2003

6✔
2004
                // Limit the reader so that it stops at the end of this htlc's
6✔
2005
                // stream.
3✔
2006
                htlcReader := io.LimitReader(r, streamLen)
2007

2008
                // Decode the contents into the htlc fields.
×
2009
                var (
2010
                        htlc                    invpkg.InvoiceHTLC
2011
                        key                     models.CircuitKey
2012
                        chanID                  uint64
2013
                        state                   uint8
3✔
2014
                        acceptTime, resolveTime uint64
3✔
2015
                        amt, mppTotalAmt        uint64
3✔
2016
                        amp                     = &record.AMP{}
3✔
2017
                        hash32                  = &[32]byte{}
3✔
2018
                        preimage32              = &[32]byte{}
3✔
2019
                )
3✔
2020
                tlvStream, err := tlv.NewStream(
3✔
2021
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
2022
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
2023
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
2024
                        tlv.MakePrimitiveRecord(
3✔
2025
                                acceptHeightType, &htlc.AcceptHeight,
3✔
2026
                        ),
3✔
2027
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
2028
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
2029
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
2030
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
2031
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
2032
                        tlv.MakeDynamicRecord(
3✔
2033
                                htlcAMPType, amp, amp.PayloadSize,
3✔
2034
                                record.AMPEncoder, record.AMPDecoder,
3✔
2035
                        ),
3✔
2036
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
3✔
2037
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
3✔
2038
                )
3✔
2039
                if err != nil {
3✔
2040
                        return nil, err
3✔
2041
                }
3✔
2042

3✔
2043
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
3✔
2044
                if err != nil {
3✔
2045
                        return nil, err
3✔
2046
                }
3✔
2047

×
2048
                if _, ok := parsedTypes[htlcAMPType]; !ok {
×
2049
                        amp = nil
2050
                }
3✔
2051

3✔
2052
                var preimage *lntypes.Preimage
×
2053
                if _, ok := parsedTypes[htlcPreimageType]; ok {
×
2054
                        pimg := lntypes.Preimage(*preimage32)
2055
                        preimage = &pimg
6✔
2056
                }
3✔
2057

3✔
2058
                var hash *lntypes.Hash
2059
                if _, ok := parsedTypes[htlcHashType]; ok {
3✔
2060
                        h := lntypes.Hash(*hash32)
6✔
2061
                        hash = &h
3✔
2062
                }
3✔
2063

3✔
2064
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
2065
                htlc.AcceptTime = getNanoTime(acceptTime)
3✔
2066
                htlc.ResolveTime = getNanoTime(resolveTime)
6✔
2067
                htlc.State = invpkg.HtlcState(state)
3✔
2068
                htlc.Amt = lnwire.MilliSatoshi(amt)
3✔
2069
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
3✔
2070
                if amp != nil && hash != nil {
2071
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
3✔
2072
                                Record:   *amp,
3✔
2073
                                Hash:     *hash,
3✔
2074
                                Preimage: preimage,
3✔
2075
                        }
3✔
2076
                }
3✔
2077

6✔
2078
                // Reconstruct the custom records fields from the parsed types
3✔
2079
                // map return from the tlv parser.
3✔
2080
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
3✔
2081

3✔
2082
                htlcs[key] = &htlc
3✔
2083
        }
3✔
2084

2085
        return htlcs, nil
2086
}
2087

3✔
2088
// invoiceSetIDKeyLen is the length of the key that's used to store the
3✔
2089
// individual HTLCs prefixed by their ID along side the main invoice within the
3✔
2090
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2091
// set ID.
2092
const invoiceSetIDKeyLen = 4 + 32
3✔
2093

2094
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2095
// number where the HTLCs for this setID will be stored udner.
2096
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
2097
        // Construct the prefix key we need to obtain the invoice information:
2098
        // invoiceNum || setID.
2099
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
2100
        copy(invoiceSetIDKey[:], invoiceNum)
2101
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
2102

2103
        return invoiceSetIDKey
3✔
2104
}
3✔
2105

3✔
2106
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
3✔
2107
// greater AMP invoices. We do this by deleting the set of keys that share the
3✔
2108
// invoice number as a prefix.
3✔
2109
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
3✔
2110
        // Since it isn't safe to delete using an active cursor, we'll use the
3✔
2111
        // cursor simply to collect the set of keys we need to delete, _then_
3✔
2112
        // delete them in another pass.
2113
        var keysToDel [][]byte
2114
        err := forEachAMPInvoice(
2115
                invoiceBucket, invoiceNum,
UNCOV
2116
                func(cursorKey, v []byte) error {
×
UNCOV
2117
                        keysToDel = append(keysToDel, cursorKey)
×
UNCOV
2118
                        return nil
×
UNCOV
2119
                },
×
UNCOV
2120
        )
×
UNCOV
2121
        if err != nil {
×
UNCOV
2122
                return err
×
UNCOV
2123
        }
×
UNCOV
2124

×
UNCOV
2125
        // In this next phase, we'll then delete all the relevant invoices.
×
UNCOV
2126
        for _, keyToDel := range keysToDel {
×
2127
                if err := invoiceBucket.Delete(keyToDel); err != nil {
UNCOV
2128
                        return err
×
2129
                }
×
2130
        }
×
2131

2132
        return nil
UNCOV
2133
}
×
UNCOV
2134

×
2135
// delAMPSettleIndex removes all the entries in the settle index associated
×
2136
// with a given AMP invoice.
×
2137
func delAMPSettleIndex(invoiceNum []byte, invoices,
2138
        settleIndex kvdb.RwBucket) error {
UNCOV
2139

×
2140
        // First, we need to grab the AMP invoice state to see if there's
2141
        // anything that we even need to delete.
2142
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
2143
        if err != nil {
2144
                return err
UNCOV
2145
        }
×
UNCOV
2146

×
UNCOV
2147
        // If there's no AMP state at all (non-AMP invoice), then we can return
×
UNCOV
2148
        // early.
×
UNCOV
2149
        if len(ampState) == 0 {
×
UNCOV
2150
                return nil
×
2151
        }
×
2152

×
2153
        // Otherwise, we'll need to iterate and delete each settle index within
2154
        // the set of returned entries.
2155
        var settleIndexKey [8]byte
UNCOV
2156
        for _, subState := range ampState {
×
UNCOV
2157
                byteOrder.PutUint64(
×
UNCOV
2158
                        settleIndexKey[:], subState.SettleIndex,
×
2159
                )
2160

2161
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
UNCOV
2162
                        return err
×
UNCOV
2163
                }
×
UNCOV
2164
        }
×
UNCOV
2165

×
UNCOV
2166
        return nil
×
UNCOV
2167
}
×
UNCOV
2168

×
2169
// DeleteCanceledInvoices deletes all canceled invoices from the database.
×
2170
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
×
2171
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
2172
                invoices := tx.ReadWriteBucket(invoiceBucket)
UNCOV
2173
                if invoices == nil {
×
2174
                        return nil
2175
                }
2176

UNCOV
2177
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2178
                        invoiceIndexBucket,
×
UNCOV
2179
                )
×
UNCOV
2180
                if invoiceIndex == nil {
×
2181
                        return nil
×
2182
                }
×
2183

UNCOV
2184
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2185
                        addIndexBucket,
×
UNCOV
2186
                )
×
UNCOV
2187
                if invoiceAddIndex == nil {
×
UNCOV
2188
                        return nil
×
UNCOV
2189
                }
×
2190

UNCOV
2191
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
UNCOV
2192

×
UNCOV
2193
                return invoiceIndex.ForEach(func(k, v []byte) error {
×
UNCOV
2194
                        // Skip the special numInvoicesKey as that does not
×
2195
                        // point to a valid invoice.
×
2196
                        if bytes.Equal(k, numInvoicesKey) {
×
2197
                                return nil
UNCOV
2198
                        }
×
UNCOV
2199

×
UNCOV
2200
                        // Skip sub-buckets.
×
UNCOV
2201
                        if v == nil {
×
UNCOV
2202
                                return nil
×
UNCOV
2203
                        }
×
UNCOV
2204

×
UNCOV
2205
                        invoice, err := fetchInvoice(v, invoices, nil, false)
×
2206
                        if err != nil {
2207
                                return err
UNCOV
2208
                        }
×
2209

×
2210
                        if invoice.State != invpkg.ContractCanceled {
×
2211
                                return nil
UNCOV
2212
                        }
×
UNCOV
2213

×
2214
                        // Delete the payment hash from the invoice index.
×
2215
                        err = invoiceIndex.Delete(k)
×
2216
                        if err != nil {
UNCOV
2217
                                return err
×
UNCOV
2218
                        }
×
UNCOV
2219

×
2220
                        // Delete payment address index reference if there's a
2221
                        // valid payment address.
UNCOV
2222
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
×
UNCOV
2223
                                // To ensure consistency check that the already
×
2224
                                // fetched invoice key matches the one in the
×
2225
                                // payment address index.
×
2226
                                key := payAddrIndex.Get(
2227
                                        invoice.Terms.PaymentAddr[:],
2228
                                )
UNCOV
2229
                                if bytes.Equal(key, k) {
×
UNCOV
2230
                                        // Delete from the payment address
×
UNCOV
2231
                                        // index.
×
UNCOV
2232
                                        if err := payAddrIndex.Delete(
×
UNCOV
2233
                                                invoice.Terms.PaymentAddr[:],
×
UNCOV
2234
                                        ); err != nil {
×
UNCOV
2235
                                                return err
×
UNCOV
2236
                                        }
×
2237
                                }
×
2238
                        }
×
2239

×
2240
                        // Remove from the add index.
×
2241
                        var addIndexKey [8]byte
×
2242
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
×
2243
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
2244
                        if err != nil {
2245
                                return err
2246
                        }
2247

UNCOV
2248
                        // Note that we don't need to delete the invoice from
×
UNCOV
2249
                        // the settle index as it is not added until the
×
UNCOV
2250
                        // invoice is settled.
×
UNCOV
2251

×
2252
                        // Now remove all sub invoices.
×
2253
                        err = delAMPInvoices(k, invoices)
×
2254
                        if err != nil {
2255
                                return err
2256
                        }
2257

2258
                        // Finally remove the serialized invoice from the
2259
                        // invoice bucket.
UNCOV
2260
                        return invoices.Delete(k)
×
UNCOV
2261
                })
×
2262
        }, func() {})
×
2263
}
×
2264

2265
// DeleteInvoice attempts to delete the passed invoices from the database in
2266
// one transaction. The passed delete references hold all keys required to
UNCOV
2267
// delete the invoices without also needing to deserialize them.
×
2268
func (d *DB) DeleteInvoice(_ context.Context,
UNCOV
2269
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
×
2270

2271
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
2272
                invoices := tx.ReadWriteBucket(invoiceBucket)
2273
                if invoices == nil {
2274
                        return invpkg.ErrNoInvoicesCreated
2275
                }
UNCOV
2276

×
UNCOV
2277
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2278
                        invoiceIndexBucket,
×
UNCOV
2279
                )
×
UNCOV
2280
                if invoiceIndex == nil {
×
2281
                        return invpkg.ErrNoInvoicesCreated
×
2282
                }
×
2283

UNCOV
2284
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2285
                        addIndexBucket,
×
UNCOV
2286
                )
×
UNCOV
2287
                if invoiceAddIndex == nil {
×
2288
                        return invpkg.ErrNoInvoicesCreated
×
2289
                }
×
2290

UNCOV
2291
                // settleIndex can be nil, as the bucket is created lazily
×
UNCOV
2292
                // when the first invoice is settled.
×
UNCOV
2293
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
×
UNCOV
2294

×
2295
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
2296

×
2297
                for _, ref := range invoicesToDelete {
2298
                        // Fetch the invoice key for using it to check for
2299
                        // consistency and also to delete from the invoice
UNCOV
2300
                        // index.
×
UNCOV
2301
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
×
UNCOV
2302
                        if invoiceKey == nil {
×
UNCOV
2303
                                return invpkg.ErrInvoiceNotFound
×
UNCOV
2304
                        }
×
UNCOV
2305

×
UNCOV
2306
                        err := invoiceIndex.Delete(ref.PayHash[:])
×
UNCOV
2307
                        if err != nil {
×
UNCOV
2308
                                return err
×
UNCOV
2309
                        }
×
UNCOV
2310

×
UNCOV
2311
                        // Delete payment address index reference if there's a
×
2312
                        // valid payment address passed.
UNCOV
2313
                        if ref.PayAddr != nil {
×
UNCOV
2314
                                // To ensure consistency check that the already
×
2315
                                // fetched invoice key matches the one in the
×
2316
                                // payment address index.
×
2317
                                key := payAddrIndex.Get(ref.PayAddr[:])
2318
                                if bytes.Equal(key, invoiceKey) {
2319
                                        // Delete from the payment address
UNCOV
2320
                                        // index. Note that since the payment
×
UNCOV
2321
                                        // address index has been introduced
×
UNCOV
2322
                                        // with an empty migration it may be
×
UNCOV
2323
                                        // possible that the index doesn't have
×
UNCOV
2324
                                        // an entry for this invoice.
×
UNCOV
2325
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
×
UNCOV
2326
                                        if err := payAddrIndex.Delete(
×
UNCOV
2327
                                                ref.PayAddr[:],
×
UNCOV
2328
                                        ); err != nil {
×
UNCOV
2329
                                                return err
×
UNCOV
2330
                                        }
×
UNCOV
2331
                                }
×
UNCOV
2332
                        }
×
UNCOV
2333

×
UNCOV
2334
                        var addIndexKey [8]byte
×
UNCOV
2335
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
×
2336

×
2337
                        // To ensure consistency check that the key stored in
×
2338
                        // the add index also matches the previously fetched
2339
                        // invoice key.
2340
                        key := invoiceAddIndex.Get(addIndexKey[:])
UNCOV
2341
                        if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2342
                                return fmt.Errorf("unknown invoice " +
×
UNCOV
2343
                                        "in add index")
×
UNCOV
2344
                        }
×
UNCOV
2345

×
UNCOV
2346
                        // Remove from the add index.
×
UNCOV
2347
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
UNCOV
2348
                        if err != nil {
×
UNCOV
2349
                                return err
×
UNCOV
2350
                        }
×
UNCOV
2351

×
2352
                        // Remove from the settle index if available and
2353
                        // if the invoice is settled.
UNCOV
2354
                        if settleIndex != nil && ref.SettleIndex > 0 {
×
UNCOV
2355
                                var settleIndexKey [8]byte
×
2356
                                byteOrder.PutUint64(
×
2357
                                        settleIndexKey[:], ref.SettleIndex,
×
2358
                                )
2359

2360
                                // To ensure consistency check that the already
UNCOV
2361
                                // fetched invoice key matches the one in the
×
UNCOV
2362
                                // settle index
×
UNCOV
2363
                                key := settleIndex.Get(settleIndexKey[:])
×
UNCOV
2364
                                if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2365
                                        return fmt.Errorf("unknown invoice " +
×
UNCOV
2366
                                                "in settle index")
×
UNCOV
2367
                                }
×
UNCOV
2368

×
UNCOV
2369
                                err = settleIndex.Delete(settleIndexKey[:])
×
UNCOV
2370
                                if err != nil {
×
UNCOV
2371
                                        return err
×
UNCOV
2372
                                }
×
UNCOV
2373
                        }
×
UNCOV
2374

×
2375
                        // In addition to deleting the main invoice state, if
UNCOV
2376
                        // this is an AMP invoice, then we'll also need to
×
UNCOV
2377
                        // delete the set HTLC set stored as a key prefix. For
×
2378
                        // non-AMP invoices, this'll be a noop.
×
2379
                        err = delAMPSettleIndex(
×
2380
                                invoiceKey, invoices, settleIndex,
2381
                        )
2382
                        if err != nil {
2383
                                return err
2384
                        }
2385
                        err = delAMPInvoices(invoiceKey, invoices)
UNCOV
2386
                        if err != nil {
×
UNCOV
2387
                                return err
×
UNCOV
2388
                        }
×
UNCOV
2389

×
2390
                        // Finally remove the serialized invoice from the
×
2391
                        // invoice bucket.
×
UNCOV
2392
                        err = invoices.Delete(invoiceKey)
×
UNCOV
2393
                        if err != nil {
×
2394
                                return err
×
2395
                        }
×
2396
                }
2397

2398
                return nil
UNCOV
2399
        }, func() {})
×
UNCOV
2400

×
2401
        return err
×
2402
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc