• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13586005509

28 Feb 2025 10:14AM UTC coverage: 68.629% (+9.9%) from 58.77%
13586005509

Pull #9521

github

web-flow
Merge 37d3a70a5 into 8532955b3
Pull Request #9521: unit: remove GOACC, use Go 1.20 native coverage functionality

129950 of 189351 relevant lines covered (68.63%)

23726.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.23
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "time"
11

12
        "github.com/lightningnetwork/lnd/graph/db/models"
13
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
14
        invpkg "github.com/lightningnetwork/lnd/invoices"
15
        "github.com/lightningnetwork/lnd/kvdb"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/tlv"
20
)
21

22
var (
23
        // invoiceBucket is the name of the bucket within the database that
24
        // stores all data related to invoices no matter their final state.
25
        // Within the invoice bucket, each invoice is keyed by its invoice ID
26
        // which is a monotonically increasing uint32.
27
        invoiceBucket = []byte("invoices")
28

29
        // paymentHashIndexBucket is the name of the sub-bucket within the
30
        // invoiceBucket which indexes all invoices by their payment hash. The
31
        // payment hash is the sha256 of the invoice's payment preimage. This
32
        // index is used to detect duplicates, and also to provide a fast path
33
        // for looking up incoming HTLCs to determine if we're able to settle
34
        // them fully.
35
        //
36
        // maps: payHash => invoiceKey
37
        invoiceIndexBucket = []byte("paymenthashes")
38

39
        // payAddrIndexBucket is the name of the top-level bucket that maps
40
        // payment addresses to their invoice number. This can be used
41
        // to efficiently query or update non-legacy invoices. Note that legacy
42
        // invoices will not be included in this index since they all have the
43
        // same, all-zero payment address, however all newly generated invoices
44
        // will end up in this index.
45
        //
46
        // maps: payAddr => invoiceKey
47
        payAddrIndexBucket = []byte("pay-addr-index")
48

49
        // setIDIndexBucket is the name of the top-level bucket that maps set
50
        // ids to their invoice number. This can be used to efficiently query or
51
        // update AMP invoice. Note that legacy or MPP invoices will not be
52
        // included in this index, since their HTLCs do not have a set id.
53
        //
54
        // maps: setID => invoiceKey
55
        setIDIndexBucket = []byte("set-id-index")
56

57
        // numInvoicesKey is the name of key which houses the auto-incrementing
58
        // invoice ID which is essentially used as a primary key. With each
59
        // invoice inserted, the primary key is incremented by one. This key is
60
        // stored within the invoiceIndexBucket. Within the invoiceBucket
61
        // invoices are uniquely identified by the invoice ID.
62
        numInvoicesKey = []byte("nik")
63

64
        // addIndexBucket is an index bucket that we'll use to create a
65
        // monotonically increasing set of add indexes. Each time we add a new
66
        // invoice, this sequence number will be incremented and then populated
67
        // within the new invoice.
68
        //
69
        // In addition to this sequence number, we map:
70
        //
71
        //   addIndexNo => invoiceKey
72
        addIndexBucket = []byte("invoice-add-index")
73

74
        // settleIndexBucket is an index bucket that we'll use to create a
75
        // monotonically increasing integer for tracking a "settle index". Each
76
        // time an invoice is settled, this sequence number will be incremented
77
        // as populate within the newly settled invoice.
78
        //
79
        // In addition to this sequence number, we map:
80
        //
81
        //   settleIndexNo => invoiceKey
82
        settleIndexBucket = []byte("invoice-settle-index")
83

84
        // invoiceBucketTombstone is a special key that indicates the invoice
85
        // bucket has been permanently closed. Its purpose is to prevent the
86
        // invoice bucket from being reopened in the future. A key use case for
87
        // the tombstone is to ensure users cannot switch back to the KV invoice
88
        // database after migrating to the native SQL database.
89
        invoiceBucketTombstone = []byte("invoice-tombstone")
90
)
91

92
const (
93
        // A set of tlv type definitions used to serialize invoice htlcs to the
94
        // database.
95
        //
96
        // NOTE: A migration should be added whenever this list changes. This
97
        // prevents against the database being rolled back to an older
98
        // format where the surrounding logic might assume a different set of
99
        // fields are known.
100
        chanIDType       tlv.Type = 1
101
        htlcIDType       tlv.Type = 3
102
        amtType          tlv.Type = 5
103
        acceptHeightType tlv.Type = 7
104
        acceptTimeType   tlv.Type = 9
105
        resolveTimeType  tlv.Type = 11
106
        expiryHeightType tlv.Type = 13
107
        htlcStateType    tlv.Type = 15
108
        mppTotalAmtType  tlv.Type = 17
109
        htlcAMPType      tlv.Type = 19
110
        htlcHashType     tlv.Type = 21
111
        htlcPreimageType tlv.Type = 23
112

113
        // A set of tlv type definitions used to serialize invoice bodiees.
114
        //
115
        // NOTE: A migration should be added whenever this list changes. This
116
        // prevents against the database being rolled back to an older
117
        // format where the surrounding logic might assume a different set of
118
        // fields are known.
119
        memoType            tlv.Type = 0
120
        payReqType          tlv.Type = 1
121
        createTimeType      tlv.Type = 2
122
        settleTimeType      tlv.Type = 3
123
        addIndexType        tlv.Type = 4
124
        settleIndexType     tlv.Type = 5
125
        preimageType        tlv.Type = 6
126
        valueType           tlv.Type = 7
127
        cltvDeltaType       tlv.Type = 8
128
        expiryType          tlv.Type = 9
129
        paymentAddrType     tlv.Type = 10
130
        featuresType        tlv.Type = 11
131
        invStateType        tlv.Type = 12
132
        amtPaidType         tlv.Type = 13
133
        hodlInvoiceType     tlv.Type = 14
134
        invoiceAmpStateType tlv.Type = 15
135

136
        // A set of tlv type definitions used to serialize the invoice AMP
137
        // state along-side the main invoice body.
138
        ampStateSetIDType       tlv.Type = 0
139
        ampStateHtlcStateType   tlv.Type = 1
140
        ampStateSettleIndexType tlv.Type = 2
141
        ampStateSettleDateType  tlv.Type = 3
142
        ampStateCircuitKeysType tlv.Type = 4
143
        ampStateAmtPaidType     tlv.Type = 5
144
)
145

146
// AddInvoice inserts the targeted invoice into the database. If the invoice has
147
// *any* payment hashes which already exists within the database, then the
148
// insertion will be aborted and rejected due to the strict policy banning any
149
// duplicate payment hashes. A side effect of this function is that it sets
150
// AddIndex on newInvoice.
151
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
152
        paymentHash lntypes.Hash) (uint64, error) {
628✔
153

628✔
154
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
630✔
155
                return 0, err
2✔
156
        }
2✔
157

158
        var invoiceAddIndex uint64
626✔
159
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
1,252✔
160
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
626✔
161
                if err != nil {
626✔
162
                        return err
×
163
                }
×
164

165
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
626✔
166
                        invoiceIndexBucket,
626✔
167
                )
626✔
168
                if err != nil {
626✔
169
                        return err
×
170
                }
×
171
                addIndex, err := invoices.CreateBucketIfNotExists(
626✔
172
                        addIndexBucket,
626✔
173
                )
626✔
174
                if err != nil {
626✔
175
                        return err
×
176
                }
×
177

178
                // Ensure that an invoice an identical payment hash doesn't
179
                // already exist within the index.
180
                if invoiceIndex.Get(paymentHash[:]) != nil {
629✔
181
                        return invpkg.ErrDuplicateInvoice
3✔
182
                }
3✔
183

184
                // Check that we aren't inserting an invoice with a duplicate
185
                // payment address. The all-zeros payment address is
186
                // special-cased to support legacy keysend invoices which don't
187
                // assign one. This is safe since later we also will avoid
188
                // indexing them and avoid collisions.
189
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
623✔
190
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
1,124✔
191
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
501✔
192
                        if payAddrIndex.Get(paymentAddr) != nil {
509✔
193
                                return invpkg.ErrDuplicatePayAddr
8✔
194
                        }
8✔
195
                }
196

197
                // If the current running payment ID counter hasn't yet been
198
                // created, then create it now.
199
                var invoiceNum uint32
618✔
200
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
618✔
201
                if invoiceCounter == nil {
820✔
202
                        var scratch [4]byte
202✔
203
                        byteOrder.PutUint32(scratch[:], invoiceNum)
202✔
204
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
202✔
205
                        if err != nil {
202✔
206
                                return err
×
207
                        }
×
208
                } else {
419✔
209
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
419✔
210
                }
419✔
211

212
                newIndex, err := putInvoice(
618✔
213
                        invoices, invoiceIndex, payAddrIndex, addIndex,
618✔
214
                        newInvoice, invoiceNum, paymentHash,
618✔
215
                )
618✔
216
                if err != nil {
618✔
217
                        return err
×
218
                }
×
219

220
                invoiceAddIndex = newIndex
618✔
221
                return nil
618✔
222
        }, func() {
626✔
223
                invoiceAddIndex = 0
626✔
224
        })
626✔
225
        if err != nil {
637✔
226
                return 0, err
11✔
227
        }
11✔
228

229
        return invoiceAddIndex, err
618✔
230
}
231

232
// InvoicesAddedSince can be used by callers to seek into the event time series
233
// of all the invoices added in the database. The specified sinceAddIndex
234
// should be the highest add index that the caller knows of. This method will
235
// return all invoices with an add index greater than the specified
236
// sinceAddIndex.
237
//
238
// NOTE: The index starts from 1, as a result. We enforce that specifying a
239
// value below the starting index value is a noop.
240
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
241
        []invpkg.Invoice, error) {
23✔
242

23✔
243
        var newInvoices []invpkg.Invoice
23✔
244

23✔
245
        // If an index of zero was specified, then in order to maintain
23✔
246
        // backwards compat, we won't send out any new invoices.
23✔
247
        if sinceAddIndex == 0 {
43✔
248
                return newInvoices, nil
20✔
249
        }
20✔
250

251
        var startIndex [8]byte
6✔
252
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
6✔
253

6✔
254
        err := kvdb.View(d, func(tx kvdb.RTx) error {
12✔
255
                invoices := tx.ReadBucket(invoiceBucket)
6✔
256
                if invoices == nil {
6✔
257
                        return nil
×
258
                }
×
259

260
                addIndex := invoices.NestedReadBucket(addIndexBucket)
6✔
261
                if addIndex == nil {
6✔
262
                        return nil
×
263
                }
×
264

265
                // We'll now run through each entry in the add index starting
266
                // at our starting index. We'll continue until we reach the
267
                // very end of the current key space.
268
                invoiceCursor := addIndex.ReadCursor()
6✔
269

6✔
270
                // We'll seek to the starting index, then manually advance the
6✔
271
                // cursor in order to skip the entry with the since add index.
6✔
272
                invoiceCursor.Seek(startIndex[:])
6✔
273
                addSeqNo, invoiceKey := invoiceCursor.Next()
6✔
274

6✔
275
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
38✔
276

32✔
277
                        // For each key found, we'll look up the actual
32✔
278
                        // invoice, then accumulate it into our return value.
32✔
279
                        invoice, err := fetchInvoice(
32✔
280
                                invoiceKey, invoices, nil, false,
32✔
281
                        )
32✔
282
                        if err != nil {
32✔
283
                                return err
×
284
                        }
×
285

286
                        newInvoices = append(newInvoices, invoice)
32✔
287
                }
288

289
                return nil
6✔
290
        }, func() {
6✔
291
                newInvoices = nil
6✔
292
        })
6✔
293
        if err != nil {
6✔
294
                return nil, err
×
295
        }
×
296

297
        return newInvoices, nil
6✔
298
}
299

300
// LookupInvoice attempts to look up an invoice according to its 32 byte
301
// payment hash. If an invoice which can settle the HTLC identified by the
302
// passed payment hash isn't found, then an error is returned. Otherwise, the
303
// full invoice is returned. Before setting the incoming HTLC, the values
304
// SHOULD be checked to ensure the payer meets the agreed upon contractual
305
// terms of the payment.
306
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
307
        invpkg.Invoice, error) {
645✔
308

645✔
309
        var invoice invpkg.Invoice
645✔
310
        err := kvdb.View(d, func(tx kvdb.RTx) error {
1,290✔
311
                invoices := tx.ReadBucket(invoiceBucket)
645✔
312
                if invoices == nil {
645✔
313
                        return invpkg.ErrNoInvoicesCreated
×
314
                }
×
315
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
645✔
316
                if invoiceIndex == nil {
661✔
317
                        return invpkg.ErrNoInvoicesCreated
16✔
318
                }
16✔
319
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
632✔
320
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
632✔
321

632✔
322
                // Retrieve the invoice number for this invoice using
632✔
323
                // the provided invoice reference.
632✔
324
                invoiceNum, err := fetchInvoiceNumByRef(
632✔
325
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
632✔
326
                )
632✔
327
                if err != nil {
643✔
328
                        return err
11✔
329
                }
11✔
330

331
                var setID *invpkg.SetID
624✔
332
                switch {
624✔
333
                // If this is a payment address ref, and the blank modified was
334
                // specified, then we'll use the zero set ID to indicate that
335
                // we won't want any HTLCs returned.
336
                case ref.PayAddr() != nil &&
337
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
4✔
338

4✔
339
                        var zeroSetID invpkg.SetID
4✔
340
                        setID = &zeroSetID
4✔
341

342
                // If this is a set ID ref, and the htlc set only modified was
343
                // specified, then we'll pass through the specified setID so
344
                // only that will be returned.
345
                case ref.SetID() != nil &&
346
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
7✔
347

7✔
348
                        setID = (*invpkg.SetID)(ref.SetID())
7✔
349
                }
350

351
                // An invoice was found, retrieve the remainder of the invoice
352
                // body.
353
                i, err := fetchInvoice(
624✔
354
                        invoiceNum, invoices, []*invpkg.SetID{setID}, true,
624✔
355
                )
624✔
356
                if err != nil {
624✔
357
                        return err
×
358
                }
×
359
                invoice = i
624✔
360

624✔
361
                return nil
624✔
362
        }, func() {})
645✔
363
        if err != nil {
669✔
364
                return invoice, err
24✔
365
        }
24✔
366

367
        return invoice, nil
624✔
368
}
369

370
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
371
// reference. The payment address will be treated as the primary key, falling
372
// back to the payment hash if nothing is found for the payment address. An
373
// error is returned if the invoice is not found.
374
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
375
        ref invpkg.InvoiceRef) ([]byte, error) {
1,273✔
376

1,273✔
377
        // If the set id is present, we only consult the set id index for this
1,273✔
378
        // invoice. This type of query is only used to facilitate user-facing
1,273✔
379
        // requests to lookup, settle or cancel an AMP invoice.
1,273✔
380
        setID := ref.SetID()
1,273✔
381
        if setID != nil {
1,288✔
382
                invoiceNumBySetID := setIDIndex.Get(setID[:])
15✔
383
                if invoiceNumBySetID == nil {
16✔
384
                        return nil, invpkg.ErrInvoiceNotFound
1✔
385
                }
1✔
386

387
                return invoiceNumBySetID, nil
14✔
388
        }
389

390
        payHash := ref.PayHash()
1,261✔
391
        payAddr := ref.PayAddr()
1,261✔
392

1,261✔
393
        getInvoiceNumByHash := func() []byte {
2,522✔
394
                if payHash != nil {
2,488✔
395
                        return invoiceIndex.Get(payHash[:])
1,227✔
396
                }
1,227✔
397
                return nil
37✔
398
        }
399

400
        getInvoiceNumByAddr := func() []byte {
2,522✔
401
                if payAddr != nil {
1,760✔
402
                        // Only allow lookups for payment address if it is not a
499✔
403
                        // blank payment address, which is a special-cased value
499✔
404
                        // for legacy keysend invoices.
499✔
405
                        if *payAddr != invpkg.BlankPayAddr {
567✔
406
                                return payAddrIndex.Get(payAddr[:])
68✔
407
                        }
68✔
408
                }
409
                return nil
1,196✔
410
        }
411

412
        invoiceNumByHash := getInvoiceNumByHash()
1,261✔
413
        invoiceNumByAddr := getInvoiceNumByAddr()
1,261✔
414
        switch {
1,261✔
415
        // If payment address and payment hash both reference an existing
416
        // invoice, ensure they reference the _same_ invoice.
417
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
34✔
418
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
36✔
419
                        return nil, invpkg.ErrInvRefEquivocation
2✔
420
                }
2✔
421

422
                return invoiceNumByAddr, nil
32✔
423

424
        // Return invoices by payment addr only.
425
        //
426
        // NOTE: We constrain this lookup to only apply if the invoice ref does
427
        // not contain a payment hash. Legacy and MPP payments depend on the
428
        // payment hash index to enforce that the HTLCs payment hash matches the
429
        // payment hash for the invoice, without this check we would
430
        // inadvertently assume the invoice contains the correct preimage for
431
        // the HTLC, which we only enforce via the lookup by the invoice index.
432
        case invoiceNumByAddr != nil && payHash == nil:
36✔
433
                return invoiceNumByAddr, nil
36✔
434

435
        // If we were only able to reference the invoice by hash, return the
436
        // corresponding invoice number. This can happen when no payment address
437
        // was provided, or if it didn't match anything in our records.
438
        case invoiceNumByHash != nil:
1,188✔
439
                return invoiceNumByHash, nil
1,188✔
440

441
        // Otherwise we don't know of the target invoice.
442
        default:
12✔
443
                return nil, invpkg.ErrInvoiceNotFound
12✔
444
        }
445
}
446

447
// FetchPendingInvoices returns all invoices that have not yet been settled or
448
// canceled. The returned map is keyed by the payment hash of each respective
449
// invoice.
450
func (d *DB) FetchPendingInvoices(_ context.Context) (
451
        map[lntypes.Hash]invpkg.Invoice, error) {
392✔
452

392✔
453
        result := make(map[lntypes.Hash]invpkg.Invoice)
392✔
454

392✔
455
        err := kvdb.View(d, func(tx kvdb.RTx) error {
784✔
456
                invoices := tx.ReadBucket(invoiceBucket)
392✔
457
                if invoices == nil {
392✔
458
                        return nil
×
459
                }
×
460

461
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
392✔
462
                if invoiceIndex == nil {
781✔
463
                        // Mask the error if there's no invoice
389✔
464
                        // index as that simply means there are no
389✔
465
                        // invoices added yet to the DB. In this case
389✔
466
                        // we simply return an empty list.
389✔
467
                        return nil
389✔
468
                }
389✔
469

470
                return invoiceIndex.ForEach(func(k, v []byte) error {
47✔
471
                        // Skip the special numInvoicesKey as that does not
41✔
472
                        // point to a valid invoice.
41✔
473
                        if bytes.Equal(k, numInvoicesKey) {
47✔
474
                                return nil
6✔
475
                        }
6✔
476

477
                        // Skip sub-buckets.
478
                        if v == nil {
38✔
479
                                return nil
×
480
                        }
×
481

482
                        invoice, err := fetchInvoice(v, invoices, nil, false)
38✔
483
                        if err != nil {
38✔
484
                                return err
×
485
                        }
×
486

487
                        if invoice.IsPending() {
61✔
488
                                var paymentHash lntypes.Hash
23✔
489
                                copy(paymentHash[:], k)
23✔
490
                                result[paymentHash] = invoice
23✔
491
                        }
23✔
492

493
                        return nil
38✔
494
                })
495
        }, func() {
392✔
496
                result = make(map[lntypes.Hash]invpkg.Invoice)
392✔
497
        })
392✔
498

499
        if err != nil {
392✔
500
                return nil, err
×
501
        }
×
502

503
        return result, nil
392✔
504
}
505

506
// QueryInvoices allows a caller to query the invoice database for invoices
507
// within the specified add index range.
508
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
509
        invpkg.InvoiceSlice, error) {
56✔
510

56✔
511
        var resp invpkg.InvoiceSlice
56✔
512

56✔
513
        err := kvdb.View(d, func(tx kvdb.RTx) error {
112✔
514
                // If the bucket wasn't found, then there aren't any invoices
56✔
515
                // within the database yet, so we can simply exit.
56✔
516
                invoices := tx.ReadBucket(invoiceBucket)
56✔
517
                if invoices == nil {
56✔
518
                        return invpkg.ErrNoInvoicesCreated
×
519
                }
×
520

521
                // Get the add index bucket which we will use to iterate through
522
                // our indexed invoices.
523
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
56✔
524
                if invoiceAddIndex == nil {
63✔
525
                        return invpkg.ErrNoInvoicesCreated
7✔
526
                }
7✔
527

528
                // Create a paginator which reads from our add index bucket with
529
                // the parameters provided by the invoice query.
530
                paginator := newPaginator(
52✔
531
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
52✔
532
                        q.NumMaxInvoices,
52✔
533
                )
52✔
534

52✔
535
                // accumulateInvoices looks up an invoice based on the index we
52✔
536
                // are given, adds it to our set of invoices if it has the right
52✔
537
                // characteristics for our query and returns the number of items
52✔
538
                // we have added to our set of invoices.
52✔
539
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
1,002✔
540
                        invoice, err := fetchInvoice(
950✔
541
                                indexValue, invoices, nil, false,
950✔
542
                        )
950✔
543
                        if err != nil {
950✔
544
                                return false, err
×
545
                        }
×
546

547
                        // Skip any settled or canceled invoices if the caller
548
                        // is only interested in pending ones.
549
                        if q.PendingOnly && !invoice.IsPending() {
1,105✔
550
                                return false, nil
155✔
551
                        }
155✔
552

553
                        // Get the creation time in Unix seconds, this always
554
                        // rounds down the nanoseconds to full seconds.
555
                        createTime := invoice.CreationDate.Unix()
795✔
556

795✔
557
                        // Skip any invoices that were created before the
795✔
558
                        // specified time.
795✔
559
                        if createTime < q.CreationDateStart {
895✔
560
                                return false, nil
100✔
561
                        }
100✔
562

563
                        // Skip any invoices that were created after the
564
                        // specified time.
565
                        if q.CreationDateEnd != 0 &&
698✔
566
                                createTime > q.CreationDateEnd {
879✔
567

181✔
568
                                return false, nil
181✔
569
                        }
181✔
570

571
                        // At this point, we've exhausted the offset, so we'll
572
                        // begin collecting invoices found within the range.
573
                        resp.Invoices = append(resp.Invoices, invoice)
520✔
574

520✔
575
                        return true, nil
520✔
576
                }
577

578
                // Query our paginator using accumulateInvoices to build up a
579
                // set of invoices.
580
                if err := paginator.query(accumulateInvoices); err != nil {
52✔
581
                        return err
×
582
                }
×
583

584
                // If we iterated through the add index in reverse order, then
585
                // we'll need to reverse the slice of invoices to return them in
586
                // forward order.
587
                if q.Reversed {
66✔
588
                        numInvoices := len(resp.Invoices)
14✔
589
                        for i := 0; i < numInvoices/2; i++ {
83✔
590
                                reverse := numInvoices - i - 1
69✔
591
                                resp.Invoices[i], resp.Invoices[reverse] =
69✔
592
                                        resp.Invoices[reverse], resp.Invoices[i]
69✔
593
                        }
69✔
594
                }
595

596
                return nil
52✔
597
        }, func() {
56✔
598
                resp = invpkg.InvoiceSlice{
56✔
599
                        InvoiceQuery: q,
56✔
600
                }
56✔
601
        })
56✔
602
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
56✔
603
                return resp, err
×
604
        }
×
605

606
        // Finally, record the indexes of the first and last invoices returned
607
        // so that the caller can resume from this point later on.
608
        if len(resp.Invoices) > 0 {
101✔
609
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
45✔
610
                lastIdx := len(resp.Invoices) - 1
45✔
611
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
45✔
612
        }
45✔
613

614
        return resp, nil
56✔
615
}
616

617
// UpdateInvoice attempts to update an invoice corresponding to the passed
618
// payment hash. If an invoice matching the passed payment hash doesn't exist
619
// within the database, then the action will fail with a "not found" error.
620
//
621
// The update is performed inside the same database transaction that fetches the
622
// invoice and is therefore atomic. The fields to update are controlled by the
623
// supplied callback.  When updating an invoice, the update itself happens
624
// in-memory on a copy of the invoice. Once it is written successfully to the
625
// database, the in-memory copy is returned to the caller.
626
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
627
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
628
        *invpkg.Invoice, error) {
644✔
629

644✔
630
        var updatedInvoice *invpkg.Invoice
644✔
631
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
1,288✔
632
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
644✔
633
                if err != nil {
644✔
634
                        return err
×
635
                }
×
636
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
644✔
637
                        invoiceIndexBucket,
644✔
638
                )
644✔
639
                if err != nil {
644✔
640
                        return err
×
641
                }
×
642
                settleIndex, err := invoices.CreateBucketIfNotExists(
644✔
643
                        settleIndexBucket,
644✔
644
                )
644✔
645
                if err != nil {
644✔
646
                        return err
×
647
                }
×
648
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
644✔
649
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
644✔
650

644✔
651
                // Retrieve the invoice number for this invoice using the
644✔
652
                // provided invoice reference.
644✔
653
                invoiceNum, err := fetchInvoiceNumByRef(
644✔
654
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
644✔
655
                )
644✔
656
                if err != nil {
648✔
657
                        return err
4✔
658
                }
4✔
659

660
                // setIDHint can also be nil here, which means all the HTLCs
661
                // for AMP invoices are fetched. If the blank setID is passed
662
                // in, then no HTLCs are fetched for the AMP invoice. If a
663
                // specific setID is passed in, then only the HTLCs for that
664
                // setID are fetched for a particular sub-AMP invoice.
665
                invoice, err := fetchInvoice(
640✔
666
                        invoiceNum, invoices, []*invpkg.SetID{setIDHint}, false,
640✔
667
                )
640✔
668
                if err != nil {
640✔
669
                        return err
×
670
                }
×
671

672
                now := d.clock.Now()
640✔
673
                updater := &kvInvoiceUpdater{
640✔
674
                        db:                d,
640✔
675
                        invoicesBucket:    invoices,
640✔
676
                        settleIndexBucket: settleIndex,
640✔
677
                        setIDIndexBucket:  setIDIndex,
640✔
678
                        updateTime:        now,
640✔
679
                        invoiceNum:        invoiceNum,
640✔
680
                        invoice:           &invoice,
640✔
681
                        updatedAmpHtlcs:   make(ampHTLCsMap),
640✔
682
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
640✔
683
                }
640✔
684

640✔
685
                payHash := ref.PayHash()
640✔
686
                updatedInvoice, err = invpkg.UpdateInvoice(
640✔
687
                        payHash, updater.invoice, now, callback, updater,
640✔
688
                )
640✔
689
                if err != nil {
655✔
690
                        return err
15✔
691
                }
15✔
692

693
                // If this is an AMP update, then limit the returned AMP state
694
                // to only the requested set ID.
695
                if setIDHint != nil {
669✔
696
                        filterInvoiceAMPState(updatedInvoice, setIDHint)
41✔
697
                }
41✔
698

699
                return nil
628✔
700
        }, func() {
644✔
701
                updatedInvoice = nil
644✔
702
        })
644✔
703

704
        return updatedInvoice, err
644✔
705
}
706

707
// filterInvoiceAMPState filters the AMP state of the invoice to only include
708
// state for the specified set IDs.
709
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
77✔
710
        filteredAMPState := make(invpkg.AMPInvoiceState)
77✔
711

77✔
712
        for _, setID := range setIDs {
154✔
713
                if setID == nil {
110✔
714
                        return
33✔
715
                }
33✔
716

717
                ampState, ok := invoice.AMPState[*setID]
47✔
718
                if ok {
93✔
719
                        filteredAMPState[*setID] = ampState
46✔
720
                }
46✔
721
        }
722

723
        invoice.AMPState = filteredAMPState
47✔
724
}
725

726
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
727
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
728

729
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
730
// is used with the kv implementation of the invoice database. Note that this
731
// updater is not concurrency safe and synchronizaton is expected to be handled
732
// on the DB level.
733
type kvInvoiceUpdater struct {
734
        db                *DB
735
        invoicesBucket    kvdb.RwBucket
736
        settleIndexBucket kvdb.RwBucket
737
        setIDIndexBucket  kvdb.RwBucket
738

739
        // updateTime is the timestamp for the update.
740
        updateTime time.Time
741

742
        // invoiceNum is a legacy key similar to the add index that is used
743
        // only in the kv implementation.
744
        invoiceNum []byte
745

746
        // invoice is the invoice that we're updating. As a side effect of the
747
        // update this invoice will be mutated.
748
        invoice *invpkg.Invoice
749

750
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
751
        // cancelled as part of this update.
752
        updatedAmpHtlcs ampHTLCsMap
753

754
        // settledSetIDs holds the set IDs that are settled with this update.
755
        settledSetIDs map[invpkg.SetID]struct{}
756
}
757

758
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
759
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
760
        _ *invpkg.InvoiceHTLC) error {
500✔
761

500✔
762
        return nil
500✔
763
}
500✔
764

765
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
766
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
767
        _ time.Time) error {
494✔
768

494✔
769
        return nil
494✔
770
}
494✔
771

772
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
773
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
774
        _ lntypes.Preimage) error {
9✔
775

9✔
776
        return nil
9✔
777
}
9✔
778

779
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
780
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
781
        _ *lntypes.Preimage) error {
456✔
782

456✔
783
        return nil
456✔
784
}
456✔
785

786
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
787
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
573✔
788
        return nil
573✔
789
}
573✔
790

791
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
792
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
793
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
48✔
794

48✔
795
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
90✔
796
                switch state.State {
42✔
797
                case invpkg.HtlcStateAccepted:
26✔
798
                        // If we're just now creating the HTLCs for this set
26✔
799
                        // then we'll also pull in the existing HTLCs that are
26✔
800
                        // part of this set, so we can write them all to disk
26✔
801
                        // together (same value)
26✔
802
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
26✔
803
                                &setID, invpkg.HtlcStateAccepted,
26✔
804
                        )
26✔
805

806
                case invpkg.HtlcStateCanceled:
10✔
807
                        // Only HTLCs in the accepted state, can be cancelled,
10✔
808
                        // but we also want to merge that with HTLCs that may be
10✔
809
                        // canceled as well since it can be cancelled one by
10✔
810
                        // one.
10✔
811
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
10✔
812
                                &setID, invpkg.HtlcStateAccepted,
10✔
813
                        )
10✔
814

10✔
815
                        cancelledHtlcs := k.invoice.HTLCSet(
10✔
816
                                &setID, invpkg.HtlcStateCanceled,
10✔
817
                        )
10✔
818
                        for htlcKey, htlc := range cancelledHtlcs {
25✔
819
                                k.updatedAmpHtlcs[setID][htlcKey] = htlc
15✔
820
                        }
15✔
821

822
                case invpkg.HtlcStateSettled:
6✔
823
                        k.updatedAmpHtlcs[setID] = make(
6✔
824
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
6✔
825
                        )
6✔
826
                }
827
        }
828

829
        if state.State == invpkg.HtlcStateSettled {
62✔
830
                // Add the set ID to the set that was settled in this invoice
14✔
831
                // update. We'll use this later to update the settle index.
14✔
832
                k.settledSetIDs[setID] = struct{}{}
14✔
833
        }
14✔
834

835
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
48✔
836

48✔
837
        return nil
48✔
838
}
839

840
// Finalize finalizes the update before it is written to the database.
841
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
609✔
842
        switch updateType {
609✔
843
        case invpkg.AddHTLCsUpdate:
505✔
844
                return k.storeAddHtlcsUpdate()
505✔
845

846
        case invpkg.CancelHTLCsUpdate:
15✔
847
                return k.storeCancelHtlcsUpdate()
15✔
848

849
        case invpkg.SettleHodlInvoiceUpdate:
61✔
850
                return k.storeSettleHodlInvoiceUpdate()
61✔
851

852
        case invpkg.CancelInvoiceUpdate:
37✔
853
                // Persist all changes which where made when cancelling the
37✔
854
                // invoice. All HTLCs which were accepted are now canceled, so
37✔
855
                // we persist this state.
37✔
856
                return k.storeCancelHtlcsUpdate()
37✔
857
        }
858

859
        return fmt.Errorf("unknown update type: %v", updateType)
×
860
}
861

862
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
863
// set of HTLCs.
864
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
49✔
865
        err := k.serializeAndStoreInvoice()
49✔
866
        if err != nil {
49✔
867
                return err
×
868
        }
×
869

870
        // If this is an AMP invoice, then we'll actually store the rest
871
        // of the HTLCs in-line with the invoice, using the invoice ID
872
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
873
        // setID.
874
        if k.invoice.IsAMP() {
61✔
875
                return k.updateAMPInvoices()
12✔
876
        }
12✔
877

878
        return nil
37✔
879
}
880

881
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
882
// HTLCs.
883
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
505✔
884
        invoiceIsAMP := k.invoice.IsAMP()
505✔
885

505✔
886
        for htlcSetID := range k.updatedAmpHtlcs {
538✔
887
                // Check if this SetID already exist.
33✔
888
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
33✔
889

33✔
890
                if setIDInvNum == nil {
52✔
891
                        err := k.setIDIndexBucket.Put(
19✔
892
                                htlcSetID[:], k.invoiceNum,
19✔
893
                        )
19✔
894
                        if err != nil {
19✔
895
                                return err
×
896
                        }
×
897
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
18✔
898
                        return invpkg.ErrDuplicateSetID{
1✔
899
                                SetID: htlcSetID,
1✔
900
                        }
1✔
901
                }
1✔
902
        }
903

904
        // If this is a non-AMP invoice, then the state can eventually go to
905
        // ContractSettled, so we pass in nil value as part of
906
        // setSettleMetaFields.
907
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
805✔
908
                err := k.setSettleMetaFields(nil)
301✔
909
                if err != nil {
301✔
910
                        return err
×
911
                }
×
912
        }
913

914
        // As we don't update the settle index above for AMP invoices, we'll do
915
        // it here for each sub-AMP invoice that was settled.
916
        for settledSetID := range k.settledSetIDs {
515✔
917
                settledSetID := settledSetID
11✔
918
                err := k.setSettleMetaFields(&settledSetID)
11✔
919
                if err != nil {
11✔
920
                        return err
×
921
                }
×
922
        }
923

924
        err := k.serializeAndStoreInvoice()
504✔
925
        if err != nil {
504✔
926
                return err
×
927
        }
×
928

929
        // If this is an AMP invoice, then we'll actually store the rest of the
930
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
931
        // and the AMP key as a suffix: invoiceNum || setID.
932
        if invoiceIsAMP {
536✔
933
                return k.updateAMPInvoices()
32✔
934
        }
32✔
935

936
        return nil
475✔
937
}
938

939
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
940
// settling a hodl invoice.
941
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
61✔
942
        err := k.setSettleMetaFields(nil)
61✔
943
        if err != nil {
61✔
944
                return err
×
945
        }
×
946

947
        return k.serializeAndStoreInvoice()
61✔
948
}
949

950
// setSettleMetaFields updates the metadata associated with settlement of an
951
// invoice. If a non-nil setID is passed in, then the value will be append to
952
// the invoice number as well, in order to allow us to detect repeated payments
953
// to the same AMP invoices "across time".
954
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
367✔
955
        // Now that we know the invoice hasn't already been settled, we'll
367✔
956
        // update the settle index so we can place this settle event in the
367✔
957
        // proper location within our time series.
367✔
958
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
367✔
959
        if err != nil {
367✔
960
                return err
×
961
        }
×
962

963
        // Make a new byte array on the stack that can potentially store the 4
964
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
965
        // here which is the number of bytes copied so we can only store the 4
966
        // bytes if this is a non-AMP invoice.
967
        var indexKey [invoiceSetIDKeyLen]byte
367✔
968
        valueLen := copy(indexKey[:], k.invoiceNum)
367✔
969

367✔
970
        if setID != nil {
378✔
971
                valueLen += copy(indexKey[valueLen:], setID[:])
11✔
972
        }
11✔
973

974
        var seqNoBytes [8]byte
367✔
975
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
367✔
976
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
367✔
977
        if err != nil {
367✔
978
                return err
×
979
        }
×
980

981
        // If the setID is nil, then this means that this is a non-AMP settle,
982
        // so we'll update the invoice settle index directly.
983
        if setID == nil {
726✔
984
                k.invoice.SettleDate = k.updateTime
359✔
985
                k.invoice.SettleIndex = nextSettleSeqNo
359✔
986
        } else {
370✔
987
                // If the set ID isn't blank, we'll update the AMP state map
11✔
988
                // which tracks when each of the setIDs associated with a given
11✔
989
                // AMP invoice are settled.
11✔
990
                ampState := k.invoice.AMPState[*setID]
11✔
991

11✔
992
                ampState.SettleDate = k.updateTime
11✔
993
                ampState.SettleIndex = nextSettleSeqNo
11✔
994

11✔
995
                k.invoice.AMPState[*setID] = ampState
11✔
996
        }
11✔
997

998
        return nil
367✔
999
}
1000

1001
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
1002
// then continually write the invoices to the end of the invoice value, we
1003
// instead write the invoices into a new key preifx that follows the main
1004
// invoice number. This ensures that we don't need to continually decode a
1005
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
1006
// associated with a particular HTLC set.
1007
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
44✔
1008
        for setID, htlcSet := range k.updatedAmpHtlcs {
85✔
1009
                // First write out the set of HTLCs including all the relevant
41✔
1010
                // TLV values.
41✔
1011
                var b bytes.Buffer
41✔
1012
                if err := serializeHtlcs(&b, htlcSet); err != nil {
41✔
1013
                        return err
×
1014
                }
×
1015

1016
                // Next store each HTLC in-line, using a prefix based off the
1017
                // invoice number.
1018
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
41✔
1019

41✔
1020
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
41✔
1021
                if err != nil {
41✔
1022
                        return err
×
1023
                }
×
1024
        }
1025

1026
        return nil
44✔
1027
}
1028

1029
// serializeAndStoreInvoice is a helper function used to store invoices.
1030
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
608✔
1031
        var buf bytes.Buffer
608✔
1032
        if err := serializeInvoice(&buf, k.invoice); err != nil {
608✔
1033
                return err
×
1034
        }
×
1035

1036
        return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
608✔
1037
}
1038

1039
// InvoicesSettledSince can be used by callers to catch up any settled invoices
1040
// they missed within the settled invoice time series. We'll return all known
1041
// settled invoice that have a settle index higher than the passed
1042
// sinceSettleIndex.
1043
//
1044
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1045
// value below the starting index value is a noop.
1046
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1047
        []invpkg.Invoice, error) {
24✔
1048

24✔
1049
        var settledInvoices []invpkg.Invoice
24✔
1050

24✔
1051
        // If an index of zero was specified, then in order to maintain
24✔
1052
        // backwards compat, we won't send out any new invoices.
24✔
1053
        if sinceSettleIndex == 0 {
45✔
1054
                return settledInvoices, nil
21✔
1055
        }
21✔
1056

1057
        var startIndex [8]byte
6✔
1058
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
6✔
1059

6✔
1060
        err := kvdb.View(d, func(tx kvdb.RTx) error {
12✔
1061
                invoices := tx.ReadBucket(invoiceBucket)
6✔
1062
                if invoices == nil {
6✔
1063
                        return nil
×
1064
                }
×
1065

1066
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
6✔
1067
                if settleIndex == nil {
6✔
1068
                        return nil
×
1069
                }
×
1070

1071
                // We'll now run through each entry in the add index starting
1072
                // at our starting index. We'll continue until we reach the
1073
                // very end of the current key space.
1074
                invoiceCursor := settleIndex.ReadCursor()
6✔
1075

6✔
1076
                // We'll seek to the starting index, then manually advance the
6✔
1077
                // cursor in order to skip the entry with the since add index.
6✔
1078
                invoiceCursor.Seek(startIndex[:])
6✔
1079
                seqNo, indexValue := invoiceCursor.Next()
6✔
1080

6✔
1081
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
20✔
1082
                        // Depending on the length of the index value, this may
14✔
1083
                        // or may not be an AMP invoice, so we'll extract the
14✔
1084
                        // invoice value into two components: the invoice num,
14✔
1085
                        // and the setID (may not be there).
14✔
1086
                        var (
14✔
1087
                                invoiceKey [4]byte
14✔
1088
                                setID      *invpkg.SetID
14✔
1089
                        )
14✔
1090

14✔
1091
                        valueLen := copy(invoiceKey[:], indexValue)
14✔
1092
                        if len(indexValue) == invoiceSetIDKeyLen {
19✔
1093
                                setID = new(invpkg.SetID)
5✔
1094
                                copy(setID[:], indexValue[valueLen:])
5✔
1095
                        }
5✔
1096

1097
                        // For each key found, we'll look up the actual
1098
                        // invoice, then accumulate it into our return value.
1099
                        invoice, err := fetchInvoice(
14✔
1100
                                invoiceKey[:], invoices, []*invpkg.SetID{setID},
14✔
1101
                                true,
14✔
1102
                        )
14✔
1103
                        if err != nil {
14✔
1104
                                return err
×
1105
                        }
×
1106

1107
                        settledInvoices = append(settledInvoices, invoice)
14✔
1108
                }
1109

1110
                return nil
6✔
1111
        }, func() {
6✔
1112
                settledInvoices = nil
6✔
1113
        })
6✔
1114
        if err != nil {
6✔
1115
                return nil, err
×
1116
        }
×
1117

1118
        return settledInvoices, nil
6✔
1119
}
1120

1121
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
1122
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
1123
        uint64, error) {
618✔
1124

618✔
1125
        // Create the invoice key which is just the big-endian representation
618✔
1126
        // of the invoice number.
618✔
1127
        var invoiceKey [4]byte
618✔
1128
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
618✔
1129

618✔
1130
        // Increment the num invoice counter index so the next invoice bares
618✔
1131
        // the proper ID.
618✔
1132
        var scratch [4]byte
618✔
1133
        invoiceCounter := invoiceNum + 1
618✔
1134
        byteOrder.PutUint32(scratch[:], invoiceCounter)
618✔
1135
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
618✔
1136
                return 0, err
×
1137
        }
×
1138

1139
        // Add the payment hash to the invoice index. This will let us quickly
1140
        // identify if we can settle an incoming payment, and also to possibly
1141
        // allow a single invoice to have multiple payment installations.
1142
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
618✔
1143
        if err != nil {
618✔
1144
                return 0, err
×
1145
        }
×
1146

1147
        // Add the invoice to the payment address index, but only if the invoice
1148
        // has a non-zero payment address. The all-zero payment address is still
1149
        // in use by legacy keysend, so we special-case here to avoid
1150
        // collisions.
1151
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
1,114✔
1152
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
496✔
1153
                if err != nil {
496✔
1154
                        return 0, err
×
1155
                }
×
1156
        }
1157

1158
        // Next, we'll obtain the next add invoice index (sequence
1159
        // number), so we can properly place this invoice within this
1160
        // event stream.
1161
        nextAddSeqNo, err := addIndex.NextSequence()
618✔
1162
        if err != nil {
618✔
1163
                return 0, err
×
1164
        }
×
1165

1166
        // With the next sequence obtained, we'll updating the event series in
1167
        // the add index bucket to map this current add counter to the index of
1168
        // this new invoice.
1169
        var seqNoBytes [8]byte
618✔
1170
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
618✔
1171
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
618✔
1172
                return 0, err
×
1173
        }
×
1174

1175
        i.AddIndex = nextAddSeqNo
618✔
1176

618✔
1177
        // Finally, serialize the invoice itself to be written to the disk.
618✔
1178
        var buf bytes.Buffer
618✔
1179
        if err := serializeInvoice(&buf, i); err != nil {
618✔
1180
                return 0, err
×
1181
        }
×
1182

1183
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
618✔
1184
                return 0, err
×
1185
        }
×
1186

1187
        return nextAddSeqNo, nil
618✔
1188
}
1189

1190
// recordSize returns the amount of bytes this TLV record will occupy when
1191
// encoded.
1192
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
1,224✔
1193
        var (
1,224✔
1194
                b   bytes.Buffer
1,224✔
1195
                buf [8]byte
1,224✔
1196
        )
1,224✔
1197

1,224✔
1198
        // We know that encoding works since the tests pass in the build this
1,224✔
1199
        // file is checked into, so we'll simplify things and simply encode it
1,224✔
1200
        // ourselves then report the total amount of bytes used.
1,224✔
1201
        if err := ampStateEncoder(&b, a, &buf); err != nil {
1,224✔
1202
                // This should never error out, but we log it just in case it
×
1203
                // does.
×
1204
                log.Errorf("encoding the amp invoice state failed: %v", err)
×
1205
        }
×
1206

1207
        return func() uint64 {
2,448✔
1208
                return uint64(len(b.Bytes()))
1,224✔
1209
        }
1,224✔
1210
}
1211

1212
// serializeInvoice serializes an invoice to a writer.
1213
//
1214
// Note: this function is in use for a migration. Before making changes that
1215
// would modify the on disk format, make a copy of the original code and store
1216
// it with the migration.
1217
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
1,223✔
1218
        creationDateBytes, err := i.CreationDate.MarshalBinary()
1,223✔
1219
        if err != nil {
1,223✔
1220
                return err
×
1221
        }
×
1222

1223
        settleDateBytes, err := i.SettleDate.MarshalBinary()
1,223✔
1224
        if err != nil {
1,223✔
1225
                return err
×
1226
        }
×
1227

1228
        var fb bytes.Buffer
1,223✔
1229
        err = i.Terms.Features.EncodeBase256(&fb)
1,223✔
1230
        if err != nil {
1,223✔
1231
                return err
×
1232
        }
×
1233
        featureBytes := fb.Bytes()
1,223✔
1234

1,223✔
1235
        preimage := [32]byte(invpkg.UnknownPreimage)
1,223✔
1236
        if i.Terms.PaymentPreimage != nil {
2,298✔
1237
                preimage = *i.Terms.PaymentPreimage
1,075✔
1238
                if preimage == invpkg.UnknownPreimage {
1,075✔
1239
                        return errors.New("cannot use all-zeroes preimage")
×
1240
                }
×
1241
        }
1242
        value := uint64(i.Terms.Value)
1,223✔
1243
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
1,223✔
1244
        expiry := uint64(i.Terms.Expiry)
1,223✔
1245

1,223✔
1246
        amtPaid := uint64(i.AmtPaid)
1,223✔
1247
        state := uint8(i.State)
1,223✔
1248

1,223✔
1249
        var hodlInvoice uint8
1,223✔
1250
        if i.HodlInvoice {
1,423✔
1251
                hodlInvoice = 1
200✔
1252
        }
200✔
1253

1254
        tlvStream, err := tlv.NewStream(
1,223✔
1255
                // Memo and payreq.
1,223✔
1256
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
1,223✔
1257
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
1,223✔
1258

1,223✔
1259
                // Add/settle metadata.
1,223✔
1260
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
1,223✔
1261
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
1,223✔
1262
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
1,223✔
1263
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
1,223✔
1264

1,223✔
1265
                // Terms.
1,223✔
1266
                tlv.MakePrimitiveRecord(preimageType, &preimage),
1,223✔
1267
                tlv.MakePrimitiveRecord(valueType, &value),
1,223✔
1268
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
1,223✔
1269
                tlv.MakePrimitiveRecord(expiryType, &expiry),
1,223✔
1270
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
1,223✔
1271
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
1,223✔
1272

1,223✔
1273
                // Invoice state.
1,223✔
1274
                tlv.MakePrimitiveRecord(invStateType, &state),
1,223✔
1275
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
1,223✔
1276

1,223✔
1277
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
1,223✔
1278

1,223✔
1279
                // Invoice AMP state.
1,223✔
1280
                tlv.MakeDynamicRecord(
1,223✔
1281
                        invoiceAmpStateType, &i.AMPState,
1,223✔
1282
                        ampRecordSize(&i.AMPState),
1,223✔
1283
                        ampStateEncoder, ampStateDecoder,
1,223✔
1284
                ),
1,223✔
1285
        )
1,223✔
1286
        if err != nil {
1,223✔
1287
                return err
×
1288
        }
×
1289

1290
        var b bytes.Buffer
1,223✔
1291
        if err = tlvStream.Encode(&b); err != nil {
1,223✔
1292
                return err
×
1293
        }
×
1294

1295
        err = binary.Write(w, byteOrder, uint64(b.Len()))
1,223✔
1296
        if err != nil {
1,223✔
1297
                return err
×
1298
        }
×
1299

1300
        if _, err = w.Write(b.Bytes()); err != nil {
1,223✔
1301
                return err
×
1302
        }
×
1303

1304
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
1305
        // in-line with the rest of the invoice.
1306
        if i.IsAMP() {
1,280✔
1307
                return nil
57✔
1308
        }
57✔
1309

1310
        return serializeHtlcs(w, i.Htlcs)
1,169✔
1311
}
1312

1313
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
1314
// a writer.
1315
func serializeHtlcs(w io.Writer,
1316
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
1,207✔
1317

1,207✔
1318
        for key, htlc := range htlcs {
1,919✔
1319
                // Encode the htlc in a tlv stream.
712✔
1320
                chanID := key.ChanID.ToUint64()
712✔
1321
                amt := uint64(htlc.Amt)
712✔
1322
                mppTotalAmt := uint64(htlc.MppTotalAmt)
712✔
1323
                acceptTime := putNanoTime(htlc.AcceptTime)
712✔
1324
                resolveTime := putNanoTime(htlc.ResolveTime)
712✔
1325
                state := uint8(htlc.State)
712✔
1326

712✔
1327
                var records []tlv.Record
712✔
1328
                records = append(records,
712✔
1329
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
712✔
1330
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
712✔
1331
                        tlv.MakePrimitiveRecord(amtType, &amt),
712✔
1332
                        tlv.MakePrimitiveRecord(
712✔
1333
                                acceptHeightType, &htlc.AcceptHeight,
712✔
1334
                        ),
712✔
1335
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
712✔
1336
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
712✔
1337
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
712✔
1338
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
712✔
1339
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
712✔
1340
                )
712✔
1341

712✔
1342
                if htlc.AMP != nil {
770✔
1343
                        setIDRecord := tlv.MakeDynamicRecord(
58✔
1344
                                htlcAMPType, &htlc.AMP.Record,
58✔
1345
                                htlc.AMP.Record.PayloadSize,
58✔
1346
                                record.AMPEncoder, record.AMPDecoder,
58✔
1347
                        )
58✔
1348
                        records = append(records, setIDRecord)
58✔
1349

58✔
1350
                        hash32 := [32]byte(htlc.AMP.Hash)
58✔
1351
                        hashRecord := tlv.MakePrimitiveRecord(
58✔
1352
                                htlcHashType, &hash32,
58✔
1353
                        )
58✔
1354
                        records = append(records, hashRecord)
58✔
1355

58✔
1356
                        if htlc.AMP.Preimage != nil {
88✔
1357
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
30✔
1358
                                preimageRecord := tlv.MakePrimitiveRecord(
30✔
1359
                                        htlcPreimageType, &preimage32,
30✔
1360
                                )
30✔
1361
                                records = append(records, preimageRecord)
30✔
1362
                        }
30✔
1363
                }
1364

1365
                // Convert the custom records to tlv.Record types that are ready
1366
                // for serialization.
1367
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
712✔
1368

712✔
1369
                // Append the custom records. Their ids are in the experimental
712✔
1370
                // range and sorted, so there is no need to sort again.
712✔
1371
                records = append(records, customRecords...)
712✔
1372

712✔
1373
                tlvStream, err := tlv.NewStream(records...)
712✔
1374
                if err != nil {
712✔
1375
                        return err
×
1376
                }
×
1377

1378
                var b bytes.Buffer
712✔
1379
                if err := tlvStream.Encode(&b); err != nil {
712✔
1380
                        return err
×
1381
                }
×
1382

1383
                // Write the length of the tlv stream followed by the stream
1384
                // bytes.
1385
                err = binary.Write(w, byteOrder, uint64(b.Len()))
712✔
1386
                if err != nil {
712✔
1387
                        return err
×
1388
                }
×
1389

1390
                if _, err := w.Write(b.Bytes()); err != nil {
712✔
1391
                        return err
×
1392
                }
×
1393
        }
1394

1395
        return nil
1,207✔
1396
}
1397

1398
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
1399
// timestamp will be mapped to 0, since calling UnixNano in that case is
1400
// undefined.
1401
func putNanoTime(t time.Time) uint64 {
1,421✔
1402
        if t.IsZero() {
1,629✔
1403
                return 0
208✔
1404
        }
208✔
1405
        return uint64(t.UnixNano())
1,216✔
1406
}
1407

1408
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
1409
// is provided, an zero-value time stamp is returned.
1410
func getNanoTime(ns uint64) time.Time {
2,311✔
1411
        if ns == 0 {
2,679✔
1412
                return time.Time{}
368✔
1413
        }
368✔
1414
        return time.Unix(0, int64(ns))
1,946✔
1415
}
1416

1417
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
1418
// identified by the setID value.
1419
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1420
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1421
        error) {
50✔
1422

50✔
1423
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
50✔
1424
        for _, setID := range setIDs {
100✔
1425
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
50✔
1426

50✔
1427
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
50✔
1428
                if htlcSetBytes == nil {
70✔
1429
                        // A set ID was passed in, but we don't have this
20✔
1430
                        // stored yet, meaning that the setID is being added
20✔
1431
                        // for the first time.
20✔
1432
                        return htlcs, invpkg.ErrInvoiceNotFound
20✔
1433
                }
20✔
1434

1435
                htlcSetReader := bytes.NewReader(htlcSetBytes)
33✔
1436
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
33✔
1437
                if err != nil {
33✔
1438
                        return nil, err
×
1439
                }
×
1440

1441
                for key, htlc := range htlcsBySetID {
76✔
1442
                        htlcs[key] = htlc
43✔
1443
                }
43✔
1444
        }
1445

1446
        return htlcs, nil
33✔
1447
}
1448

1449
// forEachAMPInvoice is a helper function that attempts to iterate over each of
1450
// the HTLC sets (based on their set ID) for the given AMP invoice identified
1451
// by its invoiceNum. The callback closure is called for each key within the
1452
// prefix range.
1453
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1454
        callback func(key, htlcSet []byte) error) error {
60✔
1455

60✔
1456
        invoiceCursor := invoiceBucket.ReadCursor()
60✔
1457

60✔
1458
        // Seek to the first key that includes the invoice data itself.
60✔
1459
        invoiceCursor.Seek(invoiceNum)
60✔
1460

60✔
1461
        // Advance to the very first key _after_ the invoice data, as this is
60✔
1462
        // where we'll encounter our first HTLC (if any are present).
60✔
1463
        cursorKey, htlcSet := invoiceCursor.Next()
60✔
1464

60✔
1465
        // If at this point, the cursor key doesn't match the invoice num
60✔
1466
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
60✔
1467
        // associated with it.
60✔
1468
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
88✔
1469
                return nil
28✔
1470
        }
28✔
1471

1472
        // Otherwise continue to iterate until we no longer match the prefix,
1473
        // executing the call back at each step.
1474
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
86✔
1475
                err := callback(cursorKey, htlcSet)
51✔
1476
                if err != nil {
51✔
1477
                        return err
×
1478
                }
×
1479
        }
1480

1481
        return nil
35✔
1482
}
1483

1484
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
1485
// AMP bucket to find all the individual HTLCs (by setID) associated with a
1486
// given invoice. If a list of set IDs are specified, then only HTLCs
1487
// associated with that setID will be retrieved.
1488
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1489
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1490
        error) {
92✔
1491

92✔
1492
        // If a set of setIDs was specified, then we can skip the cursor and
92✔
1493
        // just read out exactly what we need.
92✔
1494
        if len(setIDs) != 0 && setIDs[0] != nil {
142✔
1495
                return fetchFilteredAmpInvoices(
50✔
1496
                        invoiceBucket, invoiceNum, setIDs...,
50✔
1497
                )
50✔
1498
        }
50✔
1499

1500
        // Otherwise, iterate over all the htlc sets that are prefixed beside
1501
        // this invoice in the main invoice bucket.
1502
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
45✔
1503
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
45✔
1504
                func(key, htlcSet []byte) error {
93✔
1505
                        htlcSetReader := bytes.NewReader(htlcSet)
48✔
1506
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
48✔
1507
                        if err != nil {
48✔
1508
                                return err
×
1509
                        }
×
1510

1511
                        for key, htlc := range htlcsBySetID {
124✔
1512
                                htlcs[key] = htlc
76✔
1513
                        }
76✔
1514

1515
                        return nil
48✔
1516
                },
1517
        )
1518

1519
        if err != nil {
45✔
1520
                return nil, err
×
1521
        }
×
1522

1523
        return htlcs, nil
45✔
1524
}
1525

1526
// fetchInvoice attempts to read out the relevant state for the invoice as
1527
// specified by the invoice number. If the setID fields are set, then only the
1528
// HTLC information pertaining to those set IDs is returned.
1529
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1530
        setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
2,298✔
1531

2,298✔
1532
        invoiceBytes := invoices.Get(invoiceNum)
2,298✔
1533
        if invoiceBytes == nil {
2,298✔
1534
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
×
1535
        }
×
1536

1537
        invoiceReader := bytes.NewReader(invoiceBytes)
2,298✔
1538

2,298✔
1539
        invoice, err := deserializeInvoice(invoiceReader)
2,298✔
1540
        if err != nil {
2,298✔
1541
                return invpkg.Invoice{}, err
×
1542
        }
×
1543

1544
        // If this is an AMP invoice we'll also attempt to read out the set of
1545
        // HTLCs that were paid to prior set IDs, if needed.
1546
        if !invoice.IsAMP() {
4,506✔
1547
                return invoice, nil
2,208✔
1548
        }
2,208✔
1549

1550
        if shouldFetchAMPHTLCs(invoice, setIDs) {
185✔
1551
                invoice.Htlcs, err = fetchAmpSubInvoices(
92✔
1552
                        invoices, invoiceNum, setIDs...,
92✔
1553
                )
92✔
1554
                // TODO(positiveblue): we should fail when we are not able to
92✔
1555
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
92✔
1556
                // the invoice and channeldb package break if we return this
92✔
1557
                // error. We need to update them when we migrate this logic to
92✔
1558
                // the sql implementation.
92✔
1559
                if err != nil {
112✔
1560
                        log.Errorf("unable to fetch amp htlcs for inv "+
20✔
1561
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
20✔
1562
                }
20✔
1563

1564
                if filterAMPState {
131✔
1565
                        filterInvoiceAMPState(&invoice, setIDs...)
39✔
1566
                }
39✔
1567
        }
1568

1569
        return invoice, nil
93✔
1570
}
1571

1572
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
1573
// were paid to the relevant set IDs.
1574
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
93✔
1575
        // For AMP invoice that already have HTLCs populated (created before
93✔
1576
        // recurring invoices), then we don't need to read from the prefix
93✔
1577
        // keyed section of the bucket.
93✔
1578
        if len(invoice.Htlcs) != 0 {
93✔
1579
                return false
×
1580
        }
×
1581

1582
        // If the "zero" setID was specified, then this means that no HTLC data
1583
        // should be returned alongside of it.
1584
        if len(setIDs) != 0 && setIDs[0] != nil &&
93✔
1585
                *setIDs[0] == invpkg.BlankPayAddr {
97✔
1586

4✔
1587
                return false
4✔
1588
        }
4✔
1589

1590
        return true
92✔
1591
}
1592

1593
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
1594
// an AMP invoice. This methods only decode the relevant state vs the entire
1595
// invoice.
1596
func fetchInvoiceStateAMP(invoiceNum []byte,
1597
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
8✔
1598

8✔
1599
        // Fetch the raw invoice bytes.
8✔
1600
        invoiceBytes := invoices.Get(invoiceNum)
8✔
1601
        if invoiceBytes == nil {
8✔
1602
                return nil, invpkg.ErrInvoiceNotFound
×
1603
        }
×
1604

1605
        r := bytes.NewReader(invoiceBytes)
8✔
1606

8✔
1607
        var bodyLen int64
8✔
1608
        err := binary.Read(r, byteOrder, &bodyLen)
8✔
1609
        if err != nil {
8✔
1610
                return nil, err
×
1611
        }
×
1612

1613
        // Next, we'll make a new TLV stream that only attempts to decode the
1614
        // bytes we actually need.
1615
        ampState := make(invpkg.AMPInvoiceState)
8✔
1616
        tlvStream, err := tlv.NewStream(
8✔
1617
                // Invoice AMP state.
8✔
1618
                tlv.MakeDynamicRecord(
8✔
1619
                        invoiceAmpStateType, &ampState, nil,
8✔
1620
                        ampStateEncoder, ampStateDecoder,
8✔
1621
                ),
8✔
1622
        )
8✔
1623
        if err != nil {
8✔
1624
                return nil, err
×
1625
        }
×
1626

1627
        invoiceReader := io.LimitReader(r, bodyLen)
8✔
1628
        if err = tlvStream.Decode(invoiceReader); err != nil {
8✔
1629
                return nil, err
×
1630
        }
×
1631

1632
        return ampState, nil
8✔
1633
}
1634

1635
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
2,298✔
1636
        var (
2,298✔
1637
                preimageBytes [32]byte
2,298✔
1638
                value         uint64
2,298✔
1639
                cltvDelta     uint32
2,298✔
1640
                expiry        uint64
2,298✔
1641
                amtPaid       uint64
2,298✔
1642
                state         uint8
2,298✔
1643
                hodlInvoice   uint8
2,298✔
1644

2,298✔
1645
                creationDateBytes []byte
2,298✔
1646
                settleDateBytes   []byte
2,298✔
1647
                featureBytes      []byte
2,298✔
1648
        )
2,298✔
1649

2,298✔
1650
        var i invpkg.Invoice
2,298✔
1651
        i.AMPState = make(invpkg.AMPInvoiceState)
2,298✔
1652
        tlvStream, err := tlv.NewStream(
2,298✔
1653
                // Memo and payreq.
2,298✔
1654
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
2,298✔
1655
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
2,298✔
1656

2,298✔
1657
                // Add/settle metadata.
2,298✔
1658
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
2,298✔
1659
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
2,298✔
1660
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
2,298✔
1661
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
2,298✔
1662

2,298✔
1663
                // Terms.
2,298✔
1664
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
2,298✔
1665
                tlv.MakePrimitiveRecord(valueType, &value),
2,298✔
1666
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
2,298✔
1667
                tlv.MakePrimitiveRecord(expiryType, &expiry),
2,298✔
1668
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
2,298✔
1669
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
2,298✔
1670

2,298✔
1671
                // Invoice state.
2,298✔
1672
                tlv.MakePrimitiveRecord(invStateType, &state),
2,298✔
1673
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
2,298✔
1674

2,298✔
1675
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
2,298✔
1676

2,298✔
1677
                // Invoice AMP state.
2,298✔
1678
                tlv.MakeDynamicRecord(
2,298✔
1679
                        invoiceAmpStateType, &i.AMPState, nil,
2,298✔
1680
                        ampStateEncoder, ampStateDecoder,
2,298✔
1681
                ),
2,298✔
1682
        )
2,298✔
1683
        if err != nil {
2,298✔
1684
                return i, err
×
1685
        }
×
1686

1687
        var bodyLen int64
2,298✔
1688
        err = binary.Read(r, byteOrder, &bodyLen)
2,298✔
1689
        if err != nil {
2,298✔
1690
                return i, err
×
1691
        }
×
1692

1693
        lr := io.LimitReader(r, bodyLen)
2,298✔
1694
        if err = tlvStream.Decode(lr); err != nil {
2,298✔
1695
                return i, err
×
1696
        }
×
1697

1698
        preimage := lntypes.Preimage(preimageBytes)
2,298✔
1699
        if preimage != invpkg.UnknownPreimage {
4,357✔
1700
                i.Terms.PaymentPreimage = &preimage
2,059✔
1701
        }
2,059✔
1702

1703
        i.Terms.Value = lnwire.MilliSatoshi(value)
2,298✔
1704
        i.Terms.FinalCltvDelta = int32(cltvDelta)
2,298✔
1705
        i.Terms.Expiry = time.Duration(expiry)
2,298✔
1706
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
2,298✔
1707
        i.State = invpkg.ContractState(state)
2,298✔
1708

2,298✔
1709
        if hodlInvoice != 0 {
2,532✔
1710
                i.HodlInvoice = true
234✔
1711
        }
234✔
1712

1713
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
2,298✔
1714
        if err != nil {
2,298✔
1715
                return i, err
×
1716
        }
×
1717

1718
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
2,298✔
1719
        if err != nil {
2,298✔
1720
                return i, err
×
1721
        }
×
1722

1723
        rawFeatures := lnwire.NewRawFeatureVector()
2,298✔
1724
        err = rawFeatures.DecodeBase256(
2,298✔
1725
                bytes.NewReader(featureBytes), len(featureBytes),
2,298✔
1726
        )
2,298✔
1727
        if err != nil {
2,298✔
1728
                return i, err
×
1729
        }
×
1730

1731
        i.Terms.Features = lnwire.NewFeatureVector(
2,298✔
1732
                rawFeatures, lnwire.Features,
2,298✔
1733
        )
2,298✔
1734

2,298✔
1735
        i.Htlcs, err = deserializeHtlcs(r)
2,298✔
1736
        return i, err
2,298✔
1737
}
1738

1739
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
133✔
1740
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
266✔
1741
                // We encode the set of circuit keys as a varint length prefix.
133✔
1742
                // followed by a series of fixed sized uint8 integers.
133✔
1743
                numKeys := uint64(len(*v))
133✔
1744

133✔
1745
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
133✔
1746
                        return err
×
1747
                }
×
1748

1749
                for key := range *v {
314✔
1750
                        scidInt := key.ChanID.ToUint64()
181✔
1751

181✔
1752
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
181✔
1753
                                return err
×
1754
                        }
×
1755
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
181✔
1756
                                return err
×
1757
                        }
×
1758
                }
1759

1760
                return nil
133✔
1761
        }
1762

1763
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1764
}
1765

1766
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1767
        l uint64) error {
121✔
1768

121✔
1769
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
242✔
1770
                // First, we'll read out the varint that encodes the number of
121✔
1771
                // circuit keys encoded.
121✔
1772
                numKeys, err := tlv.ReadVarInt(r, buf)
121✔
1773
                if err != nil {
121✔
1774
                        return err
×
1775
                }
×
1776

1777
                // Now that we know how many keys to expect, iterate reading
1778
                // each one until we're done.
1779
                for i := uint64(0); i < numKeys; i++ {
286✔
1780
                        var (
165✔
1781
                                key  models.CircuitKey
165✔
1782
                                scid uint64
165✔
1783
                        )
165✔
1784

165✔
1785
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
165✔
1786
                                return err
×
1787
                        }
×
1788

1789
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
165✔
1790

165✔
1791
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
165✔
1792
                        if err != nil {
165✔
1793
                                return err
×
1794
                        }
×
1795

1796
                        (*v)[key] = struct{}{}
165✔
1797
                }
1798

1799
                return nil
121✔
1800
        }
1801

1802
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
×
1803
}
1804

1805
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1806
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
2,445✔
1807
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
4,890✔
1808
                // We'll encode the AMP state as a series of KV pairs on the
2,445✔
1809
                // wire with a length prefix.
2,445✔
1810
                numRecords := uint64(len(*v))
2,445✔
1811

2,445✔
1812
                // First, we'll write out the number of records as a var int.
2,445✔
1813
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
2,445✔
1814
                        return err
×
1815
                }
×
1816

1817
                // With that written out, we'll now encode the entries
1818
                // themselves as a sub-TLV record, which includes its _own_
1819
                // inner length prefix.
1820
                for setID, ampState := range *v {
2,578✔
1821
                        setID := [32]byte(setID)
133✔
1822
                        ampState := ampState
133✔
1823

133✔
1824
                        htlcState := uint8(ampState.State)
133✔
1825
                        settleDate := ampState.SettleDate
133✔
1826
                        settleDateBytes, err := settleDate.MarshalBinary()
133✔
1827
                        if err != nil {
133✔
1828
                                return err
×
1829
                        }
×
1830

1831
                        amtPaid := uint64(ampState.AmtPaid)
133✔
1832

133✔
1833
                        var ampStateTlvBytes bytes.Buffer
133✔
1834
                        tlvStream, err := tlv.NewStream(
133✔
1835
                                tlv.MakePrimitiveRecord(
133✔
1836
                                        ampStateSetIDType, &setID,
133✔
1837
                                ),
133✔
1838
                                tlv.MakePrimitiveRecord(
133✔
1839
                                        ampStateHtlcStateType, &htlcState,
133✔
1840
                                ),
133✔
1841
                                tlv.MakePrimitiveRecord(
133✔
1842
                                        ampStateSettleIndexType,
133✔
1843
                                        &ampState.SettleIndex,
133✔
1844
                                ),
133✔
1845
                                tlv.MakePrimitiveRecord(
133✔
1846
                                        ampStateSettleDateType,
133✔
1847
                                        &settleDateBytes,
133✔
1848
                                ),
133✔
1849
                                tlv.MakeDynamicRecord(
133✔
1850
                                        ampStateCircuitKeysType,
133✔
1851
                                        &ampState.InvoiceKeys,
133✔
1852
                                        func() uint64 {
266✔
1853
                                                // The record takes 8 bytes to
133✔
1854
                                                // encode the set of circuits,
133✔
1855
                                                // 8 bytes for the scid for the
133✔
1856
                                                // key, and 8 bytes for the HTLC
133✔
1857
                                                // index.
133✔
1858
                                                keys := ampState.InvoiceKeys
133✔
1859
                                                numKeys := uint64(len(keys))
133✔
1860
                                                size := tlv.VarIntSize(numKeys)
133✔
1861
                                                dataSize := (numKeys * 16)
133✔
1862

133✔
1863
                                                return size + dataSize
133✔
1864
                                        },
133✔
1865
                                        encodeCircuitKeys, decodeCircuitKeys,
1866
                                ),
1867
                                tlv.MakePrimitiveRecord(
1868
                                        ampStateAmtPaidType, &amtPaid,
1869
                                ),
1870
                        )
1871
                        if err != nil {
133✔
1872
                                return err
×
1873
                        }
×
1874

1875
                        err = tlvStream.Encode(&ampStateTlvBytes)
133✔
1876
                        if err != nil {
133✔
1877
                                return err
×
1878
                        }
×
1879

1880
                        // We encode the record with a varint length followed by
1881
                        // the _raw_ TLV bytes.
1882
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
133✔
1883
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
133✔
1884
                                return err
×
1885
                        }
×
1886

1887
                        _, err = w.Write(ampStateTlvBytes.Bytes())
133✔
1888
                        if err != nil {
133✔
1889
                                return err
×
1890
                        }
×
1891
                }
1892

1893
                return nil
2,445✔
1894
        }
1895

1896
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1897
}
1898

1899
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1900
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
1901
        l uint64) error {
2,307✔
1902

2,307✔
1903
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
4,614✔
1904
                // First, we'll decode the varint that encodes how many set IDs
2,307✔
1905
                // are encoded within the greater map.
2,307✔
1906
                numRecords, err := tlv.ReadVarInt(r, buf)
2,307✔
1907
                if err != nil {
2,307✔
1908
                        return err
×
1909
                }
×
1910

1911
                // Now that we know how many records we'll need to read, we can
1912
                // iterate and read them all out in series.
1913
                for i := uint64(0); i < numRecords; i++ {
2,428✔
1914
                        // Read out the varint that encodes the size of this
121✔
1915
                        // inner TLV record.
121✔
1916
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
121✔
1917
                        if err != nil {
121✔
1918
                                return err
×
1919
                        }
×
1920

1921
                        // Using this information, we'll create a new limited
1922
                        // reader that'll return an EOF once the end has been
1923
                        // reached so the stream stops consuming bytes.
1924
                        innerTlvReader := io.LimitedReader{
121✔
1925
                                R: r,
121✔
1926
                                N: int64(stateRecordSize),
121✔
1927
                        }
121✔
1928

121✔
1929
                        var (
121✔
1930
                                setID           [32]byte
121✔
1931
                                htlcState       uint8
121✔
1932
                                settleIndex     uint64
121✔
1933
                                settleDateBytes []byte
121✔
1934
                                invoiceKeys     = make(
121✔
1935
                                        map[models.CircuitKey]struct{},
121✔
1936
                                )
121✔
1937
                                amtPaid uint64
121✔
1938
                        )
121✔
1939
                        tlvStream, err := tlv.NewStream(
121✔
1940
                                tlv.MakePrimitiveRecord(
121✔
1941
                                        ampStateSetIDType, &setID,
121✔
1942
                                ),
121✔
1943
                                tlv.MakePrimitiveRecord(
121✔
1944
                                        ampStateHtlcStateType, &htlcState,
121✔
1945
                                ),
121✔
1946
                                tlv.MakePrimitiveRecord(
121✔
1947
                                        ampStateSettleIndexType, &settleIndex,
121✔
1948
                                ),
121✔
1949
                                tlv.MakePrimitiveRecord(
121✔
1950
                                        ampStateSettleDateType,
121✔
1951
                                        &settleDateBytes,
121✔
1952
                                ),
121✔
1953
                                tlv.MakeDynamicRecord(
121✔
1954
                                        ampStateCircuitKeysType,
121✔
1955
                                        &invoiceKeys, nil,
121✔
1956
                                        encodeCircuitKeys, decodeCircuitKeys,
121✔
1957
                                ),
121✔
1958
                                tlv.MakePrimitiveRecord(
121✔
1959
                                        ampStateAmtPaidType, &amtPaid,
121✔
1960
                                ),
121✔
1961
                        )
121✔
1962
                        if err != nil {
121✔
1963
                                return err
×
1964
                        }
×
1965

1966
                        err = tlvStream.Decode(&innerTlvReader)
121✔
1967
                        if err != nil {
121✔
1968
                                return err
×
1969
                        }
×
1970

1971
                        var settleDate time.Time
121✔
1972
                        err = settleDate.UnmarshalBinary(settleDateBytes)
121✔
1973
                        if err != nil {
121✔
1974
                                return err
×
1975
                        }
×
1976

1977
                        (*v)[setID] = invpkg.InvoiceStateAMP{
121✔
1978
                                State:       invpkg.HtlcState(htlcState),
121✔
1979
                                SettleIndex: settleIndex,
121✔
1980
                                SettleDate:  settleDate,
121✔
1981
                                InvoiceKeys: invoiceKeys,
121✔
1982
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
121✔
1983
                        }
121✔
1984
                }
1985

1986
                return nil
2,307✔
1987
        }
1988

1989
        return tlv.NewTypeForDecodingErr(
×
1990
                val, "channeldb.AMPInvoiceState", l, l,
×
1991
        )
×
1992
}
1993

1994
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
1995
// as a map.
1996
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1997
        error) {
2,373✔
1998

2,373✔
1999
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
2,373✔
2000
        for {
5,900✔
2001
                // Read the length of the tlv stream for this htlc.
3,527✔
2002
                var streamLen int64
3,527✔
2003
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
5,900✔
2004
                        if err == io.EOF {
4,746✔
2005
                                break
2,373✔
2006
                        }
2007

2008
                        return nil, err
×
2009
                }
2010

2011
                // Limit the reader so that it stops at the end of this htlc's
2012
                // stream.
2013
                htlcReader := io.LimitReader(r, streamLen)
1,157✔
2014

1,157✔
2015
                // Decode the contents into the htlc fields.
1,157✔
2016
                var (
1,157✔
2017
                        htlc                    invpkg.InvoiceHTLC
1,157✔
2018
                        key                     models.CircuitKey
1,157✔
2019
                        chanID                  uint64
1,157✔
2020
                        state                   uint8
1,157✔
2021
                        acceptTime, resolveTime uint64
1,157✔
2022
                        amt, mppTotalAmt        uint64
1,157✔
2023
                        amp                     = &record.AMP{}
1,157✔
2024
                        hash32                  = &[32]byte{}
1,157✔
2025
                        preimage32              = &[32]byte{}
1,157✔
2026
                )
1,157✔
2027
                tlvStream, err := tlv.NewStream(
1,157✔
2028
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
1,157✔
2029
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
1,157✔
2030
                        tlv.MakePrimitiveRecord(amtType, &amt),
1,157✔
2031
                        tlv.MakePrimitiveRecord(
1,157✔
2032
                                acceptHeightType, &htlc.AcceptHeight,
1,157✔
2033
                        ),
1,157✔
2034
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
1,157✔
2035
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
1,157✔
2036
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
1,157✔
2037
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
1,157✔
2038
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
1,157✔
2039
                        tlv.MakeDynamicRecord(
1,157✔
2040
                                htlcAMPType, amp, amp.PayloadSize,
1,157✔
2041
                                record.AMPEncoder, record.AMPDecoder,
1,157✔
2042
                        ),
1,157✔
2043
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
1,157✔
2044
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
1,157✔
2045
                )
1,157✔
2046
                if err != nil {
1,157✔
2047
                        return nil, err
×
2048
                }
×
2049

2050
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
1,157✔
2051
                if err != nil {
1,157✔
2052
                        return nil, err
×
2053
                }
×
2054

2055
                if _, ok := parsedTypes[htlcAMPType]; !ok {
2,201✔
2056
                        amp = nil
1,044✔
2057
                }
1,044✔
2058

2059
                var preimage *lntypes.Preimage
1,157✔
2060
                if _, ok := parsedTypes[htlcPreimageType]; ok {
1,212✔
2061
                        pimg := lntypes.Preimage(*preimage32)
55✔
2062
                        preimage = &pimg
55✔
2063
                }
55✔
2064

2065
                var hash *lntypes.Hash
1,157✔
2066
                if _, ok := parsedTypes[htlcHashType]; ok {
1,273✔
2067
                        h := lntypes.Hash(*hash32)
116✔
2068
                        hash = &h
116✔
2069
                }
116✔
2070

2071
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
1,157✔
2072
                htlc.AcceptTime = getNanoTime(acceptTime)
1,157✔
2073
                htlc.ResolveTime = getNanoTime(resolveTime)
1,157✔
2074
                htlc.State = invpkg.HtlcState(state)
1,157✔
2075
                htlc.Amt = lnwire.MilliSatoshi(amt)
1,157✔
2076
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
1,157✔
2077
                if amp != nil && hash != nil {
1,273✔
2078
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
116✔
2079
                                Record:   *amp,
116✔
2080
                                Hash:     *hash,
116✔
2081
                                Preimage: preimage,
116✔
2082
                        }
116✔
2083
                }
116✔
2084

2085
                // Reconstruct the custom records fields from the parsed types
2086
                // map return from the tlv parser.
2087
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
1,157✔
2088

1,157✔
2089
                htlcs[key] = &htlc
1,157✔
2090
        }
2091

2092
        return htlcs, nil
2,373✔
2093
}
2094

2095
// invoiceSetIDKeyLen is the length of the key that's used to store the
2096
// individual HTLCs prefixed by their ID along side the main invoice within the
2097
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2098
// set ID.
2099
const invoiceSetIDKeyLen = 4 + 32
2100

2101
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2102
// number where the HTLCs for this setID will be stored udner.
2103
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
88✔
2104
        // Construct the prefix key we need to obtain the invoice information:
88✔
2105
        // invoiceNum || setID.
88✔
2106
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
88✔
2107
        copy(invoiceSetIDKey[:], invoiceNum)
88✔
2108
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
88✔
2109

88✔
2110
        return invoiceSetIDKey
88✔
2111
}
88✔
2112

2113
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
2114
// greater AMP invoices. We do this by deleting the set of keys that share the
2115
// invoice number as a prefix.
2116
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
15✔
2117
        // Since it isn't safe to delete using an active cursor, we'll use the
15✔
2118
        // cursor simply to collect the set of keys we need to delete, _then_
15✔
2119
        // delete them in another pass.
15✔
2120
        var keysToDel [][]byte
15✔
2121
        err := forEachAMPInvoice(
15✔
2122
                invoiceBucket, invoiceNum,
15✔
2123
                func(cursorKey, v []byte) error {
18✔
2124
                        keysToDel = append(keysToDel, cursorKey)
3✔
2125
                        return nil
3✔
2126
                },
3✔
2127
        )
2128
        if err != nil {
15✔
2129
                return err
×
2130
        }
×
2131

2132
        // In this next phase, we'll then delete all the relevant invoices.
2133
        for _, keyToDel := range keysToDel {
18✔
2134
                if err := invoiceBucket.Delete(keyToDel); err != nil {
3✔
2135
                        return err
×
2136
                }
×
2137
        }
2138

2139
        return nil
15✔
2140
}
2141

2142
// delAMPSettleIndex removes all the entries in the settle index associated
2143
// with a given AMP invoice.
2144
func delAMPSettleIndex(invoiceNum []byte, invoices,
2145
        settleIndex kvdb.RwBucket) error {
8✔
2146

8✔
2147
        // First, we need to grab the AMP invoice state to see if there's
8✔
2148
        // anything that we even need to delete.
8✔
2149
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
8✔
2150
        if err != nil {
8✔
2151
                return err
×
2152
        }
×
2153

2154
        // If there's no AMP state at all (non-AMP invoice), then we can return
2155
        // early.
2156
        if len(ampState) == 0 {
15✔
2157
                return nil
7✔
2158
        }
7✔
2159

2160
        // Otherwise, we'll need to iterate and delete each settle index within
2161
        // the set of returned entries.
2162
        var settleIndexKey [8]byte
1✔
2163
        for _, subState := range ampState {
4✔
2164
                byteOrder.PutUint64(
3✔
2165
                        settleIndexKey[:], subState.SettleIndex,
3✔
2166
                )
3✔
2167

3✔
2168
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
3✔
2169
                        return err
×
2170
                }
×
2171
        }
2172

2173
        return nil
1✔
2174
}
2175

2176
// DeleteCanceledInvoices deletes all canceled invoices from the database.
2177
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
3✔
2178
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
2179
                invoices := tx.ReadWriteBucket(invoiceBucket)
3✔
2180
                if invoices == nil {
3✔
2181
                        return nil
×
2182
                }
×
2183

2184
                invoiceIndex := invoices.NestedReadWriteBucket(
3✔
2185
                        invoiceIndexBucket,
3✔
2186
                )
3✔
2187
                if invoiceIndex == nil {
4✔
2188
                        return nil
1✔
2189
                }
1✔
2190

2191
                invoiceAddIndex := invoices.NestedReadWriteBucket(
2✔
2192
                        addIndexBucket,
2✔
2193
                )
2✔
2194
                if invoiceAddIndex == nil {
2✔
2195
                        return nil
×
2196
                }
×
2197

2198
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
2✔
2199

2✔
2200
                return invoiceIndex.ForEach(func(k, v []byte) error {
19✔
2201
                        // Skip the special numInvoicesKey as that does not
17✔
2202
                        // point to a valid invoice.
17✔
2203
                        if bytes.Equal(k, numInvoicesKey) {
19✔
2204
                                return nil
2✔
2205
                        }
2✔
2206

2207
                        // Skip sub-buckets.
2208
                        if v == nil {
15✔
2209
                                return nil
×
2210
                        }
×
2211

2212
                        invoice, err := fetchInvoice(v, invoices, nil, false)
15✔
2213
                        if err != nil {
15✔
2214
                                return err
×
2215
                        }
×
2216

2217
                        if invoice.State != invpkg.ContractCanceled {
23✔
2218
                                return nil
8✔
2219
                        }
8✔
2220

2221
                        // Delete the payment hash from the invoice index.
2222
                        err = invoiceIndex.Delete(k)
7✔
2223
                        if err != nil {
7✔
2224
                                return err
×
2225
                        }
×
2226

2227
                        // Delete payment address index reference if there's a
2228
                        // valid payment address.
2229
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
14✔
2230
                                // To ensure consistency check that the already
7✔
2231
                                // fetched invoice key matches the one in the
7✔
2232
                                // payment address index.
7✔
2233
                                key := payAddrIndex.Get(
7✔
2234
                                        invoice.Terms.PaymentAddr[:],
7✔
2235
                                )
7✔
2236
                                if bytes.Equal(key, k) {
7✔
2237
                                        // Delete from the payment address
×
2238
                                        // index.
×
2239
                                        if err := payAddrIndex.Delete(
×
2240
                                                invoice.Terms.PaymentAddr[:],
×
2241
                                        ); err != nil {
×
2242
                                                return err
×
2243
                                        }
×
2244
                                }
2245
                        }
2246

2247
                        // Remove from the add index.
2248
                        var addIndexKey [8]byte
7✔
2249
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
7✔
2250
                        err = invoiceAddIndex.Delete(addIndexKey[:])
7✔
2251
                        if err != nil {
7✔
2252
                                return err
×
2253
                        }
×
2254

2255
                        // Note that we don't need to delete the invoice from
2256
                        // the settle index as it is not added until the
2257
                        // invoice is settled.
2258

2259
                        // Now remove all sub invoices.
2260
                        err = delAMPInvoices(k, invoices)
7✔
2261
                        if err != nil {
7✔
2262
                                return err
×
2263
                        }
×
2264

2265
                        // Finally remove the serialized invoice from the
2266
                        // invoice bucket.
2267
                        return invoices.Delete(k)
7✔
2268
                })
2269
        }, func() {})
3✔
2270
}
2271

2272
// DeleteInvoice attempts to delete the passed invoices from the database in
2273
// one transaction. The passed delete references hold all keys required to
2274
// delete the invoices without also needing to deserialize them.
2275
func (d *DB) DeleteInvoice(_ context.Context,
2276
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
6✔
2277

6✔
2278
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
12✔
2279
                invoices := tx.ReadWriteBucket(invoiceBucket)
6✔
2280
                if invoices == nil {
6✔
2281
                        return invpkg.ErrNoInvoicesCreated
×
2282
                }
×
2283

2284
                invoiceIndex := invoices.NestedReadWriteBucket(
6✔
2285
                        invoiceIndexBucket,
6✔
2286
                )
6✔
2287
                if invoiceIndex == nil {
6✔
2288
                        return invpkg.ErrNoInvoicesCreated
×
2289
                }
×
2290

2291
                invoiceAddIndex := invoices.NestedReadWriteBucket(
6✔
2292
                        addIndexBucket,
6✔
2293
                )
6✔
2294
                if invoiceAddIndex == nil {
6✔
2295
                        return invpkg.ErrNoInvoicesCreated
×
2296
                }
×
2297

2298
                // settleIndex can be nil, as the bucket is created lazily
2299
                // when the first invoice is settled.
2300
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
6✔
2301

6✔
2302
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
6✔
2303

6✔
2304
                for _, ref := range invoicesToDelete {
17✔
2305
                        // Fetch the invoice key for using it to check for
11✔
2306
                        // consistency and also to delete from the invoice
11✔
2307
                        // index.
11✔
2308
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
11✔
2309
                        if invoiceKey == nil {
12✔
2310
                                return invpkg.ErrInvoiceNotFound
1✔
2311
                        }
1✔
2312

2313
                        err := invoiceIndex.Delete(ref.PayHash[:])
10✔
2314
                        if err != nil {
10✔
2315
                                return err
×
2316
                        }
×
2317

2318
                        // Delete payment address index reference if there's a
2319
                        // valid payment address passed.
2320
                        if ref.PayAddr != nil {
19✔
2321
                                // To ensure consistency check that the already
9✔
2322
                                // fetched invoice key matches the one in the
9✔
2323
                                // payment address index.
9✔
2324
                                key := payAddrIndex.Get(ref.PayAddr[:])
9✔
2325
                                if bytes.Equal(key, invoiceKey) {
18✔
2326
                                        // Delete from the payment address
9✔
2327
                                        // index. Note that since the payment
9✔
2328
                                        // address index has been introduced
9✔
2329
                                        // with an empty migration it may be
9✔
2330
                                        // possible that the index doesn't have
9✔
2331
                                        // an entry for this invoice.
9✔
2332
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
9✔
2333
                                        if err := payAddrIndex.Delete(
9✔
2334
                                                ref.PayAddr[:],
9✔
2335
                                        ); err != nil {
9✔
2336
                                                return err
×
2337
                                        }
×
2338
                                }
2339
                        }
2340

2341
                        var addIndexKey [8]byte
10✔
2342
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
10✔
2343

10✔
2344
                        // To ensure consistency check that the key stored in
10✔
2345
                        // the add index also matches the previously fetched
10✔
2346
                        // invoice key.
10✔
2347
                        key := invoiceAddIndex.Get(addIndexKey[:])
10✔
2348
                        if !bytes.Equal(key, invoiceKey) {
11✔
2349
                                return fmt.Errorf("unknown invoice " +
1✔
2350
                                        "in add index")
1✔
2351
                        }
1✔
2352

2353
                        // Remove from the add index.
2354
                        err = invoiceAddIndex.Delete(addIndexKey[:])
9✔
2355
                        if err != nil {
9✔
2356
                                return err
×
2357
                        }
×
2358

2359
                        // Remove from the settle index if available and
2360
                        // if the invoice is settled.
2361
                        if settleIndex != nil && ref.SettleIndex > 0 {
12✔
2362
                                var settleIndexKey [8]byte
3✔
2363
                                byteOrder.PutUint64(
3✔
2364
                                        settleIndexKey[:], ref.SettleIndex,
3✔
2365
                                )
3✔
2366

3✔
2367
                                // To ensure consistency check that the already
3✔
2368
                                // fetched invoice key matches the one in the
3✔
2369
                                // settle index
3✔
2370
                                key := settleIndex.Get(settleIndexKey[:])
3✔
2371
                                if !bytes.Equal(key, invoiceKey) {
4✔
2372
                                        return fmt.Errorf("unknown invoice " +
1✔
2373
                                                "in settle index")
1✔
2374
                                }
1✔
2375

2376
                                err = settleIndex.Delete(settleIndexKey[:])
2✔
2377
                                if err != nil {
2✔
2378
                                        return err
×
2379
                                }
×
2380
                        }
2381

2382
                        // In addition to deleting the main invoice state, if
2383
                        // this is an AMP invoice, then we'll also need to
2384
                        // delete the set HTLC set stored as a key prefix. For
2385
                        // non-AMP invoices, this'll be a noop.
2386
                        err = delAMPSettleIndex(
8✔
2387
                                invoiceKey, invoices, settleIndex,
8✔
2388
                        )
8✔
2389
                        if err != nil {
8✔
2390
                                return err
×
2391
                        }
×
2392
                        err = delAMPInvoices(invoiceKey, invoices)
8✔
2393
                        if err != nil {
8✔
2394
                                return err
×
2395
                        }
×
2396

2397
                        // Finally remove the serialized invoice from the
2398
                        // invoice bucket.
2399
                        err = invoices.Delete(invoiceKey)
8✔
2400
                        if err != nil {
8✔
2401
                                return err
×
2402
                        }
×
2403
                }
2404

2405
                return nil
3✔
2406
        }, func() {})
6✔
2407

2408
        return err
6✔
2409
}
2410

2411
// SetInvoiceBucketTombstone sets the tombstone key in the invoice bucket to
2412
// mark the bucket as permanently closed. This prevents it from being reopened
2413
// in the future.
2414
func (d *DB) SetInvoiceBucketTombstone() error {
1✔
2415
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
2✔
2416
                // Access the top-level invoice bucket.
1✔
2417
                invoices := tx.ReadWriteBucket(invoiceBucket)
1✔
2418
                if invoices == nil {
1✔
2419
                        return fmt.Errorf("invoice bucket does not exist")
×
2420
                }
×
2421

2422
                // Add the tombstone key to the invoice bucket.
2423
                err := invoices.Put(invoiceBucketTombstone, []byte("1"))
1✔
2424
                if err != nil {
1✔
2425
                        return fmt.Errorf("failed to set tombstone: %w", err)
×
2426
                }
×
2427

2428
                return nil
1✔
2429
        }, func() {})
1✔
2430
}
2431

2432
// GetInvoiceBucketTombstone checks if the tombstone key exists in the invoice
2433
// bucket. It returns true if the tombstone is present and false otherwise.
2434
func (d *DB) GetInvoiceBucketTombstone() (bool, error) {
5✔
2435
        var tombstoneExists bool
5✔
2436

5✔
2437
        err := kvdb.View(d, func(tx kvdb.RTx) error {
10✔
2438
                // Access the top-level invoice bucket.
5✔
2439
                invoices := tx.ReadBucket(invoiceBucket)
5✔
2440
                if invoices == nil {
5✔
2441
                        return fmt.Errorf("invoice bucket does not exist")
×
2442
                }
×
2443

2444
                // Check if the tombstone key exists.
2445
                tombstone := invoices.Get(invoiceBucketTombstone)
5✔
2446
                tombstoneExists = tombstone != nil
5✔
2447

5✔
2448
                return nil
5✔
2449
        }, func() {})
5✔
2450
        if err != nil {
5✔
2451
                return false, err
×
2452
        }
×
2453

2454
        return tombstoneExists, nil
5✔
2455
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc