• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 12312390362

13 Dec 2024 08:44AM UTC coverage: 57.458% (+8.5%) from 48.92%
12312390362

Pull #9343

github

ellemouton
fn: rework the ContextGuard and add tests

In this commit, the ContextGuard struct is re-worked such that the
context that its new main WithCtx method provides is cancelled in sync
with a parent context being cancelled or with it's quit channel being
cancelled. Tests are added to assert the behaviour. In order for the
close of the quit channel to be consistent with the cancelling of the
derived context, the quit channel _must_ be contained internal to the
ContextGuard so that callers are only able to close the channel via the
exposed Quit method which will then take care to first cancel any
derived context that depend on the quit channel before returning.
Pull Request #9343: fn: expand the ContextGuard and add tests

101853 of 177264 relevant lines covered (57.46%)

24972.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.32
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "time"
11

12
        "github.com/lightningnetwork/lnd/graph/db/models"
13
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
14
        invpkg "github.com/lightningnetwork/lnd/invoices"
15
        "github.com/lightningnetwork/lnd/kvdb"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/tlv"
20
)
21

22
var (
23
        // invoiceBucket is the name of the bucket within the database that
24
        // stores all data related to invoices no matter their final state.
25
        // Within the invoice bucket, each invoice is keyed by its invoice ID
26
        // which is a monotonically increasing uint32.
27
        invoiceBucket = []byte("invoices")
28

29
        // paymentHashIndexBucket is the name of the sub-bucket within the
30
        // invoiceBucket which indexes all invoices by their payment hash. The
31
        // payment hash is the sha256 of the invoice's payment preimage. This
32
        // index is used to detect duplicates, and also to provide a fast path
33
        // for looking up incoming HTLCs to determine if we're able to settle
34
        // them fully.
35
        //
36
        // maps: payHash => invoiceKey
37
        invoiceIndexBucket = []byte("paymenthashes")
38

39
        // payAddrIndexBucket is the name of the top-level bucket that maps
40
        // payment addresses to their invoice number. This can be used
41
        // to efficiently query or update non-legacy invoices. Note that legacy
42
        // invoices will not be included in this index since they all have the
43
        // same, all-zero payment address, however all newly generated invoices
44
        // will end up in this index.
45
        //
46
        // maps: payAddr => invoiceKey
47
        payAddrIndexBucket = []byte("pay-addr-index")
48

49
        // setIDIndexBucket is the name of the top-level bucket that maps set
50
        // ids to their invoice number. This can be used to efficiently query or
51
        // update AMP invoice. Note that legacy or MPP invoices will not be
52
        // included in this index, since their HTLCs do not have a set id.
53
        //
54
        // maps: setID => invoiceKey
55
        setIDIndexBucket = []byte("set-id-index")
56

57
        // numInvoicesKey is the name of key which houses the auto-incrementing
58
        // invoice ID which is essentially used as a primary key. With each
59
        // invoice inserted, the primary key is incremented by one. This key is
60
        // stored within the invoiceIndexBucket. Within the invoiceBucket
61
        // invoices are uniquely identified by the invoice ID.
62
        numInvoicesKey = []byte("nik")
63

64
        // addIndexBucket is an index bucket that we'll use to create a
65
        // monotonically increasing set of add indexes. Each time we add a new
66
        // invoice, this sequence number will be incremented and then populated
67
        // within the new invoice.
68
        //
69
        // In addition to this sequence number, we map:
70
        //
71
        //   addIndexNo => invoiceKey
72
        addIndexBucket = []byte("invoice-add-index")
73

74
        // settleIndexBucket is an index bucket that we'll use to create a
75
        // monotonically increasing integer for tracking a "settle index". Each
76
        // time an invoice is settled, this sequence number will be incremented
77
        // as populate within the newly settled invoice.
78
        //
79
        // In addition to this sequence number, we map:
80
        //
81
        //   settleIndexNo => invoiceKey
82
        settleIndexBucket = []byte("invoice-settle-index")
83
)
84

85
const (
86
        // A set of tlv type definitions used to serialize invoice htlcs to the
87
        // database.
88
        //
89
        // NOTE: A migration should be added whenever this list changes. This
90
        // prevents against the database being rolled back to an older
91
        // format where the surrounding logic might assume a different set of
92
        // fields are known.
93
        chanIDType       tlv.Type = 1
94
        htlcIDType       tlv.Type = 3
95
        amtType          tlv.Type = 5
96
        acceptHeightType tlv.Type = 7
97
        acceptTimeType   tlv.Type = 9
98
        resolveTimeType  tlv.Type = 11
99
        expiryHeightType tlv.Type = 13
100
        htlcStateType    tlv.Type = 15
101
        mppTotalAmtType  tlv.Type = 17
102
        htlcAMPType      tlv.Type = 19
103
        htlcHashType     tlv.Type = 21
104
        htlcPreimageType tlv.Type = 23
105

106
        // A set of tlv type definitions used to serialize invoice bodiees.
107
        //
108
        // NOTE: A migration should be added whenever this list changes. This
109
        // prevents against the database being rolled back to an older
110
        // format where the surrounding logic might assume a different set of
111
        // fields are known.
112
        memoType            tlv.Type = 0
113
        payReqType          tlv.Type = 1
114
        createTimeType      tlv.Type = 2
115
        settleTimeType      tlv.Type = 3
116
        addIndexType        tlv.Type = 4
117
        settleIndexType     tlv.Type = 5
118
        preimageType        tlv.Type = 6
119
        valueType           tlv.Type = 7
120
        cltvDeltaType       tlv.Type = 8
121
        expiryType          tlv.Type = 9
122
        paymentAddrType     tlv.Type = 10
123
        featuresType        tlv.Type = 11
124
        invStateType        tlv.Type = 12
125
        amtPaidType         tlv.Type = 13
126
        hodlInvoiceType     tlv.Type = 14
127
        invoiceAmpStateType tlv.Type = 15
128

129
        // A set of tlv type definitions used to serialize the invoice AMP
130
        // state along-side the main invoice body.
131
        ampStateSetIDType       tlv.Type = 0
132
        ampStateHtlcStateType   tlv.Type = 1
133
        ampStateSettleIndexType tlv.Type = 2
134
        ampStateSettleDateType  tlv.Type = 3
135
        ampStateCircuitKeysType tlv.Type = 4
136
        ampStateAmtPaidType     tlv.Type = 5
137
)
138

139
// AddInvoice inserts the targeted invoice into the database. If the invoice has
140
// *any* payment hashes which already exists within the database, then the
141
// insertion will be aborted and rejected due to the strict policy banning any
142
// duplicate payment hashes. A side effect of this function is that it sets
143
// AddIndex on newInvoice.
144
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
145
        paymentHash lntypes.Hash) (uint64, error) {
622✔
146

622✔
147
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
624✔
148
                return 0, err
2✔
149
        }
2✔
150

151
        var invoiceAddIndex uint64
620✔
152
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
1,240✔
153
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
620✔
154
                if err != nil {
620✔
155
                        return err
×
156
                }
×
157

158
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
620✔
159
                        invoiceIndexBucket,
620✔
160
                )
620✔
161
                if err != nil {
620✔
162
                        return err
×
163
                }
×
164
                addIndex, err := invoices.CreateBucketIfNotExists(
620✔
165
                        addIndexBucket,
620✔
166
                )
620✔
167
                if err != nil {
620✔
168
                        return err
×
169
                }
×
170

171
                // Ensure that an invoice an identical payment hash doesn't
172
                // already exist within the index.
173
                if invoiceIndex.Get(paymentHash[:]) != nil {
623✔
174
                        return invpkg.ErrDuplicateInvoice
3✔
175
                }
3✔
176

177
                // Check that we aren't inserting an invoice with a duplicate
178
                // payment address. The all-zeros payment address is
179
                // special-cased to support legacy keysend invoices which don't
180
                // assign one. This is safe since later we also will avoid
181
                // indexing them and avoid collisions.
182
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
617✔
183
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
1,113✔
184
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
496✔
185
                        if payAddrIndex.Get(paymentAddr) != nil {
501✔
186
                                return invpkg.ErrDuplicatePayAddr
5✔
187
                        }
5✔
188
                }
189

190
                // If the current running payment ID counter hasn't yet been
191
                // created, then create it now.
192
                var invoiceNum uint32
612✔
193
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
612✔
194
                if invoiceCounter == nil {
808✔
195
                        var scratch [4]byte
196✔
196
                        byteOrder.PutUint32(scratch[:], invoiceNum)
196✔
197
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
196✔
198
                        if err != nil {
196✔
199
                                return err
×
200
                        }
×
201
                } else {
416✔
202
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
416✔
203
                }
416✔
204

205
                newIndex, err := putInvoice(
612✔
206
                        invoices, invoiceIndex, payAddrIndex, addIndex,
612✔
207
                        newInvoice, invoiceNum, paymentHash,
612✔
208
                )
612✔
209
                if err != nil {
612✔
210
                        return err
×
211
                }
×
212

213
                invoiceAddIndex = newIndex
612✔
214
                return nil
612✔
215
        }, func() {
620✔
216
                invoiceAddIndex = 0
620✔
217
        })
620✔
218
        if err != nil {
628✔
219
                return 0, err
8✔
220
        }
8✔
221

222
        return invoiceAddIndex, err
612✔
223
}
224

225
// InvoicesAddedSince can be used by callers to seek into the event time series
226
// of all the invoices added in the database. The specified sinceAddIndex
227
// should be the highest add index that the caller knows of. This method will
228
// return all invoices with an add index greater than the specified
229
// sinceAddIndex.
230
//
231
// NOTE: The index starts from 1, as a result. We enforce that specifying a
232
// value below the starting index value is a noop.
233
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
234
        []invpkg.Invoice, error) {
20✔
235

20✔
236
        var newInvoices []invpkg.Invoice
20✔
237

20✔
238
        // If an index of zero was specified, then in order to maintain
20✔
239
        // backwards compat, we won't send out any new invoices.
20✔
240
        if sinceAddIndex == 0 {
37✔
241
                return newInvoices, nil
17✔
242
        }
17✔
243

244
        var startIndex [8]byte
3✔
245
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
3✔
246

3✔
247
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
248
                invoices := tx.ReadBucket(invoiceBucket)
3✔
249
                if invoices == nil {
3✔
250
                        return nil
×
251
                }
×
252

253
                addIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
254
                if addIndex == nil {
3✔
255
                        return nil
×
256
                }
×
257

258
                // We'll now run through each entry in the add index starting
259
                // at our starting index. We'll continue until we reach the
260
                // very end of the current key space.
261
                invoiceCursor := addIndex.ReadCursor()
3✔
262

3✔
263
                // We'll seek to the starting index, then manually advance the
3✔
264
                // cursor in order to skip the entry with the since add index.
3✔
265
                invoiceCursor.Seek(startIndex[:])
3✔
266
                addSeqNo, invoiceKey := invoiceCursor.Next()
3✔
267

3✔
268
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
32✔
269

29✔
270
                        // For each key found, we'll look up the actual
29✔
271
                        // invoice, then accumulate it into our return value.
29✔
272
                        invoice, err := fetchInvoice(
29✔
273
                                invoiceKey, invoices, nil, false,
29✔
274
                        )
29✔
275
                        if err != nil {
29✔
276
                                return err
×
277
                        }
×
278

279
                        newInvoices = append(newInvoices, invoice)
29✔
280
                }
281

282
                return nil
3✔
283
        }, func() {
3✔
284
                newInvoices = nil
3✔
285
        })
3✔
286
        if err != nil {
3✔
287
                return nil, err
×
288
        }
×
289

290
        return newInvoices, nil
3✔
291
}
292

293
// LookupInvoice attempts to look up an invoice according to its 32 byte
294
// payment hash. If an invoice which can settle the HTLC identified by the
295
// passed payment hash isn't found, then an error is returned. Otherwise, the
296
// full invoice is returned. Before setting the incoming HTLC, the values
297
// SHOULD be checked to ensure the payer meets the agreed upon contractual
298
// terms of the payment.
299
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
300
        invpkg.Invoice, error) {
625✔
301

625✔
302
        var invoice invpkg.Invoice
625✔
303
        err := kvdb.View(d, func(tx kvdb.RTx) error {
1,250✔
304
                invoices := tx.ReadBucket(invoiceBucket)
625✔
305
                if invoices == nil {
625✔
306
                        return invpkg.ErrNoInvoicesCreated
×
307
                }
×
308
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
625✔
309
                if invoiceIndex == nil {
638✔
310
                        return invpkg.ErrNoInvoicesCreated
13✔
311
                }
13✔
312
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
612✔
313
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
612✔
314

612✔
315
                // Retrieve the invoice number for this invoice using
612✔
316
                // the provided invoice reference.
612✔
317
                invoiceNum, err := fetchInvoiceNumByRef(
612✔
318
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
612✔
319
                )
612✔
320
                if err != nil {
620✔
321
                        return err
8✔
322
                }
8✔
323

324
                var setID *invpkg.SetID
604✔
325
                switch {
604✔
326
                // If this is a payment address ref, and the blank modified was
327
                // specified, then we'll use the zero set ID to indicate that
328
                // we won't want any HTLCs returned.
329
                case ref.PayAddr() != nil &&
330
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
1✔
331

1✔
332
                        var zeroSetID invpkg.SetID
1✔
333
                        setID = &zeroSetID
1✔
334

335
                // If this is a set ID ref, and the htlc set only modified was
336
                // specified, then we'll pass through the specified setID so
337
                // only that will be returned.
338
                case ref.SetID() != nil &&
339
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
4✔
340

4✔
341
                        setID = (*invpkg.SetID)(ref.SetID())
4✔
342
                }
343

344
                // An invoice was found, retrieve the remainder of the invoice
345
                // body.
346
                i, err := fetchInvoice(
604✔
347
                        invoiceNum, invoices, []*invpkg.SetID{setID}, true,
604✔
348
                )
604✔
349
                if err != nil {
604✔
350
                        return err
×
351
                }
×
352
                invoice = i
604✔
353

604✔
354
                return nil
604✔
355
        }, func() {})
625✔
356
        if err != nil {
646✔
357
                return invoice, err
21✔
358
        }
21✔
359

360
        return invoice, nil
604✔
361
}
362

363
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
364
// reference. The payment address will be treated as the primary key, falling
365
// back to the payment hash if nothing is found for the payment address. An
366
// error is returned if the invoice is not found.
367
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
368
        ref invpkg.InvoiceRef) ([]byte, error) {
1,236✔
369

1,236✔
370
        // If the set id is present, we only consult the set id index for this
1,236✔
371
        // invoice. This type of query is only used to facilitate user-facing
1,236✔
372
        // requests to lookup, settle or cancel an AMP invoice.
1,236✔
373
        setID := ref.SetID()
1,236✔
374
        if setID != nil {
1,243✔
375
                invoiceNumBySetID := setIDIndex.Get(setID[:])
7✔
376
                if invoiceNumBySetID == nil {
8✔
377
                        return nil, invpkg.ErrInvoiceNotFound
1✔
378
                }
1✔
379

380
                return invoiceNumBySetID, nil
6✔
381
        }
382

383
        payHash := ref.PayHash()
1,229✔
384
        payAddr := ref.PayAddr()
1,229✔
385

1,229✔
386
        getInvoiceNumByHash := func() []byte {
2,458✔
387
                if payHash != nil {
2,434✔
388
                        return invoiceIndex.Get(payHash[:])
1,205✔
389
                }
1,205✔
390
                return nil
24✔
391
        }
392

393
        getInvoiceNumByAddr := func() []byte {
2,458✔
394
                if payAddr != nil {
1,710✔
395
                        // Only allow lookups for payment address if it is not a
481✔
396
                        // blank payment address, which is a special-cased value
481✔
397
                        // for legacy keysend invoices.
481✔
398
                        if *payAddr != invpkg.BlankPayAddr {
536✔
399
                                return payAddrIndex.Get(payAddr[:])
55✔
400
                        }
55✔
401
                }
402
                return nil
1,174✔
403
        }
404

405
        invoiceNumByHash := getInvoiceNumByHash()
1,229✔
406
        invoiceNumByAddr := getInvoiceNumByAddr()
1,229✔
407
        switch {
1,229✔
408
        // If payment address and payment hash both reference an existing
409
        // invoice, ensure they reference the _same_ invoice.
410
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
31✔
411
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
33✔
412
                        return nil, invpkg.ErrInvRefEquivocation
2✔
413
                }
2✔
414

415
                return invoiceNumByAddr, nil
29✔
416

417
        // Return invoices by payment addr only.
418
        //
419
        // NOTE: We constrain this lookup to only apply if the invoice ref does
420
        // not contain a payment hash. Legacy and MPP payments depend on the
421
        // payment hash index to enforce that the HTLCs payment hash matches the
422
        // payment hash for the invoice, without this check we would
423
        // inadvertently assume the invoice contains the correct preimage for
424
        // the HTLC, which we only enforce via the lookup by the invoice index.
425
        case invoiceNumByAddr != nil && payHash == nil:
23✔
426
                return invoiceNumByAddr, nil
23✔
427

428
        // If we were only able to reference the invoice by hash, return the
429
        // corresponding invoice number. This can happen when no payment address
430
        // was provided, or if it didn't match anything in our records.
431
        case invoiceNumByHash != nil:
1,166✔
432
                return invoiceNumByHash, nil
1,166✔
433

434
        // Otherwise we don't know of the target invoice.
435
        default:
9✔
436
                return nil, invpkg.ErrInvoiceNotFound
9✔
437
        }
438
}
439

440
// FetchPendingInvoices returns all invoices that have not yet been settled or
441
// canceled. The returned map is keyed by the payment hash of each respective
442
// invoice.
443
func (d *DB) FetchPendingInvoices(_ context.Context) (
444
        map[lntypes.Hash]invpkg.Invoice, error) {
386✔
445

386✔
446
        result := make(map[lntypes.Hash]invpkg.Invoice)
386✔
447

386✔
448
        err := kvdb.View(d, func(tx kvdb.RTx) error {
772✔
449
                invoices := tx.ReadBucket(invoiceBucket)
386✔
450
                if invoices == nil {
386✔
451
                        return nil
×
452
                }
×
453

454
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
386✔
455
                if invoiceIndex == nil {
769✔
456
                        // Mask the error if there's no invoice
383✔
457
                        // index as that simply means there are no
383✔
458
                        // invoices added yet to the DB. In this case
383✔
459
                        // we simply return an empty list.
383✔
460
                        return nil
383✔
461
                }
383✔
462

463
                return invoiceIndex.ForEach(func(k, v []byte) error {
41✔
464
                        // Skip the special numInvoicesKey as that does not
38✔
465
                        // point to a valid invoice.
38✔
466
                        if bytes.Equal(k, numInvoicesKey) {
41✔
467
                                return nil
3✔
468
                        }
3✔
469

470
                        // Skip sub-buckets.
471
                        if v == nil {
35✔
472
                                return nil
×
473
                        }
×
474

475
                        invoice, err := fetchInvoice(v, invoices, nil, false)
35✔
476
                        if err != nil {
35✔
477
                                return err
×
478
                        }
×
479

480
                        if invoice.IsPending() {
55✔
481
                                var paymentHash lntypes.Hash
20✔
482
                                copy(paymentHash[:], k)
20✔
483
                                result[paymentHash] = invoice
20✔
484
                        }
20✔
485

486
                        return nil
35✔
487
                })
488
        }, func() {
386✔
489
                result = make(map[lntypes.Hash]invpkg.Invoice)
386✔
490
        })
386✔
491

492
        if err != nil {
386✔
493
                return nil, err
×
494
        }
×
495

496
        return result, nil
386✔
497
}
498

499
// QueryInvoices allows a caller to query the invoice database for invoices
500
// within the specified add index range.
501
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
502
        invpkg.InvoiceSlice, error) {
43✔
503

43✔
504
        var resp invpkg.InvoiceSlice
43✔
505

43✔
506
        err := kvdb.View(d, func(tx kvdb.RTx) error {
86✔
507
                // If the bucket wasn't found, then there aren't any invoices
43✔
508
                // within the database yet, so we can simply exit.
43✔
509
                invoices := tx.ReadBucket(invoiceBucket)
43✔
510
                if invoices == nil {
43✔
511
                        return invpkg.ErrNoInvoicesCreated
×
512
                }
×
513

514
                // Get the add index bucket which we will use to iterate through
515
                // our indexed invoices.
516
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
43✔
517
                if invoiceAddIndex == nil {
43✔
518
                        return invpkg.ErrNoInvoicesCreated
×
519
                }
×
520

521
                // Create a paginator which reads from our add index bucket with
522
                // the parameters provided by the invoice query.
523
                paginator := newPaginator(
43✔
524
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
43✔
525
                        q.NumMaxInvoices,
43✔
526
                )
43✔
527

43✔
528
                // accumulateInvoices looks up an invoice based on the index we
43✔
529
                // are given, adds it to our set of invoices if it has the right
43✔
530
                // characteristics for our query and returns the number of items
43✔
531
                // we have added to our set of invoices.
43✔
532
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
982✔
533
                        invoice, err := fetchInvoice(
939✔
534
                                indexValue, invoices, nil, false,
939✔
535
                        )
939✔
536
                        if err != nil {
939✔
537
                                return false, err
×
538
                        }
×
539

540
                        // Skip any settled or canceled invoices if the caller
541
                        // is only interested in pending ones.
542
                        if q.PendingOnly && !invoice.IsPending() {
1,094✔
543
                                return false, nil
155✔
544
                        }
155✔
545

546
                        // Get the creation time in Unix seconds, this always
547
                        // rounds down the nanoseconds to full seconds.
548
                        createTime := invoice.CreationDate.Unix()
784✔
549

784✔
550
                        // Skip any invoices that were created before the
784✔
551
                        // specified time.
784✔
552
                        if createTime < q.CreationDateStart {
881✔
553
                                return false, nil
97✔
554
                        }
97✔
555

556
                        // Skip any invoices that were created after the
557
                        // specified time.
558
                        if q.CreationDateEnd != 0 &&
687✔
559
                                createTime > q.CreationDateEnd {
865✔
560

178✔
561
                                return false, nil
178✔
562
                        }
178✔
563

564
                        // At this point, we've exhausted the offset, so we'll
565
                        // begin collecting invoices found within the range.
566
                        resp.Invoices = append(resp.Invoices, invoice)
509✔
567

509✔
568
                        return true, nil
509✔
569
                }
570

571
                // Query our paginator using accumulateInvoices to build up a
572
                // set of invoices.
573
                if err := paginator.query(accumulateInvoices); err != nil {
43✔
574
                        return err
×
575
                }
×
576

577
                // If we iterated through the add index in reverse order, then
578
                // we'll need to reverse the slice of invoices to return them in
579
                // forward order.
580
                if q.Reversed {
57✔
581
                        numInvoices := len(resp.Invoices)
14✔
582
                        for i := 0; i < numInvoices/2; i++ {
83✔
583
                                reverse := numInvoices - i - 1
69✔
584
                                resp.Invoices[i], resp.Invoices[reverse] =
69✔
585
                                        resp.Invoices[reverse], resp.Invoices[i]
69✔
586
                        }
69✔
587
                }
588

589
                return nil
43✔
590
        }, func() {
43✔
591
                resp = invpkg.InvoiceSlice{
43✔
592
                        InvoiceQuery: q,
43✔
593
                }
43✔
594
        })
43✔
595
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
43✔
596
                return resp, err
×
597
        }
×
598

599
        // Finally, record the indexes of the first and last invoices returned
600
        // so that the caller can resume from this point later on.
601
        if len(resp.Invoices) > 0 {
81✔
602
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
38✔
603
                lastIdx := len(resp.Invoices) - 1
38✔
604
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
38✔
605
        }
38✔
606

607
        return resp, nil
43✔
608
}
609

610
// UpdateInvoice attempts to update an invoice corresponding to the passed
611
// payment hash. If an invoice matching the passed payment hash doesn't exist
612
// within the database, then the action will fail with a "not found" error.
613
//
614
// The update is performed inside the same database transaction that fetches the
615
// invoice and is therefore atomic. The fields to update are controlled by the
616
// supplied callback.  When updating an invoice, the update itself happens
617
// in-memory on a copy of the invoice. Once it is written successfully to the
618
// database, the in-memory copy is returned to the caller.
619
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
620
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
621
        *invpkg.Invoice, error) {
624✔
622

624✔
623
        var updatedInvoice *invpkg.Invoice
624✔
624
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
1,248✔
625
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
624✔
626
                if err != nil {
624✔
627
                        return err
×
628
                }
×
629
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
624✔
630
                        invoiceIndexBucket,
624✔
631
                )
624✔
632
                if err != nil {
624✔
633
                        return err
×
634
                }
×
635
                settleIndex, err := invoices.CreateBucketIfNotExists(
624✔
636
                        settleIndexBucket,
624✔
637
                )
624✔
638
                if err != nil {
624✔
639
                        return err
×
640
                }
×
641
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
624✔
642
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
624✔
643

624✔
644
                // Retrieve the invoice number for this invoice using the
624✔
645
                // provided invoice reference.
624✔
646
                invoiceNum, err := fetchInvoiceNumByRef(
624✔
647
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
624✔
648
                )
624✔
649
                if err != nil {
628✔
650
                        return err
4✔
651
                }
4✔
652

653
                // If the set ID hint is non-nil, then we'll use that to filter
654
                // out the HTLCs for AMP invoice so we don't need to read them
655
                // all out to satisfy the invoice callback below. If it's nil,
656
                // then we pass in the zero set ID which means no HTLCs will be
657
                // read out.
658
                var invSetID invpkg.SetID
620✔
659

620✔
660
                if setIDHint != nil {
651✔
661
                        invSetID = *setIDHint
31✔
662
                }
31✔
663
                invoice, err := fetchInvoice(
620✔
664
                        invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
620✔
665
                )
620✔
666
                if err != nil {
620✔
667
                        return err
×
668
                }
×
669

670
                now := d.clock.Now()
620✔
671
                updater := &kvInvoiceUpdater{
620✔
672
                        db:                d,
620✔
673
                        invoicesBucket:    invoices,
620✔
674
                        settleIndexBucket: settleIndex,
620✔
675
                        setIDIndexBucket:  setIDIndex,
620✔
676
                        updateTime:        now,
620✔
677
                        invoiceNum:        invoiceNum,
620✔
678
                        invoice:           &invoice,
620✔
679
                        updatedAmpHtlcs:   make(ampHTLCsMap),
620✔
680
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
620✔
681
                }
620✔
682

620✔
683
                payHash := ref.PayHash()
620✔
684
                updatedInvoice, err = invpkg.UpdateInvoice(
620✔
685
                        payHash, updater.invoice, now, callback, updater,
620✔
686
                )
620✔
687
                if err != nil {
631✔
688
                        return err
11✔
689
                }
11✔
690

691
                // If this is an AMP update, then limit the returned AMP state
692
                // to only the requested set ID.
693
                if setIDHint != nil {
637✔
694
                        filterInvoiceAMPState(updatedInvoice, &invSetID)
28✔
695
                }
28✔
696

697
                return nil
609✔
698
        }, func() {
624✔
699
                updatedInvoice = nil
624✔
700
        })
624✔
701

702
        return updatedInvoice, err
624✔
703
}
704

705
// filterInvoiceAMPState filters the AMP state of the invoice to only include
706
// state for the specified set IDs.
707
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
52✔
708
        filteredAMPState := make(invpkg.AMPInvoiceState)
52✔
709

52✔
710
        for _, setID := range setIDs {
104✔
711
                if setID == nil {
70✔
712
                        return
18✔
713
                }
18✔
714

715
                ampState, ok := invoice.AMPState[*setID]
34✔
716
                if ok {
67✔
717
                        filteredAMPState[*setID] = ampState
33✔
718
                }
33✔
719
        }
720

721
        invoice.AMPState = filteredAMPState
34✔
722
}
723

724
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
725
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
726

727
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
728
// is used with the kv implementation of the invoice database. Note that this
729
// updater is not concurrency safe and synchronizaton is expected to be handled
730
// on the DB level.
731
type kvInvoiceUpdater struct {
732
        db                *DB
733
        invoicesBucket    kvdb.RwBucket
734
        settleIndexBucket kvdb.RwBucket
735
        setIDIndexBucket  kvdb.RwBucket
736

737
        // updateTime is the timestamp for the update.
738
        updateTime time.Time
739

740
        // invoiceNum is a legacy key similar to the add index that is used
741
        // only in the kv implementation.
742
        invoiceNum []byte
743

744
        // invoice is the invoice that we're updating. As a side effect of the
745
        // update this invoice will be mutated.
746
        invoice *invpkg.Invoice
747

748
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
749
        // cancelled as part of this update.
750
        updatedAmpHtlcs ampHTLCsMap
751

752
        // settledSetIDs holds the set IDs that are settled with this update.
753
        settledSetIDs map[invpkg.SetID]struct{}
754
}
755

756
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
757
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
758
        _ *invpkg.InvoiceHTLC) error {
491✔
759

491✔
760
        return nil
491✔
761
}
491✔
762

763
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
764
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
765
        _ time.Time) error {
485✔
766

485✔
767
        return nil
485✔
768
}
485✔
769

770
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
771
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
772
        _ lntypes.Preimage) error {
6✔
773

6✔
774
        return nil
6✔
775
}
6✔
776

777
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
778
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
779
        _ *lntypes.Preimage) error {
450✔
780

450✔
781
        return nil
450✔
782
}
450✔
783

784
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
785
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
557✔
786
        return nil
557✔
787
}
557✔
788

789
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
790
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
791
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
33✔
792

33✔
793
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
61✔
794
                switch state.State {
28✔
795
                case invpkg.HtlcStateAccepted:
19✔
796
                        // If we're just now creating the HTLCs for this set
19✔
797
                        // then we'll also pull in the existing HTLCs that are
19✔
798
                        // part of this set, so we can write them all to disk
19✔
799
                        // together (same value)
19✔
800
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
19✔
801
                                &setID, invpkg.HtlcStateAccepted,
19✔
802
                        )
19✔
803

804
                case invpkg.HtlcStateCanceled:
3✔
805
                        // Only HTLCs in the accepted state, can be cancelled,
3✔
806
                        // but we also want to merge that with HTLCs that may be
3✔
807
                        // canceled as well since it can be cancelled one by
3✔
808
                        // one.
3✔
809
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
3✔
810
                                &setID, invpkg.HtlcStateAccepted,
3✔
811
                        )
3✔
812

3✔
813
                        cancelledHtlcs := k.invoice.HTLCSet(
3✔
814
                                &setID, invpkg.HtlcStateCanceled,
3✔
815
                        )
3✔
816
                        for htlcKey, htlc := range cancelledHtlcs {
7✔
817
                                k.updatedAmpHtlcs[setID][htlcKey] = htlc
4✔
818
                        }
4✔
819

820
                case invpkg.HtlcStateSettled:
6✔
821
                        k.updatedAmpHtlcs[setID] = make(
6✔
822
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
6✔
823
                        )
6✔
824
                }
825
        }
826

827
        if state.State == invpkg.HtlcStateSettled {
44✔
828
                // Add the set ID to the set that was settled in this invoice
11✔
829
                // update. We'll use this later to update the settle index.
11✔
830
                k.settledSetIDs[setID] = struct{}{}
11✔
831
        }
11✔
832

833
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
33✔
834

33✔
835
        return nil
33✔
836
}
837

838
// Finalize finalizes the update before it is written to the database.
839
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
591✔
840
        switch updateType {
591✔
841
        case invpkg.AddHTLCsUpdate:
496✔
842
                return k.storeAddHtlcsUpdate()
496✔
843

844
        case invpkg.CancelHTLCsUpdate:
6✔
845
                return k.storeCancelHtlcsUpdate()
6✔
846

847
        case invpkg.SettleHodlInvoiceUpdate:
58✔
848
                return k.storeSettleHodlInvoiceUpdate()
58✔
849

850
        case invpkg.CancelInvoiceUpdate:
31✔
851
                return k.serializeAndStoreInvoice()
31✔
852
        }
853

854
        return fmt.Errorf("unknown update type: %v", updateType)
×
855
}
856

857
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
858
// set of HTLCs.
859
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
6✔
860
        err := k.serializeAndStoreInvoice()
6✔
861
        if err != nil {
6✔
862
                return err
×
863
        }
×
864

865
        // If this is an AMP invoice, then we'll actually store the rest
866
        // of the HTLCs in-line with the invoice, using the invoice ID
867
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
868
        // setID.
869
        if k.invoice.IsAMP() {
9✔
870
                return k.updateAMPInvoices()
3✔
871
        }
3✔
872

873
        return nil
3✔
874
}
875

876
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
877
// HTLCs.
878
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
496✔
879
        invoiceIsAMP := k.invoice.IsAMP()
496✔
880

496✔
881
        for htlcSetID := range k.updatedAmpHtlcs {
521✔
882
                // Check if this SetID already exist.
25✔
883
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
25✔
884

25✔
885
                if setIDInvNum == nil {
38✔
886
                        err := k.setIDIndexBucket.Put(
13✔
887
                                htlcSetID[:], k.invoiceNum,
13✔
888
                        )
13✔
889
                        if err != nil {
13✔
890
                                return err
×
891
                        }
×
892
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
13✔
893
                        return invpkg.ErrDuplicateSetID{
1✔
894
                                SetID: htlcSetID,
1✔
895
                        }
1✔
896
                }
1✔
897
        }
898

899
        // If this is a non-AMP invoice, then the state can eventually go to
900
        // ContractSettled, so we pass in nil value as part of
901
        // setSettleMetaFields.
902
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
793✔
903
                err := k.setSettleMetaFields(nil)
298✔
904
                if err != nil {
298✔
905
                        return err
×
906
                }
×
907
        }
908

909
        // As we don't update the settle index above for AMP invoices, we'll do
910
        // it here for each sub-AMP invoice that was settled.
911
        for settledSetID := range k.settledSetIDs {
503✔
912
                settledSetID := settledSetID
8✔
913
                err := k.setSettleMetaFields(&settledSetID)
8✔
914
                if err != nil {
8✔
915
                        return err
×
916
                }
×
917
        }
918

919
        err := k.serializeAndStoreInvoice()
495✔
920
        if err != nil {
495✔
921
                return err
×
922
        }
×
923

924
        // If this is an AMP invoice, then we'll actually store the rest of the
925
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
926
        // and the AMP key as a suffix: invoiceNum || setID.
927
        if invoiceIsAMP {
519✔
928
                return k.updateAMPInvoices()
24✔
929
        }
24✔
930

931
        return nil
471✔
932
}
933

934
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
935
// settling a hodl invoice.
936
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
58✔
937
        err := k.setSettleMetaFields(nil)
58✔
938
        if err != nil {
58✔
939
                return err
×
940
        }
×
941

942
        return k.serializeAndStoreInvoice()
58✔
943
}
944

945
// setSettleMetaFields updates the metadata associated with settlement of an
946
// invoice. If a non-nil setID is passed in, then the value will be append to
947
// the invoice number as well, in order to allow us to detect repeated payments
948
// to the same AMP invoices "across time".
949
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
364✔
950
        // Now that we know the invoice hasn't already been settled, we'll
364✔
951
        // update the settle index so we can place this settle event in the
364✔
952
        // proper location within our time series.
364✔
953
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
364✔
954
        if err != nil {
364✔
955
                return err
×
956
        }
×
957

958
        // Make a new byte array on the stack that can potentially store the 4
959
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
960
        // here which is the number of bytes copied so we can only store the 4
961
        // bytes if this is a non-AMP invoice.
962
        var indexKey [invoiceSetIDKeyLen]byte
364✔
963
        valueLen := copy(indexKey[:], k.invoiceNum)
364✔
964

364✔
965
        if setID != nil {
372✔
966
                valueLen += copy(indexKey[valueLen:], setID[:])
8✔
967
        }
8✔
968

969
        var seqNoBytes [8]byte
364✔
970
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
364✔
971
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
364✔
972
        if err != nil {
364✔
973
                return err
×
974
        }
×
975

976
        // If the setID is nil, then this means that this is a non-AMP settle,
977
        // so we'll update the invoice settle index directly.
978
        if setID == nil {
720✔
979
                k.invoice.SettleDate = k.updateTime
356✔
980
                k.invoice.SettleIndex = nextSettleSeqNo
356✔
981
        } else {
364✔
982
                // If the set ID isn't blank, we'll update the AMP state map
8✔
983
                // which tracks when each of the setIDs associated with a given
8✔
984
                // AMP invoice are settled.
8✔
985
                ampState := k.invoice.AMPState[*setID]
8✔
986

8✔
987
                ampState.SettleDate = k.updateTime
8✔
988
                ampState.SettleIndex = nextSettleSeqNo
8✔
989

8✔
990
                k.invoice.AMPState[*setID] = ampState
8✔
991
        }
8✔
992

993
        return nil
364✔
994
}
995

996
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
997
// then continually write the invoices to the end of the invoice value, we
998
// instead write the invoices into a new key preifx that follows the main
999
// invoice number. This ensures that we don't need to continually decode a
1000
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
1001
// associated with a particular HTLC set.
1002
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
27✔
1003
        for setID, htlcSet := range k.updatedAmpHtlcs {
54✔
1004
                // First write out the set of HTLCs including all the relevant
27✔
1005
                // TLV values.
27✔
1006
                var b bytes.Buffer
27✔
1007
                if err := serializeHtlcs(&b, htlcSet); err != nil {
27✔
1008
                        return err
×
1009
                }
×
1010

1011
                // Next store each HTLC in-line, using a prefix based off the
1012
                // invoice number.
1013
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
27✔
1014

27✔
1015
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
27✔
1016
                if err != nil {
27✔
1017
                        return err
×
1018
                }
×
1019
        }
1020

1021
        return nil
27✔
1022
}
1023

1024
// serializeAndStoreInvoice is a helper function used to store invoices.
1025
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
590✔
1026
        var buf bytes.Buffer
590✔
1027
        if err := serializeInvoice(&buf, k.invoice); err != nil {
590✔
1028
                return err
×
1029
        }
×
1030

1031
        return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
590✔
1032
}
1033

1034
// InvoicesSettledSince can be used by callers to catch up any settled invoices
1035
// they missed within the settled invoice time series. We'll return all known
1036
// settled invoice that have a settle index higher than the passed
1037
// sinceSettleIndex.
1038
//
1039
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1040
// value below the starting index value is a noop.
1041
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1042
        []invpkg.Invoice, error) {
21✔
1043

21✔
1044
        var settledInvoices []invpkg.Invoice
21✔
1045

21✔
1046
        // If an index of zero was specified, then in order to maintain
21✔
1047
        // backwards compat, we won't send out any new invoices.
21✔
1048
        if sinceSettleIndex == 0 {
39✔
1049
                return settledInvoices, nil
18✔
1050
        }
18✔
1051

1052
        var startIndex [8]byte
3✔
1053
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
3✔
1054

3✔
1055
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
1056
                invoices := tx.ReadBucket(invoiceBucket)
3✔
1057
                if invoices == nil {
3✔
1058
                        return nil
×
1059
                }
×
1060

1061
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
3✔
1062
                if settleIndex == nil {
3✔
1063
                        return nil
×
1064
                }
×
1065

1066
                // We'll now run through each entry in the add index starting
1067
                // at our starting index. We'll continue until we reach the
1068
                // very end of the current key space.
1069
                invoiceCursor := settleIndex.ReadCursor()
3✔
1070

3✔
1071
                // We'll seek to the starting index, then manually advance the
3✔
1072
                // cursor in order to skip the entry with the since add index.
3✔
1073
                invoiceCursor.Seek(startIndex[:])
3✔
1074
                seqNo, indexValue := invoiceCursor.Next()
3✔
1075

3✔
1076
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
14✔
1077
                        // Depending on the length of the index value, this may
11✔
1078
                        // or may not be an AMP invoice, so we'll extract the
11✔
1079
                        // invoice value into two components: the invoice num,
11✔
1080
                        // and the setID (may not be there).
11✔
1081
                        var (
11✔
1082
                                invoiceKey [4]byte
11✔
1083
                                setID      *invpkg.SetID
11✔
1084
                        )
11✔
1085

11✔
1086
                        valueLen := copy(invoiceKey[:], indexValue)
11✔
1087
                        if len(indexValue) == invoiceSetIDKeyLen {
13✔
1088
                                setID = new(invpkg.SetID)
2✔
1089
                                copy(setID[:], indexValue[valueLen:])
2✔
1090
                        }
2✔
1091

1092
                        // For each key found, we'll look up the actual
1093
                        // invoice, then accumulate it into our return value.
1094
                        invoice, err := fetchInvoice(
11✔
1095
                                invoiceKey[:], invoices, []*invpkg.SetID{setID},
11✔
1096
                                true,
11✔
1097
                        )
11✔
1098
                        if err != nil {
11✔
1099
                                return err
×
1100
                        }
×
1101

1102
                        settledInvoices = append(settledInvoices, invoice)
11✔
1103
                }
1104

1105
                return nil
3✔
1106
        }, func() {
3✔
1107
                settledInvoices = nil
3✔
1108
        })
3✔
1109
        if err != nil {
3✔
1110
                return nil, err
×
1111
        }
×
1112

1113
        return settledInvoices, nil
3✔
1114
}
1115

1116
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
1117
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
1118
        uint64, error) {
612✔
1119

612✔
1120
        // Create the invoice key which is just the big-endian representation
612✔
1121
        // of the invoice number.
612✔
1122
        var invoiceKey [4]byte
612✔
1123
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
612✔
1124

612✔
1125
        // Increment the num invoice counter index so the next invoice bares
612✔
1126
        // the proper ID.
612✔
1127
        var scratch [4]byte
612✔
1128
        invoiceCounter := invoiceNum + 1
612✔
1129
        byteOrder.PutUint32(scratch[:], invoiceCounter)
612✔
1130
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
612✔
1131
                return 0, err
×
1132
        }
×
1133

1134
        // Add the payment hash to the invoice index. This will let us quickly
1135
        // identify if we can settle an incoming payment, and also to possibly
1136
        // allow a single invoice to have multiple payment installations.
1137
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
612✔
1138
        if err != nil {
612✔
1139
                return 0, err
×
1140
        }
×
1141

1142
        // Add the invoice to the payment address index, but only if the invoice
1143
        // has a non-zero payment address. The all-zero payment address is still
1144
        // in use by legacy keysend, so we special-case here to avoid
1145
        // collisions.
1146
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
1,103✔
1147
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
491✔
1148
                if err != nil {
491✔
1149
                        return 0, err
×
1150
                }
×
1151
        }
1152

1153
        // Next, we'll obtain the next add invoice index (sequence
1154
        // number), so we can properly place this invoice within this
1155
        // event stream.
1156
        nextAddSeqNo, err := addIndex.NextSequence()
612✔
1157
        if err != nil {
612✔
1158
                return 0, err
×
1159
        }
×
1160

1161
        // With the next sequence obtained, we'll updating the event series in
1162
        // the add index bucket to map this current add counter to the index of
1163
        // this new invoice.
1164
        var seqNoBytes [8]byte
612✔
1165
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
612✔
1166
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
612✔
1167
                return 0, err
×
1168
        }
×
1169

1170
        i.AddIndex = nextAddSeqNo
612✔
1171

612✔
1172
        // Finally, serialize the invoice itself to be written to the disk.
612✔
1173
        var buf bytes.Buffer
612✔
1174
        if err := serializeInvoice(&buf, i); err != nil {
612✔
1175
                return 0, err
×
1176
        }
×
1177

1178
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
612✔
1179
                return 0, err
×
1180
        }
×
1181

1182
        return nextAddSeqNo, nil
612✔
1183
}
1184

1185
// recordSize returns the amount of bytes this TLV record will occupy when
1186
// encoded.
1187
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
1,203✔
1188
        var (
1,203✔
1189
                b   bytes.Buffer
1,203✔
1190
                buf [8]byte
1,203✔
1191
        )
1,203✔
1192

1,203✔
1193
        // We know that encoding works since the tests pass in the build this
1,203✔
1194
        // file is checked into, so we'll simplify things and simply encode it
1,203✔
1195
        // ourselves then report the total amount of bytes used.
1,203✔
1196
        if err := ampStateEncoder(&b, a, &buf); err != nil {
1,203✔
1197
                // This should never error out, but we log it just in case it
×
1198
                // does.
×
1199
                log.Errorf("encoding the amp invoice state failed: %v", err)
×
1200
        }
×
1201

1202
        return func() uint64 {
2,406✔
1203
                return uint64(len(b.Bytes()))
1,203✔
1204
        }
1,203✔
1205
}
1206

1207
// serializeInvoice serializes an invoice to a writer.
1208
//
1209
// Note: this function is in use for a migration. Before making changes that
1210
// would modify the on disk format, make a copy of the original code and store
1211
// it with the migration.
1212
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
1,202✔
1213
        creationDateBytes, err := i.CreationDate.MarshalBinary()
1,202✔
1214
        if err != nil {
1,202✔
1215
                return err
×
1216
        }
×
1217

1218
        settleDateBytes, err := i.SettleDate.MarshalBinary()
1,202✔
1219
        if err != nil {
1,202✔
1220
                return err
×
1221
        }
×
1222

1223
        var fb bytes.Buffer
1,202✔
1224
        err = i.Terms.Features.EncodeBase256(&fb)
1,202✔
1225
        if err != nil {
1,202✔
1226
                return err
×
1227
        }
×
1228
        featureBytes := fb.Bytes()
1,202✔
1229

1,202✔
1230
        preimage := [32]byte(invpkg.UnknownPreimage)
1,202✔
1231
        if i.Terms.PaymentPreimage != nil {
2,270✔
1232
                preimage = *i.Terms.PaymentPreimage
1,068✔
1233
                if preimage == invpkg.UnknownPreimage {
1,068✔
1234
                        return errors.New("cannot use all-zeroes preimage")
×
1235
                }
×
1236
        }
1237
        value := uint64(i.Terms.Value)
1,202✔
1238
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
1,202✔
1239
        expiry := uint64(i.Terms.Expiry)
1,202✔
1240

1,202✔
1241
        amtPaid := uint64(i.AmtPaid)
1,202✔
1242
        state := uint8(i.State)
1,202✔
1243

1,202✔
1244
        var hodlInvoice uint8
1,202✔
1245
        if i.HodlInvoice {
1,399✔
1246
                hodlInvoice = 1
197✔
1247
        }
197✔
1248

1249
        tlvStream, err := tlv.NewStream(
1,202✔
1250
                // Memo and payreq.
1,202✔
1251
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
1,202✔
1252
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
1,202✔
1253

1,202✔
1254
                // Add/settle metadata.
1,202✔
1255
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
1,202✔
1256
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
1,202✔
1257
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
1,202✔
1258
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
1,202✔
1259

1,202✔
1260
                // Terms.
1,202✔
1261
                tlv.MakePrimitiveRecord(preimageType, &preimage),
1,202✔
1262
                tlv.MakePrimitiveRecord(valueType, &value),
1,202✔
1263
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
1,202✔
1264
                tlv.MakePrimitiveRecord(expiryType, &expiry),
1,202✔
1265
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
1,202✔
1266
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
1,202✔
1267

1,202✔
1268
                // Invoice state.
1,202✔
1269
                tlv.MakePrimitiveRecord(invStateType, &state),
1,202✔
1270
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
1,202✔
1271

1,202✔
1272
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
1,202✔
1273

1,202✔
1274
                // Invoice AMP state.
1,202✔
1275
                tlv.MakeDynamicRecord(
1,202✔
1276
                        invoiceAmpStateType, &i.AMPState,
1,202✔
1277
                        ampRecordSize(&i.AMPState),
1,202✔
1278
                        ampStateEncoder, ampStateDecoder,
1,202✔
1279
                ),
1,202✔
1280
        )
1,202✔
1281
        if err != nil {
1,202✔
1282
                return err
×
1283
        }
×
1284

1285
        var b bytes.Buffer
1,202✔
1286
        if err = tlvStream.Encode(&b); err != nil {
1,202✔
1287
                return err
×
1288
        }
×
1289

1290
        err = binary.Write(w, byteOrder, uint64(b.Len()))
1,202✔
1291
        if err != nil {
1,202✔
1292
                return err
×
1293
        }
×
1294

1295
        if _, err = w.Write(b.Bytes()); err != nil {
1,202✔
1296
                return err
×
1297
        }
×
1298

1299
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
1300
        // in-line with the rest of the invoice.
1301
        if i.IsAMP() {
1,242✔
1302
                return nil
40✔
1303
        }
40✔
1304

1305
        return serializeHtlcs(w, i.Htlcs)
1,162✔
1306
}
1307

1308
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
1309
// a writer.
1310
func serializeHtlcs(w io.Writer,
1311
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
1,189✔
1312

1,189✔
1313
        for key, htlc := range htlcs {
1,876✔
1314
                // Encode the htlc in a tlv stream.
687✔
1315
                chanID := key.ChanID.ToUint64()
687✔
1316
                amt := uint64(htlc.Amt)
687✔
1317
                mppTotalAmt := uint64(htlc.MppTotalAmt)
687✔
1318
                acceptTime := putNanoTime(htlc.AcceptTime)
687✔
1319
                resolveTime := putNanoTime(htlc.ResolveTime)
687✔
1320
                state := uint8(htlc.State)
687✔
1321

687✔
1322
                var records []tlv.Record
687✔
1323
                records = append(records,
687✔
1324
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
687✔
1325
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
687✔
1326
                        tlv.MakePrimitiveRecord(amtType, &amt),
687✔
1327
                        tlv.MakePrimitiveRecord(
687✔
1328
                                acceptHeightType, &htlc.AcceptHeight,
687✔
1329
                        ),
687✔
1330
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
687✔
1331
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
687✔
1332
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
687✔
1333
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
687✔
1334
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
687✔
1335
                )
687✔
1336

687✔
1337
                if htlc.AMP != nil {
723✔
1338
                        setIDRecord := tlv.MakeDynamicRecord(
36✔
1339
                                htlcAMPType, &htlc.AMP.Record,
36✔
1340
                                htlc.AMP.Record.PayloadSize,
36✔
1341
                                record.AMPEncoder, record.AMPDecoder,
36✔
1342
                        )
36✔
1343
                        records = append(records, setIDRecord)
36✔
1344

36✔
1345
                        hash32 := [32]byte(htlc.AMP.Hash)
36✔
1346
                        hashRecord := tlv.MakePrimitiveRecord(
36✔
1347
                                htlcHashType, &hash32,
36✔
1348
                        )
36✔
1349
                        records = append(records, hashRecord)
36✔
1350

36✔
1351
                        if htlc.AMP.Preimage != nil {
63✔
1352
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
27✔
1353
                                preimageRecord := tlv.MakePrimitiveRecord(
27✔
1354
                                        htlcPreimageType, &preimage32,
27✔
1355
                                )
27✔
1356
                                records = append(records, preimageRecord)
27✔
1357
                        }
27✔
1358
                }
1359

1360
                // Convert the custom records to tlv.Record types that are ready
1361
                // for serialization.
1362
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
687✔
1363

687✔
1364
                // Append the custom records. Their ids are in the experimental
687✔
1365
                // range and sorted, so there is no need to sort again.
687✔
1366
                records = append(records, customRecords...)
687✔
1367

687✔
1368
                tlvStream, err := tlv.NewStream(records...)
687✔
1369
                if err != nil {
687✔
1370
                        return err
×
1371
                }
×
1372

1373
                var b bytes.Buffer
687✔
1374
                if err := tlvStream.Encode(&b); err != nil {
687✔
1375
                        return err
×
1376
                }
×
1377

1378
                // Write the length of the tlv stream followed by the stream
1379
                // bytes.
1380
                err = binary.Write(w, byteOrder, uint64(b.Len()))
687✔
1381
                if err != nil {
687✔
1382
                        return err
×
1383
                }
×
1384

1385
                if _, err := w.Write(b.Bytes()); err != nil {
687✔
1386
                        return err
×
1387
                }
×
1388
        }
1389

1390
        return nil
1,189✔
1391
}
1392

1393
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
1394
// timestamp will be mapped to 0, since calling UnixNano in that case is
1395
// undefined.
1396
func putNanoTime(t time.Time) uint64 {
1,374✔
1397
        if t.IsZero() {
1,571✔
1398
                return 0
197✔
1399
        }
197✔
1400
        return uint64(t.UnixNano())
1,177✔
1401
}
1402

1403
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
1404
// is provided, an zero-value time stamp is returned.
1405
func getNanoTime(ns uint64) time.Time {
2,186✔
1406
        if ns == 0 {
2,534✔
1407
                return time.Time{}
348✔
1408
        }
348✔
1409
        return time.Unix(0, int64(ns))
1,838✔
1410
}
1411

1412
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
1413
// identified by the setID value.
1414
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1415
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1416
        error) {
37✔
1417

37✔
1418
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
37✔
1419
        for _, setID := range setIDs {
74✔
1420
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
37✔
1421

37✔
1422
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
37✔
1423
                if htlcSetBytes == nil {
51✔
1424
                        // A set ID was passed in, but we don't have this
14✔
1425
                        // stored yet, meaning that the setID is being added
14✔
1426
                        // for the first time.
14✔
1427
                        return htlcs, invpkg.ErrInvoiceNotFound
14✔
1428
                }
14✔
1429

1430
                htlcSetReader := bytes.NewReader(htlcSetBytes)
23✔
1431
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
23✔
1432
                if err != nil {
23✔
1433
                        return nil, err
×
1434
                }
×
1435

1436
                for key, htlc := range htlcsBySetID {
51✔
1437
                        htlcs[key] = htlc
28✔
1438
                }
28✔
1439
        }
1440

1441
        return htlcs, nil
23✔
1442
}
1443

1444
// forEachAMPInvoice is a helper function that attempts to iterate over each of
1445
// the HTLC sets (based on their set ID) for the given AMP invoice identified
1446
// by its invoiceNum. The callback closure is called for each key within the
1447
// prefix range.
1448
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1449
        callback func(key, htlcSet []byte) error) error {
33✔
1450

33✔
1451
        invoiceCursor := invoiceBucket.ReadCursor()
33✔
1452

33✔
1453
        // Seek to the first key that includes the invoice data itself.
33✔
1454
        invoiceCursor.Seek(invoiceNum)
33✔
1455

33✔
1456
        // Advance to the very first key _after_ the invoice data, as this is
33✔
1457
        // where we'll encounter our first HTLC (if any are present).
33✔
1458
        cursorKey, htlcSet := invoiceCursor.Next()
33✔
1459

33✔
1460
        // If at this point, the cursor key doesn't match the invoice num
33✔
1461
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
33✔
1462
        // associated with it.
33✔
1463
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
51✔
1464
                return nil
18✔
1465
        }
18✔
1466

1467
        // Otherwise continue to iterate until we no longer match the prefix,
1468
        // executing the call back at each step.
1469
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
41✔
1470
                err := callback(cursorKey, htlcSet)
26✔
1471
                if err != nil {
26✔
1472
                        return err
×
1473
                }
×
1474
        }
1475

1476
        return nil
15✔
1477
}
1478

1479
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
1480
// AMP bucket to find all the individual HTLCs (by setID) associated with a
1481
// given invoice. If a list of set IDs are specified, then only HTLCs
1482
// associated with that setID will be retrieved.
1483
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1484
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1485
        error) {
55✔
1486

55✔
1487
        // If a set of setIDs was specified, then we can skip the cursor and
55✔
1488
        // just read out exactly what we need.
55✔
1489
        if len(setIDs) != 0 && setIDs[0] != nil {
92✔
1490
                return fetchFilteredAmpInvoices(
37✔
1491
                        invoiceBucket, invoiceNum, setIDs...,
37✔
1492
                )
37✔
1493
        }
37✔
1494

1495
        // Otherwise, iterate over all the htlc sets that are prefixed beside
1496
        // this invoice in the main invoice bucket.
1497
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
18✔
1498
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
18✔
1499
                func(key, htlcSet []byte) error {
41✔
1500
                        htlcSetReader := bytes.NewReader(htlcSet)
23✔
1501
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
23✔
1502
                        if err != nil {
23✔
1503
                                return err
×
1504
                        }
×
1505

1506
                        for key, htlc := range htlcsBySetID {
55✔
1507
                                htlcs[key] = htlc
32✔
1508
                        }
32✔
1509

1510
                        return nil
23✔
1511
                },
1512
        )
1513

1514
        if err != nil {
18✔
1515
                return nil, err
×
1516
        }
×
1517

1518
        return htlcs, nil
18✔
1519
}
1520

1521
// fetchInvoice attempts to read out the relevant state for the invoice as
1522
// specified by the invoice number. If the setID fields are set, then only the
1523
// HTLC information pertaining to those set IDs is returned.
1524
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1525
        setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
2,253✔
1526

2,253✔
1527
        invoiceBytes := invoices.Get(invoiceNum)
2,253✔
1528
        if invoiceBytes == nil {
2,253✔
1529
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
×
1530
        }
×
1531

1532
        invoiceReader := bytes.NewReader(invoiceBytes)
2,253✔
1533

2,253✔
1534
        invoice, err := deserializeInvoice(invoiceReader)
2,253✔
1535
        if err != nil {
2,253✔
1536
                return invpkg.Invoice{}, err
×
1537
        }
×
1538

1539
        // If this is an AMP invoice we'll also attempt to read out the set of
1540
        // HTLCs that were paid to prior set IDs, if needed.
1541
        if !invoice.IsAMP() {
4,448✔
1542
                return invoice, nil
2,195✔
1543
        }
2,195✔
1544

1545
        if shouldFetchAMPHTLCs(invoice, setIDs) {
113✔
1546
                invoice.Htlcs, err = fetchAmpSubInvoices(
55✔
1547
                        invoices, invoiceNum, setIDs...,
55✔
1548
                )
55✔
1549
                // TODO(positiveblue): we should fail when we are not able to
55✔
1550
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
55✔
1551
                // the invoice and channeldb package break if we return this
55✔
1552
                // error. We need to update them when we migrate this logic to
55✔
1553
                // the sql implementation.
55✔
1554
                if err != nil {
69✔
1555
                        log.Errorf("unable to fetch amp htlcs for inv "+
14✔
1556
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
14✔
1557
                }
14✔
1558

1559
                if filterAMPState {
79✔
1560
                        filterInvoiceAMPState(&invoice, setIDs...)
24✔
1561
                }
24✔
1562
        }
1563

1564
        return invoice, nil
58✔
1565
}
1566

1567
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
1568
// were paid to the relevant set IDs.
1569
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
58✔
1570
        // For AMP invoice that already have HTLCs populated (created before
58✔
1571
        // recurring invoices), then we don't need to read from the prefix
58✔
1572
        // keyed section of the bucket.
58✔
1573
        if len(invoice.Htlcs) != 0 {
58✔
1574
                return false
×
1575
        }
×
1576

1577
        // If the "zero" setID was specified, then this means that no HTLC data
1578
        // should be returned alongside of it.
1579
        if len(setIDs) != 0 && setIDs[0] != nil &&
58✔
1580
                *setIDs[0] == invpkg.BlankPayAddr {
61✔
1581

3✔
1582
                return false
3✔
1583
        }
3✔
1584

1585
        return true
55✔
1586
}
1587

1588
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
1589
// an AMP invoice. This methods only decode the relevant state vs the entire
1590
// invoice.
1591
func fetchInvoiceStateAMP(invoiceNum []byte,
1592
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
8✔
1593

8✔
1594
        // Fetch the raw invoice bytes.
8✔
1595
        invoiceBytes := invoices.Get(invoiceNum)
8✔
1596
        if invoiceBytes == nil {
8✔
1597
                return nil, invpkg.ErrInvoiceNotFound
×
1598
        }
×
1599

1600
        r := bytes.NewReader(invoiceBytes)
8✔
1601

8✔
1602
        var bodyLen int64
8✔
1603
        err := binary.Read(r, byteOrder, &bodyLen)
8✔
1604
        if err != nil {
8✔
1605
                return nil, err
×
1606
        }
×
1607

1608
        // Next, we'll make a new TLV stream that only attempts to decode the
1609
        // bytes we actually need.
1610
        ampState := make(invpkg.AMPInvoiceState)
8✔
1611
        tlvStream, err := tlv.NewStream(
8✔
1612
                // Invoice AMP state.
8✔
1613
                tlv.MakeDynamicRecord(
8✔
1614
                        invoiceAmpStateType, &ampState, nil,
8✔
1615
                        ampStateEncoder, ampStateDecoder,
8✔
1616
                ),
8✔
1617
        )
8✔
1618
        if err != nil {
8✔
1619
                return nil, err
×
1620
        }
×
1621

1622
        invoiceReader := io.LimitReader(r, bodyLen)
8✔
1623
        if err = tlvStream.Decode(invoiceReader); err != nil {
8✔
1624
                return nil, err
×
1625
        }
×
1626

1627
        return ampState, nil
8✔
1628
}
1629

1630
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
2,253✔
1631
        var (
2,253✔
1632
                preimageBytes [32]byte
2,253✔
1633
                value         uint64
2,253✔
1634
                cltvDelta     uint32
2,253✔
1635
                expiry        uint64
2,253✔
1636
                amtPaid       uint64
2,253✔
1637
                state         uint8
2,253✔
1638
                hodlInvoice   uint8
2,253✔
1639

2,253✔
1640
                creationDateBytes []byte
2,253✔
1641
                settleDateBytes   []byte
2,253✔
1642
                featureBytes      []byte
2,253✔
1643
        )
2,253✔
1644

2,253✔
1645
        var i invpkg.Invoice
2,253✔
1646
        i.AMPState = make(invpkg.AMPInvoiceState)
2,253✔
1647
        tlvStream, err := tlv.NewStream(
2,253✔
1648
                // Memo and payreq.
2,253✔
1649
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
2,253✔
1650
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
2,253✔
1651

2,253✔
1652
                // Add/settle metadata.
2,253✔
1653
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
2,253✔
1654
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
2,253✔
1655
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
2,253✔
1656
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
2,253✔
1657

2,253✔
1658
                // Terms.
2,253✔
1659
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
2,253✔
1660
                tlv.MakePrimitiveRecord(valueType, &value),
2,253✔
1661
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
2,253✔
1662
                tlv.MakePrimitiveRecord(expiryType, &expiry),
2,253✔
1663
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
2,253✔
1664
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
2,253✔
1665

2,253✔
1666
                // Invoice state.
2,253✔
1667
                tlv.MakePrimitiveRecord(invStateType, &state),
2,253✔
1668
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
2,253✔
1669

2,253✔
1670
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
2,253✔
1671

2,253✔
1672
                // Invoice AMP state.
2,253✔
1673
                tlv.MakeDynamicRecord(
2,253✔
1674
                        invoiceAmpStateType, &i.AMPState, nil,
2,253✔
1675
                        ampStateEncoder, ampStateDecoder,
2,253✔
1676
                ),
2,253✔
1677
        )
2,253✔
1678
        if err != nil {
2,253✔
1679
                return i, err
×
1680
        }
×
1681

1682
        var bodyLen int64
2,253✔
1683
        err = binary.Read(r, byteOrder, &bodyLen)
2,253✔
1684
        if err != nil {
2,253✔
1685
                return i, err
×
1686
        }
×
1687

1688
        lr := io.LimitReader(r, bodyLen)
2,253✔
1689
        if err = tlvStream.Decode(lr); err != nil {
2,253✔
1690
                return i, err
×
1691
        }
×
1692

1693
        preimage := lntypes.Preimage(preimageBytes)
2,253✔
1694
        if preimage != invpkg.UnknownPreimage {
4,299✔
1695
                i.Terms.PaymentPreimage = &preimage
2,046✔
1696
        }
2,046✔
1697

1698
        i.Terms.Value = lnwire.MilliSatoshi(value)
2,253✔
1699
        i.Terms.FinalCltvDelta = int32(cltvDelta)
2,253✔
1700
        i.Terms.Expiry = time.Duration(expiry)
2,253✔
1701
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
2,253✔
1702
        i.State = invpkg.ContractState(state)
2,253✔
1703

2,253✔
1704
        if hodlInvoice != 0 {
2,483✔
1705
                i.HodlInvoice = true
230✔
1706
        }
230✔
1707

1708
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
2,253✔
1709
        if err != nil {
2,253✔
1710
                return i, err
×
1711
        }
×
1712

1713
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
2,253✔
1714
        if err != nil {
2,253✔
1715
                return i, err
×
1716
        }
×
1717

1718
        rawFeatures := lnwire.NewRawFeatureVector()
2,253✔
1719
        err = rawFeatures.DecodeBase256(
2,253✔
1720
                bytes.NewReader(featureBytes), len(featureBytes),
2,253✔
1721
        )
2,253✔
1722
        if err != nil {
2,253✔
1723
                return i, err
×
1724
        }
×
1725

1726
        i.Terms.Features = lnwire.NewFeatureVector(
2,253✔
1727
                rawFeatures, lnwire.Features,
2,253✔
1728
        )
2,253✔
1729

2,253✔
1730
        i.Htlcs, err = deserializeHtlcs(r)
2,253✔
1731
        return i, err
2,253✔
1732
}
1733

1734
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
98✔
1735
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
196✔
1736
                // We encode the set of circuit keys as a varint length prefix.
98✔
1737
                // followed by a series of fixed sized uint8 integers.
98✔
1738
                numKeys := uint64(len(*v))
98✔
1739

98✔
1740
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
98✔
1741
                        return err
×
1742
                }
×
1743

1744
                for key := range *v {
226✔
1745
                        scidInt := key.ChanID.ToUint64()
128✔
1746

128✔
1747
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
128✔
1748
                                return err
×
1749
                        }
×
1750
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
128✔
1751
                                return err
×
1752
                        }
×
1753
                }
1754

1755
                return nil
98✔
1756
        }
1757

1758
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1759
}
1760

1761
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1762
        l uint64) error {
88✔
1763

88✔
1764
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
176✔
1765
                // First, we'll read out the varint that encodes the number of
88✔
1766
                // circuit keys encoded.
88✔
1767
                numKeys, err := tlv.ReadVarInt(r, buf)
88✔
1768
                if err != nil {
88✔
1769
                        return err
×
1770
                }
×
1771

1772
                // Now that we know how many keys to expect, iterate reading
1773
                // each one until we're done.
1774
                for i := uint64(0); i < numKeys; i++ {
197✔
1775
                        var (
109✔
1776
                                key  models.CircuitKey
109✔
1777
                                scid uint64
109✔
1778
                        )
109✔
1779

109✔
1780
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
109✔
1781
                                return err
×
1782
                        }
×
1783

1784
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
109✔
1785

109✔
1786
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
109✔
1787
                        if err != nil {
109✔
1788
                                return err
×
1789
                        }
×
1790

1791
                        (*v)[key] = struct{}{}
109✔
1792
                }
1793

1794
                return nil
88✔
1795
        }
1796

1797
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
×
1798
}
1799

1800
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1801
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
2,406✔
1802
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
4,812✔
1803
                // We'll encode the AMP state as a series of KV pairs on the
2,406✔
1804
                // wire with a length prefix.
2,406✔
1805
                numRecords := uint64(len(*v))
2,406✔
1806

2,406✔
1807
                // First, we'll write out the number of records as a var int.
2,406✔
1808
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
2,406✔
1809
                        return err
×
1810
                }
×
1811

1812
                // With that written out, we'll now encode the entries
1813
                // themselves as a sub-TLV record, which includes its _own_
1814
                // inner length prefix.
1815
                for setID, ampState := range *v {
2,504✔
1816
                        setID := [32]byte(setID)
98✔
1817
                        ampState := ampState
98✔
1818

98✔
1819
                        htlcState := uint8(ampState.State)
98✔
1820
                        settleDate := ampState.SettleDate
98✔
1821
                        settleDateBytes, err := settleDate.MarshalBinary()
98✔
1822
                        if err != nil {
98✔
1823
                                return err
×
1824
                        }
×
1825

1826
                        amtPaid := uint64(ampState.AmtPaid)
98✔
1827

98✔
1828
                        var ampStateTlvBytes bytes.Buffer
98✔
1829
                        tlvStream, err := tlv.NewStream(
98✔
1830
                                tlv.MakePrimitiveRecord(
98✔
1831
                                        ampStateSetIDType, &setID,
98✔
1832
                                ),
98✔
1833
                                tlv.MakePrimitiveRecord(
98✔
1834
                                        ampStateHtlcStateType, &htlcState,
98✔
1835
                                ),
98✔
1836
                                tlv.MakePrimitiveRecord(
98✔
1837
                                        ampStateSettleIndexType,
98✔
1838
                                        &ampState.SettleIndex,
98✔
1839
                                ),
98✔
1840
                                tlv.MakePrimitiveRecord(
98✔
1841
                                        ampStateSettleDateType,
98✔
1842
                                        &settleDateBytes,
98✔
1843
                                ),
98✔
1844
                                tlv.MakeDynamicRecord(
98✔
1845
                                        ampStateCircuitKeysType,
98✔
1846
                                        &ampState.InvoiceKeys,
98✔
1847
                                        func() uint64 {
196✔
1848
                                                // The record takes 8 bytes to
98✔
1849
                                                // encode the set of circuits,
98✔
1850
                                                // 8 bytes for the scid for the
98✔
1851
                                                // key, and 8 bytes for the HTLC
98✔
1852
                                                // index.
98✔
1853
                                                keys := ampState.InvoiceKeys
98✔
1854
                                                numKeys := uint64(len(keys))
98✔
1855
                                                size := tlv.VarIntSize(numKeys)
98✔
1856
                                                dataSize := (numKeys * 16)
98✔
1857

98✔
1858
                                                return size + dataSize
98✔
1859
                                        },
98✔
1860
                                        encodeCircuitKeys, decodeCircuitKeys,
1861
                                ),
1862
                                tlv.MakePrimitiveRecord(
1863
                                        ampStateAmtPaidType, &amtPaid,
1864
                                ),
1865
                        )
1866
                        if err != nil {
98✔
1867
                                return err
×
1868
                        }
×
1869

1870
                        err = tlvStream.Encode(&ampStateTlvBytes)
98✔
1871
                        if err != nil {
98✔
1872
                                return err
×
1873
                        }
×
1874

1875
                        // We encode the record with a varint length followed by
1876
                        // the _raw_ TLV bytes.
1877
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
98✔
1878
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
98✔
1879
                                return err
×
1880
                        }
×
1881

1882
                        _, err = w.Write(ampStateTlvBytes.Bytes())
98✔
1883
                        if err != nil {
98✔
1884
                                return err
×
1885
                        }
×
1886
                }
1887

1888
                return nil
2,406✔
1889
        }
1890

1891
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1892
}
1893

1894
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1895
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
1896
        l uint64) error {
2,262✔
1897

2,262✔
1898
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
4,524✔
1899
                // First, we'll decode the varint that encodes how many set IDs
2,262✔
1900
                // are encoded within the greater map.
2,262✔
1901
                numRecords, err := tlv.ReadVarInt(r, buf)
2,262✔
1902
                if err != nil {
2,262✔
1903
                        return err
×
1904
                }
×
1905

1906
                // Now that we know how many records we'll need to read, we can
1907
                // iterate and read them all out in series.
1908
                for i := uint64(0); i < numRecords; i++ {
2,350✔
1909
                        // Read out the varint that encodes the size of this
88✔
1910
                        // inner TLV record.
88✔
1911
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
88✔
1912
                        if err != nil {
88✔
1913
                                return err
×
1914
                        }
×
1915

1916
                        // Using this information, we'll create a new limited
1917
                        // reader that'll return an EOF once the end has been
1918
                        // reached so the stream stops consuming bytes.
1919
                        innerTlvReader := io.LimitedReader{
88✔
1920
                                R: r,
88✔
1921
                                N: int64(stateRecordSize),
88✔
1922
                        }
88✔
1923

88✔
1924
                        var (
88✔
1925
                                setID           [32]byte
88✔
1926
                                htlcState       uint8
88✔
1927
                                settleIndex     uint64
88✔
1928
                                settleDateBytes []byte
88✔
1929
                                invoiceKeys     = make(
88✔
1930
                                        map[models.CircuitKey]struct{},
88✔
1931
                                )
88✔
1932
                                amtPaid uint64
88✔
1933
                        )
88✔
1934
                        tlvStream, err := tlv.NewStream(
88✔
1935
                                tlv.MakePrimitiveRecord(
88✔
1936
                                        ampStateSetIDType, &setID,
88✔
1937
                                ),
88✔
1938
                                tlv.MakePrimitiveRecord(
88✔
1939
                                        ampStateHtlcStateType, &htlcState,
88✔
1940
                                ),
88✔
1941
                                tlv.MakePrimitiveRecord(
88✔
1942
                                        ampStateSettleIndexType, &settleIndex,
88✔
1943
                                ),
88✔
1944
                                tlv.MakePrimitiveRecord(
88✔
1945
                                        ampStateSettleDateType,
88✔
1946
                                        &settleDateBytes,
88✔
1947
                                ),
88✔
1948
                                tlv.MakeDynamicRecord(
88✔
1949
                                        ampStateCircuitKeysType,
88✔
1950
                                        &invoiceKeys, nil,
88✔
1951
                                        encodeCircuitKeys, decodeCircuitKeys,
88✔
1952
                                ),
88✔
1953
                                tlv.MakePrimitiveRecord(
88✔
1954
                                        ampStateAmtPaidType, &amtPaid,
88✔
1955
                                ),
88✔
1956
                        )
88✔
1957
                        if err != nil {
88✔
1958
                                return err
×
1959
                        }
×
1960

1961
                        err = tlvStream.Decode(&innerTlvReader)
88✔
1962
                        if err != nil {
88✔
1963
                                return err
×
1964
                        }
×
1965

1966
                        var settleDate time.Time
88✔
1967
                        err = settleDate.UnmarshalBinary(settleDateBytes)
88✔
1968
                        if err != nil {
88✔
1969
                                return err
×
1970
                        }
×
1971

1972
                        (*v)[setID] = invpkg.InvoiceStateAMP{
88✔
1973
                                State:       invpkg.HtlcState(htlcState),
88✔
1974
                                SettleIndex: settleIndex,
88✔
1975
                                SettleDate:  settleDate,
88✔
1976
                                InvoiceKeys: invoiceKeys,
88✔
1977
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
88✔
1978
                        }
88✔
1979
                }
1980

1981
                return nil
2,262✔
1982
        }
1983

1984
        return tlv.NewTypeForDecodingErr(
×
1985
                val, "channeldb.AMPInvoiceState", l, l,
×
1986
        )
×
1987
}
1988

1989
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
1990
// as a map.
1991
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1992
        error) {
2,299✔
1993

2,299✔
1994
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
2,299✔
1995
        for {
5,691✔
1996
                // Read the length of the tlv stream for this htlc.
3,392✔
1997
                var streamLen int64
3,392✔
1998
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
5,691✔
1999
                        if err == io.EOF {
4,598✔
2000
                                break
2,299✔
2001
                        }
2002

2003
                        return nil, err
×
2004
                }
2005

2006
                // Limit the reader so that it stops at the end of this htlc's
2007
                // stream.
2008
                htlcReader := io.LimitReader(r, streamLen)
1,093✔
2009

1,093✔
2010
                // Decode the contents into the htlc fields.
1,093✔
2011
                var (
1,093✔
2012
                        htlc                    invpkg.InvoiceHTLC
1,093✔
2013
                        key                     models.CircuitKey
1,093✔
2014
                        chanID                  uint64
1,093✔
2015
                        state                   uint8
1,093✔
2016
                        acceptTime, resolveTime uint64
1,093✔
2017
                        amt, mppTotalAmt        uint64
1,093✔
2018
                        amp                     = &record.AMP{}
1,093✔
2019
                        hash32                  = &[32]byte{}
1,093✔
2020
                        preimage32              = &[32]byte{}
1,093✔
2021
                )
1,093✔
2022
                tlvStream, err := tlv.NewStream(
1,093✔
2023
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
1,093✔
2024
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
1,093✔
2025
                        tlv.MakePrimitiveRecord(amtType, &amt),
1,093✔
2026
                        tlv.MakePrimitiveRecord(
1,093✔
2027
                                acceptHeightType, &htlc.AcceptHeight,
1,093✔
2028
                        ),
1,093✔
2029
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
1,093✔
2030
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
1,093✔
2031
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
1,093✔
2032
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
1,093✔
2033
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
1,093✔
2034
                        tlv.MakeDynamicRecord(
1,093✔
2035
                                htlcAMPType, amp, amp.PayloadSize,
1,093✔
2036
                                record.AMPEncoder, record.AMPDecoder,
1,093✔
2037
                        ),
1,093✔
2038
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
1,093✔
2039
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
1,093✔
2040
                )
1,093✔
2041
                if err != nil {
1,093✔
2042
                        return nil, err
×
2043
                }
×
2044

2045
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
1,093✔
2046
                if err != nil {
1,093✔
2047
                        return nil, err
×
2048
                }
×
2049

2050
                if _, ok := parsedTypes[htlcAMPType]; !ok {
2,126✔
2051
                        amp = nil
1,033✔
2052
                }
1,033✔
2053

2054
                var preimage *lntypes.Preimage
1,093✔
2055
                if _, ok := parsedTypes[htlcPreimageType]; ok {
1,132✔
2056
                        pimg := lntypes.Preimage(*preimage32)
39✔
2057
                        preimage = &pimg
39✔
2058
                }
39✔
2059

2060
                var hash *lntypes.Hash
1,093✔
2061
                if _, ok := parsedTypes[htlcHashType]; ok {
1,153✔
2062
                        h := lntypes.Hash(*hash32)
60✔
2063
                        hash = &h
60✔
2064
                }
60✔
2065

2066
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
1,093✔
2067
                htlc.AcceptTime = getNanoTime(acceptTime)
1,093✔
2068
                htlc.ResolveTime = getNanoTime(resolveTime)
1,093✔
2069
                htlc.State = invpkg.HtlcState(state)
1,093✔
2070
                htlc.Amt = lnwire.MilliSatoshi(amt)
1,093✔
2071
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
1,093✔
2072
                if amp != nil && hash != nil {
1,153✔
2073
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
60✔
2074
                                Record:   *amp,
60✔
2075
                                Hash:     *hash,
60✔
2076
                                Preimage: preimage,
60✔
2077
                        }
60✔
2078
                }
60✔
2079

2080
                // Reconstruct the custom records fields from the parsed types
2081
                // map return from the tlv parser.
2082
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
1,093✔
2083

1,093✔
2084
                htlcs[key] = &htlc
1,093✔
2085
        }
2086

2087
        return htlcs, nil
2,299✔
2088
}
2089

2090
// invoiceSetIDKeyLen is the length of the key that's used to store the
2091
// individual HTLCs prefixed by their ID along side the main invoice within the
2092
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2093
// set ID.
2094
const invoiceSetIDKeyLen = 4 + 32
2095

2096
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2097
// number where the HTLCs for this setID will be stored udner.
2098
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
64✔
2099
        // Construct the prefix key we need to obtain the invoice information:
64✔
2100
        // invoiceNum || setID.
64✔
2101
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
64✔
2102
        copy(invoiceSetIDKey[:], invoiceNum)
64✔
2103
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
64✔
2104

64✔
2105
        return invoiceSetIDKey
64✔
2106
}
64✔
2107

2108
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
2109
// greater AMP invoices. We do this by deleting the set of keys that share the
2110
// invoice number as a prefix.
2111
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
15✔
2112
        // Since it isn't safe to delete using an active cursor, we'll use the
15✔
2113
        // cursor simply to collect the set of keys we need to delete, _then_
15✔
2114
        // delete them in another pass.
15✔
2115
        var keysToDel [][]byte
15✔
2116
        err := forEachAMPInvoice(
15✔
2117
                invoiceBucket, invoiceNum,
15✔
2118
                func(cursorKey, v []byte) error {
18✔
2119
                        keysToDel = append(keysToDel, cursorKey)
3✔
2120
                        return nil
3✔
2121
                },
3✔
2122
        )
2123
        if err != nil {
15✔
2124
                return err
×
2125
        }
×
2126

2127
        // In this next phase, we'll then delete all the relevant invoices.
2128
        for _, keyToDel := range keysToDel {
18✔
2129
                if err := invoiceBucket.Delete(keyToDel); err != nil {
3✔
2130
                        return err
×
2131
                }
×
2132
        }
2133

2134
        return nil
15✔
2135
}
2136

2137
// delAMPSettleIndex removes all the entries in the settle index associated
2138
// with a given AMP invoice.
2139
func delAMPSettleIndex(invoiceNum []byte, invoices,
2140
        settleIndex kvdb.RwBucket) error {
8✔
2141

8✔
2142
        // First, we need to grab the AMP invoice state to see if there's
8✔
2143
        // anything that we even need to delete.
8✔
2144
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
8✔
2145
        if err != nil {
8✔
2146
                return err
×
2147
        }
×
2148

2149
        // If there's no AMP state at all (non-AMP invoice), then we can return
2150
        // early.
2151
        if len(ampState) == 0 {
15✔
2152
                return nil
7✔
2153
        }
7✔
2154

2155
        // Otherwise, we'll need to iterate and delete each settle index within
2156
        // the set of returned entries.
2157
        var settleIndexKey [8]byte
1✔
2158
        for _, subState := range ampState {
4✔
2159
                byteOrder.PutUint64(
3✔
2160
                        settleIndexKey[:], subState.SettleIndex,
3✔
2161
                )
3✔
2162

3✔
2163
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
3✔
2164
                        return err
×
2165
                }
×
2166
        }
2167

2168
        return nil
1✔
2169
}
2170

2171
// DeleteCanceledInvoices deletes all canceled invoices from the database.
2172
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
3✔
2173
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
2174
                invoices := tx.ReadWriteBucket(invoiceBucket)
3✔
2175
                if invoices == nil {
3✔
2176
                        return nil
×
2177
                }
×
2178

2179
                invoiceIndex := invoices.NestedReadWriteBucket(
3✔
2180
                        invoiceIndexBucket,
3✔
2181
                )
3✔
2182
                if invoiceIndex == nil {
4✔
2183
                        return nil
1✔
2184
                }
1✔
2185

2186
                invoiceAddIndex := invoices.NestedReadWriteBucket(
2✔
2187
                        addIndexBucket,
2✔
2188
                )
2✔
2189
                if invoiceAddIndex == nil {
2✔
2190
                        return nil
×
2191
                }
×
2192

2193
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
2✔
2194

2✔
2195
                return invoiceIndex.ForEach(func(k, v []byte) error {
19✔
2196
                        // Skip the special numInvoicesKey as that does not
17✔
2197
                        // point to a valid invoice.
17✔
2198
                        if bytes.Equal(k, numInvoicesKey) {
19✔
2199
                                return nil
2✔
2200
                        }
2✔
2201

2202
                        // Skip sub-buckets.
2203
                        if v == nil {
15✔
2204
                                return nil
×
2205
                        }
×
2206

2207
                        invoice, err := fetchInvoice(v, invoices, nil, false)
15✔
2208
                        if err != nil {
15✔
2209
                                return err
×
2210
                        }
×
2211

2212
                        if invoice.State != invpkg.ContractCanceled {
23✔
2213
                                return nil
8✔
2214
                        }
8✔
2215

2216
                        // Delete the payment hash from the invoice index.
2217
                        err = invoiceIndex.Delete(k)
7✔
2218
                        if err != nil {
7✔
2219
                                return err
×
2220
                        }
×
2221

2222
                        // Delete payment address index reference if there's a
2223
                        // valid payment address.
2224
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
14✔
2225
                                // To ensure consistency check that the already
7✔
2226
                                // fetched invoice key matches the one in the
7✔
2227
                                // payment address index.
7✔
2228
                                key := payAddrIndex.Get(
7✔
2229
                                        invoice.Terms.PaymentAddr[:],
7✔
2230
                                )
7✔
2231
                                if bytes.Equal(key, k) {
7✔
2232
                                        // Delete from the payment address
×
2233
                                        // index.
×
2234
                                        if err := payAddrIndex.Delete(
×
2235
                                                invoice.Terms.PaymentAddr[:],
×
2236
                                        ); err != nil {
×
2237
                                                return err
×
2238
                                        }
×
2239
                                }
2240
                        }
2241

2242
                        // Remove from the add index.
2243
                        var addIndexKey [8]byte
7✔
2244
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
7✔
2245
                        err = invoiceAddIndex.Delete(addIndexKey[:])
7✔
2246
                        if err != nil {
7✔
2247
                                return err
×
2248
                        }
×
2249

2250
                        // Note that we don't need to delete the invoice from
2251
                        // the settle index as it is not added until the
2252
                        // invoice is settled.
2253

2254
                        // Now remove all sub invoices.
2255
                        err = delAMPInvoices(k, invoices)
7✔
2256
                        if err != nil {
7✔
2257
                                return err
×
2258
                        }
×
2259

2260
                        // Finally remove the serialized invoice from the
2261
                        // invoice bucket.
2262
                        return invoices.Delete(k)
7✔
2263
                })
2264
        }, func() {})
3✔
2265
}
2266

2267
// DeleteInvoice attempts to delete the passed invoices from the database in
2268
// one transaction. The passed delete references hold all keys required to
2269
// delete the invoices without also needing to deserialize them.
2270
func (d *DB) DeleteInvoice(_ context.Context,
2271
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
6✔
2272

6✔
2273
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
12✔
2274
                invoices := tx.ReadWriteBucket(invoiceBucket)
6✔
2275
                if invoices == nil {
6✔
2276
                        return invpkg.ErrNoInvoicesCreated
×
2277
                }
×
2278

2279
                invoiceIndex := invoices.NestedReadWriteBucket(
6✔
2280
                        invoiceIndexBucket,
6✔
2281
                )
6✔
2282
                if invoiceIndex == nil {
6✔
2283
                        return invpkg.ErrNoInvoicesCreated
×
2284
                }
×
2285

2286
                invoiceAddIndex := invoices.NestedReadWriteBucket(
6✔
2287
                        addIndexBucket,
6✔
2288
                )
6✔
2289
                if invoiceAddIndex == nil {
6✔
2290
                        return invpkg.ErrNoInvoicesCreated
×
2291
                }
×
2292

2293
                // settleIndex can be nil, as the bucket is created lazily
2294
                // when the first invoice is settled.
2295
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
6✔
2296

6✔
2297
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
6✔
2298

6✔
2299
                for _, ref := range invoicesToDelete {
17✔
2300
                        // Fetch the invoice key for using it to check for
11✔
2301
                        // consistency and also to delete from the invoice
11✔
2302
                        // index.
11✔
2303
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
11✔
2304
                        if invoiceKey == nil {
12✔
2305
                                return invpkg.ErrInvoiceNotFound
1✔
2306
                        }
1✔
2307

2308
                        err := invoiceIndex.Delete(ref.PayHash[:])
10✔
2309
                        if err != nil {
10✔
2310
                                return err
×
2311
                        }
×
2312

2313
                        // Delete payment address index reference if there's a
2314
                        // valid payment address passed.
2315
                        if ref.PayAddr != nil {
19✔
2316
                                // To ensure consistency check that the already
9✔
2317
                                // fetched invoice key matches the one in the
9✔
2318
                                // payment address index.
9✔
2319
                                key := payAddrIndex.Get(ref.PayAddr[:])
9✔
2320
                                if bytes.Equal(key, invoiceKey) {
18✔
2321
                                        // Delete from the payment address
9✔
2322
                                        // index. Note that since the payment
9✔
2323
                                        // address index has been introduced
9✔
2324
                                        // with an empty migration it may be
9✔
2325
                                        // possible that the index doesn't have
9✔
2326
                                        // an entry for this invoice.
9✔
2327
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
9✔
2328
                                        if err := payAddrIndex.Delete(
9✔
2329
                                                ref.PayAddr[:],
9✔
2330
                                        ); err != nil {
9✔
2331
                                                return err
×
2332
                                        }
×
2333
                                }
2334
                        }
2335

2336
                        var addIndexKey [8]byte
10✔
2337
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
10✔
2338

10✔
2339
                        // To ensure consistency check that the key stored in
10✔
2340
                        // the add index also matches the previously fetched
10✔
2341
                        // invoice key.
10✔
2342
                        key := invoiceAddIndex.Get(addIndexKey[:])
10✔
2343
                        if !bytes.Equal(key, invoiceKey) {
11✔
2344
                                return fmt.Errorf("unknown invoice " +
1✔
2345
                                        "in add index")
1✔
2346
                        }
1✔
2347

2348
                        // Remove from the add index.
2349
                        err = invoiceAddIndex.Delete(addIndexKey[:])
9✔
2350
                        if err != nil {
9✔
2351
                                return err
×
2352
                        }
×
2353

2354
                        // Remove from the settle index if available and
2355
                        // if the invoice is settled.
2356
                        if settleIndex != nil && ref.SettleIndex > 0 {
12✔
2357
                                var settleIndexKey [8]byte
3✔
2358
                                byteOrder.PutUint64(
3✔
2359
                                        settleIndexKey[:], ref.SettleIndex,
3✔
2360
                                )
3✔
2361

3✔
2362
                                // To ensure consistency check that the already
3✔
2363
                                // fetched invoice key matches the one in the
3✔
2364
                                // settle index
3✔
2365
                                key := settleIndex.Get(settleIndexKey[:])
3✔
2366
                                if !bytes.Equal(key, invoiceKey) {
4✔
2367
                                        return fmt.Errorf("unknown invoice " +
1✔
2368
                                                "in settle index")
1✔
2369
                                }
1✔
2370

2371
                                err = settleIndex.Delete(settleIndexKey[:])
2✔
2372
                                if err != nil {
2✔
2373
                                        return err
×
2374
                                }
×
2375
                        }
2376

2377
                        // In addition to deleting the main invoice state, if
2378
                        // this is an AMP invoice, then we'll also need to
2379
                        // delete the set HTLC set stored as a key prefix. For
2380
                        // non-AMP invoices, this'll be a noop.
2381
                        err = delAMPSettleIndex(
8✔
2382
                                invoiceKey, invoices, settleIndex,
8✔
2383
                        )
8✔
2384
                        if err != nil {
8✔
2385
                                return err
×
2386
                        }
×
2387
                        err = delAMPInvoices(invoiceKey, invoices)
8✔
2388
                        if err != nil {
8✔
2389
                                return err
×
2390
                        }
×
2391

2392
                        // Finally remove the serialized invoice from the
2393
                        // invoice bucket.
2394
                        err = invoices.Delete(invoiceKey)
8✔
2395
                        if err != nil {
8✔
2396
                                return err
×
2397
                        }
×
2398
                }
2399

2400
                return nil
3✔
2401
        }, func() {})
6✔
2402

2403
        return err
6✔
2404
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc