• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13211764208

08 Feb 2025 03:08AM UTC coverage: 49.288% (-9.5%) from 58.815%
13211764208

Pull #9489

github

calvinrzachman
itest: verify switchrpc server enforces send then track

We prevent the rpc server from allowing onion dispatches for
attempt IDs which have already been tracked by rpc clients.

This helps protect the client from leaking a duplicate onion
attempt. NOTE: This is not the only method for solving this
issue! The issue could be addressed via careful client side
programming which accounts for the uncertainty and async
nature of dispatching onions to a remote process via RPC.
This would require some lnd ChannelRouter changes for how
we intend to use these RPCs though.
Pull Request #9489: multi: add BuildOnion, SendOnion, and TrackOnion RPCs

474 of 990 new or added lines in 11 files covered. (47.88%)

27321 existing lines in 435 files now uncovered.

101192 of 205306 relevant lines covered (49.29%)

1.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.7
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "time"
11

12
        "github.com/lightningnetwork/lnd/graph/db/models"
13
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
14
        invpkg "github.com/lightningnetwork/lnd/invoices"
15
        "github.com/lightningnetwork/lnd/kvdb"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/tlv"
20
)
21

22
var (
23
        // invoiceBucket is the name of the bucket within the database that
24
        // stores all data related to invoices no matter their final state.
25
        // Within the invoice bucket, each invoice is keyed by its invoice ID
26
        // which is a monotonically increasing uint32.
27
        invoiceBucket = []byte("invoices")
28

29
        // paymentHashIndexBucket is the name of the sub-bucket within the
30
        // invoiceBucket which indexes all invoices by their payment hash. The
31
        // payment hash is the sha256 of the invoice's payment preimage. This
32
        // index is used to detect duplicates, and also to provide a fast path
33
        // for looking up incoming HTLCs to determine if we're able to settle
34
        // them fully.
35
        //
36
        // maps: payHash => invoiceKey
37
        invoiceIndexBucket = []byte("paymenthashes")
38

39
        // payAddrIndexBucket is the name of the top-level bucket that maps
40
        // payment addresses to their invoice number. This can be used
41
        // to efficiently query or update non-legacy invoices. Note that legacy
42
        // invoices will not be included in this index since they all have the
43
        // same, all-zero payment address, however all newly generated invoices
44
        // will end up in this index.
45
        //
46
        // maps: payAddr => invoiceKey
47
        payAddrIndexBucket = []byte("pay-addr-index")
48

49
        // setIDIndexBucket is the name of the top-level bucket that maps set
50
        // ids to their invoice number. This can be used to efficiently query or
51
        // update AMP invoice. Note that legacy or MPP invoices will not be
52
        // included in this index, since their HTLCs do not have a set id.
53
        //
54
        // maps: setID => invoiceKey
55
        setIDIndexBucket = []byte("set-id-index")
56

57
        // numInvoicesKey is the name of key which houses the auto-incrementing
58
        // invoice ID which is essentially used as a primary key. With each
59
        // invoice inserted, the primary key is incremented by one. This key is
60
        // stored within the invoiceIndexBucket. Within the invoiceBucket
61
        // invoices are uniquely identified by the invoice ID.
62
        numInvoicesKey = []byte("nik")
63

64
        // addIndexBucket is an index bucket that we'll use to create a
65
        // monotonically increasing set of add indexes. Each time we add a new
66
        // invoice, this sequence number will be incremented and then populated
67
        // within the new invoice.
68
        //
69
        // In addition to this sequence number, we map:
70
        //
71
        //   addIndexNo => invoiceKey
72
        addIndexBucket = []byte("invoice-add-index")
73

74
        // settleIndexBucket is an index bucket that we'll use to create a
75
        // monotonically increasing integer for tracking a "settle index". Each
76
        // time an invoice is settled, this sequence number will be incremented
77
        // as populate within the newly settled invoice.
78
        //
79
        // In addition to this sequence number, we map:
80
        //
81
        //   settleIndexNo => invoiceKey
82
        settleIndexBucket = []byte("invoice-settle-index")
83

84
        // invoiceBucketTombstone is a special key that indicates the invoice
85
        // bucket has been permanently closed. Its purpose is to prevent the
86
        // invoice bucket from being reopened in the future. A key use case for
87
        // the tombstone is to ensure users cannot switch back to the KV invoice
88
        // database after migrating to the native SQL database.
89
        invoiceBucketTombstone = []byte("invoice-tombstone")
90
)
91

92
const (
93
        // A set of tlv type definitions used to serialize invoice htlcs to the
94
        // database.
95
        //
96
        // NOTE: A migration should be added whenever this list changes. This
97
        // prevents against the database being rolled back to an older
98
        // format where the surrounding logic might assume a different set of
99
        // fields are known.
100
        chanIDType       tlv.Type = 1
101
        htlcIDType       tlv.Type = 3
102
        amtType          tlv.Type = 5
103
        acceptHeightType tlv.Type = 7
104
        acceptTimeType   tlv.Type = 9
105
        resolveTimeType  tlv.Type = 11
106
        expiryHeightType tlv.Type = 13
107
        htlcStateType    tlv.Type = 15
108
        mppTotalAmtType  tlv.Type = 17
109
        htlcAMPType      tlv.Type = 19
110
        htlcHashType     tlv.Type = 21
111
        htlcPreimageType tlv.Type = 23
112

113
        // A set of tlv type definitions used to serialize invoice bodiees.
114
        //
115
        // NOTE: A migration should be added whenever this list changes. This
116
        // prevents against the database being rolled back to an older
117
        // format where the surrounding logic might assume a different set of
118
        // fields are known.
119
        memoType            tlv.Type = 0
120
        payReqType          tlv.Type = 1
121
        createTimeType      tlv.Type = 2
122
        settleTimeType      tlv.Type = 3
123
        addIndexType        tlv.Type = 4
124
        settleIndexType     tlv.Type = 5
125
        preimageType        tlv.Type = 6
126
        valueType           tlv.Type = 7
127
        cltvDeltaType       tlv.Type = 8
128
        expiryType          tlv.Type = 9
129
        paymentAddrType     tlv.Type = 10
130
        featuresType        tlv.Type = 11
131
        invStateType        tlv.Type = 12
132
        amtPaidType         tlv.Type = 13
133
        hodlInvoiceType     tlv.Type = 14
134
        invoiceAmpStateType tlv.Type = 15
135

136
        // A set of tlv type definitions used to serialize the invoice AMP
137
        // state along-side the main invoice body.
138
        ampStateSetIDType       tlv.Type = 0
139
        ampStateHtlcStateType   tlv.Type = 1
140
        ampStateSettleIndexType tlv.Type = 2
141
        ampStateSettleDateType  tlv.Type = 3
142
        ampStateCircuitKeysType tlv.Type = 4
143
        ampStateAmtPaidType     tlv.Type = 5
144
)
145

146
// AddInvoice inserts the targeted invoice into the database. If the invoice has
147
// *any* payment hashes which already exists within the database, then the
148
// insertion will be aborted and rejected due to the strict policy banning any
149
// duplicate payment hashes. A side effect of this function is that it sets
150
// AddIndex on newInvoice.
151
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
152
        paymentHash lntypes.Hash) (uint64, error) {
3✔
153

3✔
154
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
3✔
UNCOV
155
                return 0, err
×
UNCOV
156
        }
×
157

158
        var invoiceAddIndex uint64
3✔
159
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
160
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
3✔
161
                if err != nil {
3✔
162
                        return err
×
163
                }
×
164

165
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
166
                        invoiceIndexBucket,
3✔
167
                )
3✔
168
                if err != nil {
3✔
169
                        return err
×
170
                }
×
171
                addIndex, err := invoices.CreateBucketIfNotExists(
3✔
172
                        addIndexBucket,
3✔
173
                )
3✔
174
                if err != nil {
3✔
175
                        return err
×
176
                }
×
177

178
                // Ensure that an invoice an identical payment hash doesn't
179
                // already exist within the index.
180
                if invoiceIndex.Get(paymentHash[:]) != nil {
3✔
UNCOV
181
                        return invpkg.ErrDuplicateInvoice
×
UNCOV
182
                }
×
183

184
                // Check that we aren't inserting an invoice with a duplicate
185
                // payment address. The all-zeros payment address is
186
                // special-cased to support legacy keysend invoices which don't
187
                // assign one. This is safe since later we also will avoid
188
                // indexing them and avoid collisions.
189
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
3✔
190
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
6✔
191
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
3✔
192
                        if payAddrIndex.Get(paymentAddr) != nil {
6✔
193
                                return invpkg.ErrDuplicatePayAddr
3✔
194
                        }
3✔
195
                }
196

197
                // If the current running payment ID counter hasn't yet been
198
                // created, then create it now.
199
                var invoiceNum uint32
3✔
200
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
3✔
201
                if invoiceCounter == nil {
6✔
202
                        var scratch [4]byte
3✔
203
                        byteOrder.PutUint32(scratch[:], invoiceNum)
3✔
204
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
3✔
205
                        if err != nil {
3✔
206
                                return err
×
207
                        }
×
208
                } else {
3✔
209
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
3✔
210
                }
3✔
211

212
                newIndex, err := putInvoice(
3✔
213
                        invoices, invoiceIndex, payAddrIndex, addIndex,
3✔
214
                        newInvoice, invoiceNum, paymentHash,
3✔
215
                )
3✔
216
                if err != nil {
3✔
217
                        return err
×
218
                }
×
219

220
                invoiceAddIndex = newIndex
3✔
221
                return nil
3✔
222
        }, func() {
3✔
223
                invoiceAddIndex = 0
3✔
224
        })
3✔
225
        if err != nil {
6✔
226
                return 0, err
3✔
227
        }
3✔
228

229
        return invoiceAddIndex, err
3✔
230
}
231

232
// InvoicesAddedSince can be used by callers to seek into the event time series
233
// of all the invoices added in the database. The specified sinceAddIndex
234
// should be the highest add index that the caller knows of. This method will
235
// return all invoices with an add index greater than the specified
236
// sinceAddIndex.
237
//
238
// NOTE: The index starts from 1, as a result. We enforce that specifying a
239
// value below the starting index value is a noop.
240
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
241
        []invpkg.Invoice, error) {
3✔
242

3✔
243
        var newInvoices []invpkg.Invoice
3✔
244

3✔
245
        // If an index of zero was specified, then in order to maintain
3✔
246
        // backwards compat, we won't send out any new invoices.
3✔
247
        if sinceAddIndex == 0 {
6✔
248
                return newInvoices, nil
3✔
249
        }
3✔
250

251
        var startIndex [8]byte
3✔
252
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
3✔
253

3✔
254
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
255
                invoices := tx.ReadBucket(invoiceBucket)
3✔
256
                if invoices == nil {
3✔
257
                        return nil
×
258
                }
×
259

260
                addIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
261
                if addIndex == nil {
3✔
262
                        return nil
×
263
                }
×
264

265
                // We'll now run through each entry in the add index starting
266
                // at our starting index. We'll continue until we reach the
267
                // very end of the current key space.
268
                invoiceCursor := addIndex.ReadCursor()
3✔
269

3✔
270
                // We'll seek to the starting index, then manually advance the
3✔
271
                // cursor in order to skip the entry with the since add index.
3✔
272
                invoiceCursor.Seek(startIndex[:])
3✔
273
                addSeqNo, invoiceKey := invoiceCursor.Next()
3✔
274

3✔
275
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
6✔
276

3✔
277
                        // For each key found, we'll look up the actual
3✔
278
                        // invoice, then accumulate it into our return value.
3✔
279
                        invoice, err := fetchInvoice(
3✔
280
                                invoiceKey, invoices, nil, false,
3✔
281
                        )
3✔
282
                        if err != nil {
3✔
283
                                return err
×
284
                        }
×
285

286
                        newInvoices = append(newInvoices, invoice)
3✔
287
                }
288

289
                return nil
3✔
290
        }, func() {
3✔
291
                newInvoices = nil
3✔
292
        })
3✔
293
        if err != nil {
3✔
294
                return nil, err
×
295
        }
×
296

297
        return newInvoices, nil
3✔
298
}
299

300
// LookupInvoice attempts to look up an invoice according to its 32 byte
301
// payment hash. If an invoice which can settle the HTLC identified by the
302
// passed payment hash isn't found, then an error is returned. Otherwise, the
303
// full invoice is returned. Before setting the incoming HTLC, the values
304
// SHOULD be checked to ensure the payer meets the agreed upon contractual
305
// terms of the payment.
306
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
307
        invpkg.Invoice, error) {
3✔
308

3✔
309
        var invoice invpkg.Invoice
3✔
310
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
311
                invoices := tx.ReadBucket(invoiceBucket)
3✔
312
                if invoices == nil {
3✔
313
                        return invpkg.ErrNoInvoicesCreated
×
314
                }
×
315
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
316
                if invoiceIndex == nil {
6✔
317
                        return invpkg.ErrNoInvoicesCreated
3✔
318
                }
3✔
319
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
3✔
320
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
3✔
321

3✔
322
                // Retrieve the invoice number for this invoice using
3✔
323
                // the provided invoice reference.
3✔
324
                invoiceNum, err := fetchInvoiceNumByRef(
3✔
325
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
3✔
326
                )
3✔
327
                if err != nil {
6✔
328
                        return err
3✔
329
                }
3✔
330

331
                var setID *invpkg.SetID
3✔
332
                switch {
3✔
333
                // If this is a payment address ref, and the blank modified was
334
                // specified, then we'll use the zero set ID to indicate that
335
                // we won't want any HTLCs returned.
336
                case ref.PayAddr() != nil &&
337
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
3✔
338

3✔
339
                        var zeroSetID invpkg.SetID
3✔
340
                        setID = &zeroSetID
3✔
341

342
                // If this is a set ID ref, and the htlc set only modified was
343
                // specified, then we'll pass through the specified setID so
344
                // only that will be returned.
345
                case ref.SetID() != nil &&
346
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
3✔
347

3✔
348
                        setID = (*invpkg.SetID)(ref.SetID())
3✔
349
                }
350

351
                // An invoice was found, retrieve the remainder of the invoice
352
                // body.
353
                i, err := fetchInvoice(
3✔
354
                        invoiceNum, invoices, []*invpkg.SetID{setID}, true,
3✔
355
                )
3✔
356
                if err != nil {
3✔
357
                        return err
×
358
                }
×
359
                invoice = i
3✔
360

3✔
361
                return nil
3✔
362
        }, func() {})
3✔
363
        if err != nil {
6✔
364
                return invoice, err
3✔
365
        }
3✔
366

367
        return invoice, nil
3✔
368
}
369

370
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
371
// reference. The payment address will be treated as the primary key, falling
372
// back to the payment hash if nothing is found for the payment address. An
373
// error is returned if the invoice is not found.
374
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
375
        ref invpkg.InvoiceRef) ([]byte, error) {
3✔
376

3✔
377
        // If the set id is present, we only consult the set id index for this
3✔
378
        // invoice. This type of query is only used to facilitate user-facing
3✔
379
        // requests to lookup, settle or cancel an AMP invoice.
3✔
380
        setID := ref.SetID()
3✔
381
        if setID != nil {
6✔
382
                invoiceNumBySetID := setIDIndex.Get(setID[:])
3✔
383
                if invoiceNumBySetID == nil {
3✔
UNCOV
384
                        return nil, invpkg.ErrInvoiceNotFound
×
UNCOV
385
                }
×
386

387
                return invoiceNumBySetID, nil
3✔
388
        }
389

390
        payHash := ref.PayHash()
3✔
391
        payAddr := ref.PayAddr()
3✔
392

3✔
393
        getInvoiceNumByHash := func() []byte {
6✔
394
                if payHash != nil {
6✔
395
                        return invoiceIndex.Get(payHash[:])
3✔
396
                }
3✔
397
                return nil
3✔
398
        }
399

400
        getInvoiceNumByAddr := func() []byte {
6✔
401
                if payAddr != nil {
6✔
402
                        // Only allow lookups for payment address if it is not a
3✔
403
                        // blank payment address, which is a special-cased value
3✔
404
                        // for legacy keysend invoices.
3✔
405
                        if *payAddr != invpkg.BlankPayAddr {
6✔
406
                                return payAddrIndex.Get(payAddr[:])
3✔
407
                        }
3✔
408
                }
409
                return nil
3✔
410
        }
411

412
        invoiceNumByHash := getInvoiceNumByHash()
3✔
413
        invoiceNumByAddr := getInvoiceNumByAddr()
3✔
414
        switch {
3✔
415
        // If payment address and payment hash both reference an existing
416
        // invoice, ensure they reference the _same_ invoice.
417
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
3✔
418
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
3✔
UNCOV
419
                        return nil, invpkg.ErrInvRefEquivocation
×
UNCOV
420
                }
×
421

422
                return invoiceNumByAddr, nil
3✔
423

424
        // Return invoices by payment addr only.
425
        //
426
        // NOTE: We constrain this lookup to only apply if the invoice ref does
427
        // not contain a payment hash. Legacy and MPP payments depend on the
428
        // payment hash index to enforce that the HTLCs payment hash matches the
429
        // payment hash for the invoice, without this check we would
430
        // inadvertently assume the invoice contains the correct preimage for
431
        // the HTLC, which we only enforce via the lookup by the invoice index.
432
        case invoiceNumByAddr != nil && payHash == nil:
3✔
433
                return invoiceNumByAddr, nil
3✔
434

435
        // If we were only able to reference the invoice by hash, return the
436
        // corresponding invoice number. This can happen when no payment address
437
        // was provided, or if it didn't match anything in our records.
438
        case invoiceNumByHash != nil:
3✔
439
                return invoiceNumByHash, nil
3✔
440

441
        // Otherwise we don't know of the target invoice.
442
        default:
3✔
443
                return nil, invpkg.ErrInvoiceNotFound
3✔
444
        }
445
}
446

447
// FetchPendingInvoices returns all invoices that have not yet been settled or
448
// canceled. The returned map is keyed by the payment hash of each respective
449
// invoice.
450
func (d *DB) FetchPendingInvoices(_ context.Context) (
451
        map[lntypes.Hash]invpkg.Invoice, error) {
3✔
452

3✔
453
        result := make(map[lntypes.Hash]invpkg.Invoice)
3✔
454

3✔
455
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
456
                invoices := tx.ReadBucket(invoiceBucket)
3✔
457
                if invoices == nil {
3✔
458
                        return nil
×
459
                }
×
460

461
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
462
                if invoiceIndex == nil {
6✔
463
                        // Mask the error if there's no invoice
3✔
464
                        // index as that simply means there are no
3✔
465
                        // invoices added yet to the DB. In this case
3✔
466
                        // we simply return an empty list.
3✔
467
                        return nil
3✔
468
                }
3✔
469

470
                return invoiceIndex.ForEach(func(k, v []byte) error {
6✔
471
                        // Skip the special numInvoicesKey as that does not
3✔
472
                        // point to a valid invoice.
3✔
473
                        if bytes.Equal(k, numInvoicesKey) {
6✔
474
                                return nil
3✔
475
                        }
3✔
476

477
                        // Skip sub-buckets.
478
                        if v == nil {
3✔
479
                                return nil
×
480
                        }
×
481

482
                        invoice, err := fetchInvoice(v, invoices, nil, false)
3✔
483
                        if err != nil {
3✔
484
                                return err
×
485
                        }
×
486

487
                        if invoice.IsPending() {
6✔
488
                                var paymentHash lntypes.Hash
3✔
489
                                copy(paymentHash[:], k)
3✔
490
                                result[paymentHash] = invoice
3✔
491
                        }
3✔
492

493
                        return nil
3✔
494
                })
495
        }, func() {
3✔
496
                result = make(map[lntypes.Hash]invpkg.Invoice)
3✔
497
        })
3✔
498

499
        if err != nil {
3✔
500
                return nil, err
×
501
        }
×
502

503
        return result, nil
3✔
504
}
505

506
// QueryInvoices allows a caller to query the invoice database for invoices
507
// within the specified add index range.
508
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
509
        invpkg.InvoiceSlice, error) {
3✔
510

3✔
511
        var resp invpkg.InvoiceSlice
3✔
512

3✔
513
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
514
                // If the bucket wasn't found, then there aren't any invoices
3✔
515
                // within the database yet, so we can simply exit.
3✔
516
                invoices := tx.ReadBucket(invoiceBucket)
3✔
517
                if invoices == nil {
3✔
518
                        return invpkg.ErrNoInvoicesCreated
×
519
                }
×
520

521
                // Get the add index bucket which we will use to iterate through
522
                // our indexed invoices.
523
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
524
                if invoiceAddIndex == nil {
6✔
525
                        return invpkg.ErrNoInvoicesCreated
3✔
526
                }
3✔
527

528
                // Create a paginator which reads from our add index bucket with
529
                // the parameters provided by the invoice query.
530
                paginator := newPaginator(
3✔
531
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
3✔
532
                        q.NumMaxInvoices,
3✔
533
                )
3✔
534

3✔
535
                // accumulateInvoices looks up an invoice based on the index we
3✔
536
                // are given, adds it to our set of invoices if it has the right
3✔
537
                // characteristics for our query and returns the number of items
3✔
538
                // we have added to our set of invoices.
3✔
539
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
6✔
540
                        invoice, err := fetchInvoice(
3✔
541
                                indexValue, invoices, nil, false,
3✔
542
                        )
3✔
543
                        if err != nil {
3✔
544
                                return false, err
×
545
                        }
×
546

547
                        // Skip any settled or canceled invoices if the caller
548
                        // is only interested in pending ones.
549
                        if q.PendingOnly && !invoice.IsPending() {
3✔
UNCOV
550
                                return false, nil
×
UNCOV
551
                        }
×
552

553
                        // Get the creation time in Unix seconds, this always
554
                        // rounds down the nanoseconds to full seconds.
555
                        createTime := invoice.CreationDate.Unix()
3✔
556

3✔
557
                        // Skip any invoices that were created before the
3✔
558
                        // specified time.
3✔
559
                        if createTime < q.CreationDateStart {
6✔
560
                                return false, nil
3✔
561
                        }
3✔
562

563
                        // Skip any invoices that were created after the
564
                        // specified time.
565
                        if q.CreationDateEnd != 0 &&
3✔
566
                                createTime > q.CreationDateEnd {
6✔
567

3✔
568
                                return false, nil
3✔
569
                        }
3✔
570

571
                        // At this point, we've exhausted the offset, so we'll
572
                        // begin collecting invoices found within the range.
573
                        resp.Invoices = append(resp.Invoices, invoice)
3✔
574

3✔
575
                        return true, nil
3✔
576
                }
577

578
                // Query our paginator using accumulateInvoices to build up a
579
                // set of invoices.
580
                if err := paginator.query(accumulateInvoices); err != nil {
3✔
581
                        return err
×
582
                }
×
583

584
                // If we iterated through the add index in reverse order, then
585
                // we'll need to reverse the slice of invoices to return them in
586
                // forward order.
587
                if q.Reversed {
3✔
UNCOV
588
                        numInvoices := len(resp.Invoices)
×
UNCOV
589
                        for i := 0; i < numInvoices/2; i++ {
×
UNCOV
590
                                reverse := numInvoices - i - 1
×
UNCOV
591
                                resp.Invoices[i], resp.Invoices[reverse] =
×
UNCOV
592
                                        resp.Invoices[reverse], resp.Invoices[i]
×
UNCOV
593
                        }
×
594
                }
595

596
                return nil
3✔
597
        }, func() {
3✔
598
                resp = invpkg.InvoiceSlice{
3✔
599
                        InvoiceQuery: q,
3✔
600
                }
3✔
601
        })
3✔
602
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
3✔
603
                return resp, err
×
604
        }
×
605

606
        // Finally, record the indexes of the first and last invoices returned
607
        // so that the caller can resume from this point later on.
608
        if len(resp.Invoices) > 0 {
6✔
609
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
3✔
610
                lastIdx := len(resp.Invoices) - 1
3✔
611
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
3✔
612
        }
3✔
613

614
        return resp, nil
3✔
615
}
616

617
// UpdateInvoice attempts to update an invoice corresponding to the passed
618
// payment hash. If an invoice matching the passed payment hash doesn't exist
619
// within the database, then the action will fail with a "not found" error.
620
//
621
// The update is performed inside the same database transaction that fetches the
622
// invoice and is therefore atomic. The fields to update are controlled by the
623
// supplied callback.  When updating an invoice, the update itself happens
624
// in-memory on a copy of the invoice. Once it is written successfully to the
625
// database, the in-memory copy is returned to the caller.
626
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
627
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
628
        *invpkg.Invoice, error) {
3✔
629

3✔
630
        var updatedInvoice *invpkg.Invoice
3✔
631
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
632
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
3✔
633
                if err != nil {
3✔
634
                        return err
×
635
                }
×
636
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
637
                        invoiceIndexBucket,
3✔
638
                )
3✔
639
                if err != nil {
3✔
640
                        return err
×
641
                }
×
642
                settleIndex, err := invoices.CreateBucketIfNotExists(
3✔
643
                        settleIndexBucket,
3✔
644
                )
3✔
645
                if err != nil {
3✔
646
                        return err
×
647
                }
×
648
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
3✔
649
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
3✔
650

3✔
651
                // Retrieve the invoice number for this invoice using the
3✔
652
                // provided invoice reference.
3✔
653
                invoiceNum, err := fetchInvoiceNumByRef(
3✔
654
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
3✔
655
                )
3✔
656
                if err != nil {
3✔
UNCOV
657
                        return err
×
UNCOV
658
                }
×
659

660
                // setIDHint can also be nil here, which means all the HTLCs
661
                // for AMP invoices are fetched. If the blank setID is passed
662
                // in, then no HTLCs are fetched for the AMP invoice. If a
663
                // specific setID is passed in, then only the HTLCs for that
664
                // setID are fetched for a particular sub-AMP invoice.
665
                invoice, err := fetchInvoice(
3✔
666
                        invoiceNum, invoices, []*invpkg.SetID{setIDHint}, false,
3✔
667
                )
3✔
668
                if err != nil {
3✔
669
                        return err
×
670
                }
×
671

672
                now := d.clock.Now()
3✔
673
                updater := &kvInvoiceUpdater{
3✔
674
                        db:                d,
3✔
675
                        invoicesBucket:    invoices,
3✔
676
                        settleIndexBucket: settleIndex,
3✔
677
                        setIDIndexBucket:  setIDIndex,
3✔
678
                        updateTime:        now,
3✔
679
                        invoiceNum:        invoiceNum,
3✔
680
                        invoice:           &invoice,
3✔
681
                        updatedAmpHtlcs:   make(ampHTLCsMap),
3✔
682
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
3✔
683
                }
3✔
684

3✔
685
                payHash := ref.PayHash()
3✔
686
                updatedInvoice, err = invpkg.UpdateInvoice(
3✔
687
                        payHash, updater.invoice, now, callback, updater,
3✔
688
                )
3✔
689
                if err != nil {
6✔
690
                        return err
3✔
691
                }
3✔
692

693
                // If this is an AMP update, then limit the returned AMP state
694
                // to only the requested set ID.
695
                if setIDHint != nil {
6✔
696
                        filterInvoiceAMPState(updatedInvoice, setIDHint)
3✔
697
                }
3✔
698

699
                return nil
3✔
700
        }, func() {
3✔
701
                updatedInvoice = nil
3✔
702
        })
3✔
703

704
        return updatedInvoice, err
3✔
705
}
706

707
// filterInvoiceAMPState filters the AMP state of the invoice to only include
708
// state for the specified set IDs.
709
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
3✔
710
        filteredAMPState := make(invpkg.AMPInvoiceState)
3✔
711

3✔
712
        for _, setID := range setIDs {
6✔
713
                if setID == nil {
6✔
714
                        return
3✔
715
                }
3✔
716

717
                ampState, ok := invoice.AMPState[*setID]
3✔
718
                if ok {
6✔
719
                        filteredAMPState[*setID] = ampState
3✔
720
                }
3✔
721
        }
722

723
        invoice.AMPState = filteredAMPState
3✔
724
}
725

726
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
727
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
728

729
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
730
// is used with the kv implementation of the invoice database. Note that this
731
// updater is not concurrency safe and synchronizaton is expected to be handled
732
// on the DB level.
733
type kvInvoiceUpdater struct {
734
        db                *DB
735
        invoicesBucket    kvdb.RwBucket
736
        settleIndexBucket kvdb.RwBucket
737
        setIDIndexBucket  kvdb.RwBucket
738

739
        // updateTime is the timestamp for the update.
740
        updateTime time.Time
741

742
        // invoiceNum is a legacy key similar to the add index that is used
743
        // only in the kv implementation.
744
        invoiceNum []byte
745

746
        // invoice is the invoice that we're updating. As a side effect of the
747
        // update this invoice will be mutated.
748
        invoice *invpkg.Invoice
749

750
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
751
        // cancelled as part of this update.
752
        updatedAmpHtlcs ampHTLCsMap
753

754
        // settledSetIDs holds the set IDs that are settled with this update.
755
        settledSetIDs map[invpkg.SetID]struct{}
756
}
757

758
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
759
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
760
        _ *invpkg.InvoiceHTLC) error {
3✔
761

3✔
762
        return nil
3✔
763
}
3✔
764

765
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
766
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
767
        _ time.Time) error {
3✔
768

3✔
769
        return nil
3✔
770
}
3✔
771

772
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
773
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
774
        _ lntypes.Preimage) error {
3✔
775

3✔
776
        return nil
3✔
777
}
3✔
778

779
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
780
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
781
        _ *lntypes.Preimage) error {
3✔
782

3✔
783
        return nil
3✔
784
}
3✔
785

786
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
787
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
3✔
788
        return nil
3✔
789
}
3✔
790

791
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
792
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
793
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
3✔
794

3✔
795
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
6✔
796
                switch state.State {
3✔
797
                case invpkg.HtlcStateAccepted:
3✔
798
                        // If we're just now creating the HTLCs for this set
3✔
799
                        // then we'll also pull in the existing HTLCs that are
3✔
800
                        // part of this set, so we can write them all to disk
3✔
801
                        // together (same value)
3✔
802
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
3✔
803
                                &setID, invpkg.HtlcStateAccepted,
3✔
804
                        )
3✔
805

UNCOV
806
                case invpkg.HtlcStateCanceled:
×
UNCOV
807
                        // Only HTLCs in the accepted state, can be cancelled,
×
UNCOV
808
                        // but we also want to merge that with HTLCs that may be
×
UNCOV
809
                        // canceled as well since it can be cancelled one by
×
UNCOV
810
                        // one.
×
UNCOV
811
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
×
UNCOV
812
                                &setID, invpkg.HtlcStateAccepted,
×
UNCOV
813
                        )
×
UNCOV
814

×
UNCOV
815
                        cancelledHtlcs := k.invoice.HTLCSet(
×
UNCOV
816
                                &setID, invpkg.HtlcStateCanceled,
×
UNCOV
817
                        )
×
UNCOV
818
                        for htlcKey, htlc := range cancelledHtlcs {
×
UNCOV
819
                                k.updatedAmpHtlcs[setID][htlcKey] = htlc
×
UNCOV
820
                        }
×
821

UNCOV
822
                case invpkg.HtlcStateSettled:
×
UNCOV
823
                        k.updatedAmpHtlcs[setID] = make(
×
UNCOV
824
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
×
UNCOV
825
                        )
×
826
                }
827
        }
828

829
        if state.State == invpkg.HtlcStateSettled {
6✔
830
                // Add the set ID to the set that was settled in this invoice
3✔
831
                // update. We'll use this later to update the settle index.
3✔
832
                k.settledSetIDs[setID] = struct{}{}
3✔
833
        }
3✔
834

835
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
3✔
836

3✔
837
        return nil
3✔
838
}
839

840
// Finalize finalizes the update before it is written to the database.
841
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
3✔
842
        switch updateType {
3✔
843
        case invpkg.AddHTLCsUpdate:
3✔
844
                return k.storeAddHtlcsUpdate()
3✔
845

846
        case invpkg.CancelHTLCsUpdate:
3✔
847
                return k.storeCancelHtlcsUpdate()
3✔
848

849
        case invpkg.SettleHodlInvoiceUpdate:
3✔
850
                return k.storeSettleHodlInvoiceUpdate()
3✔
851

852
        case invpkg.CancelInvoiceUpdate:
3✔
853
                // Persist all changes which where made when cancelling the
3✔
854
                // invoice. All HTLCs which were accepted are now canceled, so
3✔
855
                // we persist this state.
3✔
856
                return k.storeCancelHtlcsUpdate()
3✔
857
        }
858

859
        return fmt.Errorf("unknown update type: %v", updateType)
×
860
}
861

862
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
863
// set of HTLCs.
864
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
3✔
865
        err := k.serializeAndStoreInvoice()
3✔
866
        if err != nil {
3✔
867
                return err
×
868
        }
×
869

870
        // If this is an AMP invoice, then we'll actually store the rest
871
        // of the HTLCs in-line with the invoice, using the invoice ID
872
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
873
        // setID.
874
        if k.invoice.IsAMP() {
3✔
UNCOV
875
                return k.updateAMPInvoices()
×
UNCOV
876
        }
×
877

878
        return nil
3✔
879
}
880

881
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
882
// HTLCs.
883
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
3✔
884
        invoiceIsAMP := k.invoice.IsAMP()
3✔
885

3✔
886
        for htlcSetID := range k.updatedAmpHtlcs {
6✔
887
                // Check if this SetID already exist.
3✔
888
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
3✔
889

3✔
890
                if setIDInvNum == nil {
6✔
891
                        err := k.setIDIndexBucket.Put(
3✔
892
                                htlcSetID[:], k.invoiceNum,
3✔
893
                        )
3✔
894
                        if err != nil {
3✔
895
                                return err
×
896
                        }
×
897
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
3✔
UNCOV
898
                        return invpkg.ErrDuplicateSetID{
×
UNCOV
899
                                SetID: htlcSetID,
×
UNCOV
900
                        }
×
UNCOV
901
                }
×
902
        }
903

904
        // If this is a non-AMP invoice, then the state can eventually go to
905
        // ContractSettled, so we pass in nil value as part of
906
        // setSettleMetaFields.
907
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
6✔
908
                err := k.setSettleMetaFields(nil)
3✔
909
                if err != nil {
3✔
910
                        return err
×
911
                }
×
912
        }
913

914
        // As we don't update the settle index above for AMP invoices, we'll do
915
        // it here for each sub-AMP invoice that was settled.
916
        for settledSetID := range k.settledSetIDs {
6✔
917
                settledSetID := settledSetID
3✔
918
                err := k.setSettleMetaFields(&settledSetID)
3✔
919
                if err != nil {
3✔
920
                        return err
×
921
                }
×
922
        }
923

924
        err := k.serializeAndStoreInvoice()
3✔
925
        if err != nil {
3✔
926
                return err
×
927
        }
×
928

929
        // If this is an AMP invoice, then we'll actually store the rest of the
930
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
931
        // and the AMP key as a suffix: invoiceNum || setID.
932
        if invoiceIsAMP {
6✔
933
                return k.updateAMPInvoices()
3✔
934
        }
3✔
935

936
        return nil
3✔
937
}
938

939
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
940
// settling a hodl invoice.
941
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
3✔
942
        err := k.setSettleMetaFields(nil)
3✔
943
        if err != nil {
3✔
944
                return err
×
945
        }
×
946

947
        return k.serializeAndStoreInvoice()
3✔
948
}
949

950
// setSettleMetaFields updates the metadata associated with settlement of an
951
// invoice. If a non-nil setID is passed in, then the value will be append to
952
// the invoice number as well, in order to allow us to detect repeated payments
953
// to the same AMP invoices "across time".
954
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
3✔
955
        // Now that we know the invoice hasn't already been settled, we'll
3✔
956
        // update the settle index so we can place this settle event in the
3✔
957
        // proper location within our time series.
3✔
958
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
3✔
959
        if err != nil {
3✔
960
                return err
×
961
        }
×
962

963
        // Make a new byte array on the stack that can potentially store the 4
964
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
965
        // here which is the number of bytes copied so we can only store the 4
966
        // bytes if this is a non-AMP invoice.
967
        var indexKey [invoiceSetIDKeyLen]byte
3✔
968
        valueLen := copy(indexKey[:], k.invoiceNum)
3✔
969

3✔
970
        if setID != nil {
6✔
971
                valueLen += copy(indexKey[valueLen:], setID[:])
3✔
972
        }
3✔
973

974
        var seqNoBytes [8]byte
3✔
975
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
3✔
976
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
3✔
977
        if err != nil {
3✔
978
                return err
×
979
        }
×
980

981
        // If the setID is nil, then this means that this is a non-AMP settle,
982
        // so we'll update the invoice settle index directly.
983
        if setID == nil {
6✔
984
                k.invoice.SettleDate = k.updateTime
3✔
985
                k.invoice.SettleIndex = nextSettleSeqNo
3✔
986
        } else {
6✔
987
                // If the set ID isn't blank, we'll update the AMP state map
3✔
988
                // which tracks when each of the setIDs associated with a given
3✔
989
                // AMP invoice are settled.
3✔
990
                ampState := k.invoice.AMPState[*setID]
3✔
991

3✔
992
                ampState.SettleDate = k.updateTime
3✔
993
                ampState.SettleIndex = nextSettleSeqNo
3✔
994

3✔
995
                k.invoice.AMPState[*setID] = ampState
3✔
996
        }
3✔
997

998
        return nil
3✔
999
}
1000

1001
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
1002
// then continually write the invoices to the end of the invoice value, we
1003
// instead write the invoices into a new key preifx that follows the main
1004
// invoice number. This ensures that we don't need to continually decode a
1005
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
1006
// associated with a particular HTLC set.
1007
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
3✔
1008
        for setID, htlcSet := range k.updatedAmpHtlcs {
6✔
1009
                // First write out the set of HTLCs including all the relevant
3✔
1010
                // TLV values.
3✔
1011
                var b bytes.Buffer
3✔
1012
                if err := serializeHtlcs(&b, htlcSet); err != nil {
3✔
1013
                        return err
×
1014
                }
×
1015

1016
                // Next store each HTLC in-line, using a prefix based off the
1017
                // invoice number.
1018
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
3✔
1019

3✔
1020
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
3✔
1021
                if err != nil {
3✔
1022
                        return err
×
1023
                }
×
1024
        }
1025

1026
        return nil
3✔
1027
}
1028

1029
// serializeAndStoreInvoice is a helper function used to store invoices.
1030
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
3✔
1031
        var buf bytes.Buffer
3✔
1032
        if err := serializeInvoice(&buf, k.invoice); err != nil {
3✔
1033
                return err
×
1034
        }
×
1035

1036
        return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
3✔
1037
}
1038

1039
// InvoicesSettledSince can be used by callers to catch up any settled invoices
1040
// they missed within the settled invoice time series. We'll return all known
1041
// settled invoice that have a settle index higher than the passed
1042
// sinceSettleIndex.
1043
//
1044
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1045
// value below the starting index value is a noop.
1046
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1047
        []invpkg.Invoice, error) {
3✔
1048

3✔
1049
        var settledInvoices []invpkg.Invoice
3✔
1050

3✔
1051
        // If an index of zero was specified, then in order to maintain
3✔
1052
        // backwards compat, we won't send out any new invoices.
3✔
1053
        if sinceSettleIndex == 0 {
6✔
1054
                return settledInvoices, nil
3✔
1055
        }
3✔
1056

1057
        var startIndex [8]byte
3✔
1058
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
3✔
1059

3✔
1060
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
1061
                invoices := tx.ReadBucket(invoiceBucket)
3✔
1062
                if invoices == nil {
3✔
1063
                        return nil
×
1064
                }
×
1065

1066
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
3✔
1067
                if settleIndex == nil {
3✔
1068
                        return nil
×
1069
                }
×
1070

1071
                // We'll now run through each entry in the add index starting
1072
                // at our starting index. We'll continue until we reach the
1073
                // very end of the current key space.
1074
                invoiceCursor := settleIndex.ReadCursor()
3✔
1075

3✔
1076
                // We'll seek to the starting index, then manually advance the
3✔
1077
                // cursor in order to skip the entry with the since add index.
3✔
1078
                invoiceCursor.Seek(startIndex[:])
3✔
1079
                seqNo, indexValue := invoiceCursor.Next()
3✔
1080

3✔
1081
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
6✔
1082
                        // Depending on the length of the index value, this may
3✔
1083
                        // or may not be an AMP invoice, so we'll extract the
3✔
1084
                        // invoice value into two components: the invoice num,
3✔
1085
                        // and the setID (may not be there).
3✔
1086
                        var (
3✔
1087
                                invoiceKey [4]byte
3✔
1088
                                setID      *invpkg.SetID
3✔
1089
                        )
3✔
1090

3✔
1091
                        valueLen := copy(invoiceKey[:], indexValue)
3✔
1092
                        if len(indexValue) == invoiceSetIDKeyLen {
6✔
1093
                                setID = new(invpkg.SetID)
3✔
1094
                                copy(setID[:], indexValue[valueLen:])
3✔
1095
                        }
3✔
1096

1097
                        // For each key found, we'll look up the actual
1098
                        // invoice, then accumulate it into our return value.
1099
                        invoice, err := fetchInvoice(
3✔
1100
                                invoiceKey[:], invoices, []*invpkg.SetID{setID},
3✔
1101
                                true,
3✔
1102
                        )
3✔
1103
                        if err != nil {
3✔
1104
                                return err
×
1105
                        }
×
1106

1107
                        settledInvoices = append(settledInvoices, invoice)
3✔
1108
                }
1109

1110
                return nil
3✔
1111
        }, func() {
3✔
1112
                settledInvoices = nil
3✔
1113
        })
3✔
1114
        if err != nil {
3✔
1115
                return nil, err
×
1116
        }
×
1117

1118
        return settledInvoices, nil
3✔
1119
}
1120

1121
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
1122
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
1123
        uint64, error) {
3✔
1124

3✔
1125
        // Create the invoice key which is just the big-endian representation
3✔
1126
        // of the invoice number.
3✔
1127
        var invoiceKey [4]byte
3✔
1128
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
3✔
1129

3✔
1130
        // Increment the num invoice counter index so the next invoice bares
3✔
1131
        // the proper ID.
3✔
1132
        var scratch [4]byte
3✔
1133
        invoiceCounter := invoiceNum + 1
3✔
1134
        byteOrder.PutUint32(scratch[:], invoiceCounter)
3✔
1135
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
3✔
1136
                return 0, err
×
1137
        }
×
1138

1139
        // Add the payment hash to the invoice index. This will let us quickly
1140
        // identify if we can settle an incoming payment, and also to possibly
1141
        // allow a single invoice to have multiple payment installations.
1142
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
3✔
1143
        if err != nil {
3✔
1144
                return 0, err
×
1145
        }
×
1146

1147
        // Add the invoice to the payment address index, but only if the invoice
1148
        // has a non-zero payment address. The all-zero payment address is still
1149
        // in use by legacy keysend, so we special-case here to avoid
1150
        // collisions.
1151
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
6✔
1152
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
3✔
1153
                if err != nil {
3✔
1154
                        return 0, err
×
1155
                }
×
1156
        }
1157

1158
        // Next, we'll obtain the next add invoice index (sequence
1159
        // number), so we can properly place this invoice within this
1160
        // event stream.
1161
        nextAddSeqNo, err := addIndex.NextSequence()
3✔
1162
        if err != nil {
3✔
1163
                return 0, err
×
1164
        }
×
1165

1166
        // With the next sequence obtained, we'll updating the event series in
1167
        // the add index bucket to map this current add counter to the index of
1168
        // this new invoice.
1169
        var seqNoBytes [8]byte
3✔
1170
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
3✔
1171
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
3✔
1172
                return 0, err
×
1173
        }
×
1174

1175
        i.AddIndex = nextAddSeqNo
3✔
1176

3✔
1177
        // Finally, serialize the invoice itself to be written to the disk.
3✔
1178
        var buf bytes.Buffer
3✔
1179
        if err := serializeInvoice(&buf, i); err != nil {
3✔
1180
                return 0, err
×
1181
        }
×
1182

1183
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
3✔
1184
                return 0, err
×
1185
        }
×
1186

1187
        return nextAddSeqNo, nil
3✔
1188
}
1189

1190
// recordSize returns the amount of bytes this TLV record will occupy when
1191
// encoded.
1192
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
3✔
1193
        var (
3✔
1194
                b   bytes.Buffer
3✔
1195
                buf [8]byte
3✔
1196
        )
3✔
1197

3✔
1198
        // We know that encoding works since the tests pass in the build this
3✔
1199
        // file is checked into, so we'll simplify things and simply encode it
3✔
1200
        // ourselves then report the total amount of bytes used.
3✔
1201
        if err := ampStateEncoder(&b, a, &buf); err != nil {
3✔
1202
                // This should never error out, but we log it just in case it
×
1203
                // does.
×
1204
                log.Errorf("encoding the amp invoice state failed: %v", err)
×
1205
        }
×
1206

1207
        return func() uint64 {
6✔
1208
                return uint64(len(b.Bytes()))
3✔
1209
        }
3✔
1210
}
1211

1212
// serializeInvoice serializes an invoice to a writer.
1213
//
1214
// Note: this function is in use for a migration. Before making changes that
1215
// would modify the on disk format, make a copy of the original code and store
1216
// it with the migration.
1217
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
3✔
1218
        creationDateBytes, err := i.CreationDate.MarshalBinary()
3✔
1219
        if err != nil {
3✔
1220
                return err
×
1221
        }
×
1222

1223
        settleDateBytes, err := i.SettleDate.MarshalBinary()
3✔
1224
        if err != nil {
3✔
1225
                return err
×
1226
        }
×
1227

1228
        var fb bytes.Buffer
3✔
1229
        err = i.Terms.Features.EncodeBase256(&fb)
3✔
1230
        if err != nil {
3✔
1231
                return err
×
1232
        }
×
1233
        featureBytes := fb.Bytes()
3✔
1234

3✔
1235
        preimage := [32]byte(invpkg.UnknownPreimage)
3✔
1236
        if i.Terms.PaymentPreimage != nil {
6✔
1237
                preimage = *i.Terms.PaymentPreimage
3✔
1238
                if preimage == invpkg.UnknownPreimage {
3✔
1239
                        return errors.New("cannot use all-zeroes preimage")
×
1240
                }
×
1241
        }
1242
        value := uint64(i.Terms.Value)
3✔
1243
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
3✔
1244
        expiry := uint64(i.Terms.Expiry)
3✔
1245

3✔
1246
        amtPaid := uint64(i.AmtPaid)
3✔
1247
        state := uint8(i.State)
3✔
1248

3✔
1249
        var hodlInvoice uint8
3✔
1250
        if i.HodlInvoice {
6✔
1251
                hodlInvoice = 1
3✔
1252
        }
3✔
1253

1254
        tlvStream, err := tlv.NewStream(
3✔
1255
                // Memo and payreq.
3✔
1256
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1257
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3✔
1258

3✔
1259
                // Add/settle metadata.
3✔
1260
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3✔
1261
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1262
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1263
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1264

3✔
1265
                // Terms.
3✔
1266
                tlv.MakePrimitiveRecord(preimageType, &preimage),
3✔
1267
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1268
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1269
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1270
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1271
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1272

3✔
1273
                // Invoice state.
3✔
1274
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1275
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1276

3✔
1277
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1278

3✔
1279
                // Invoice AMP state.
3✔
1280
                tlv.MakeDynamicRecord(
3✔
1281
                        invoiceAmpStateType, &i.AMPState,
3✔
1282
                        ampRecordSize(&i.AMPState),
3✔
1283
                        ampStateEncoder, ampStateDecoder,
3✔
1284
                ),
3✔
1285
        )
3✔
1286
        if err != nil {
3✔
1287
                return err
×
1288
        }
×
1289

1290
        var b bytes.Buffer
3✔
1291
        if err = tlvStream.Encode(&b); err != nil {
3✔
1292
                return err
×
1293
        }
×
1294

1295
        err = binary.Write(w, byteOrder, uint64(b.Len()))
3✔
1296
        if err != nil {
3✔
1297
                return err
×
1298
        }
×
1299

1300
        if _, err = w.Write(b.Bytes()); err != nil {
3✔
1301
                return err
×
1302
        }
×
1303

1304
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
1305
        // in-line with the rest of the invoice.
1306
        if i.IsAMP() {
6✔
1307
                return nil
3✔
1308
        }
3✔
1309

1310
        return serializeHtlcs(w, i.Htlcs)
3✔
1311
}
1312

1313
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
1314
// a writer.
1315
func serializeHtlcs(w io.Writer,
1316
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
3✔
1317

3✔
1318
        for key, htlc := range htlcs {
6✔
1319
                // Encode the htlc in a tlv stream.
3✔
1320
                chanID := key.ChanID.ToUint64()
3✔
1321
                amt := uint64(htlc.Amt)
3✔
1322
                mppTotalAmt := uint64(htlc.MppTotalAmt)
3✔
1323
                acceptTime := putNanoTime(htlc.AcceptTime)
3✔
1324
                resolveTime := putNanoTime(htlc.ResolveTime)
3✔
1325
                state := uint8(htlc.State)
3✔
1326

3✔
1327
                var records []tlv.Record
3✔
1328
                records = append(records,
3✔
1329
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
1330
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
1331
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
1332
                        tlv.MakePrimitiveRecord(
3✔
1333
                                acceptHeightType, &htlc.AcceptHeight,
3✔
1334
                        ),
3✔
1335
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
1336
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
1337
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
1338
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
1339
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
1340
                )
3✔
1341

3✔
1342
                if htlc.AMP != nil {
6✔
1343
                        setIDRecord := tlv.MakeDynamicRecord(
3✔
1344
                                htlcAMPType, &htlc.AMP.Record,
3✔
1345
                                htlc.AMP.Record.PayloadSize,
3✔
1346
                                record.AMPEncoder, record.AMPDecoder,
3✔
1347
                        )
3✔
1348
                        records = append(records, setIDRecord)
3✔
1349

3✔
1350
                        hash32 := [32]byte(htlc.AMP.Hash)
3✔
1351
                        hashRecord := tlv.MakePrimitiveRecord(
3✔
1352
                                htlcHashType, &hash32,
3✔
1353
                        )
3✔
1354
                        records = append(records, hashRecord)
3✔
1355

3✔
1356
                        if htlc.AMP.Preimage != nil {
6✔
1357
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
3✔
1358
                                preimageRecord := tlv.MakePrimitiveRecord(
3✔
1359
                                        htlcPreimageType, &preimage32,
3✔
1360
                                )
3✔
1361
                                records = append(records, preimageRecord)
3✔
1362
                        }
3✔
1363
                }
1364

1365
                // Convert the custom records to tlv.Record types that are ready
1366
                // for serialization.
1367
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
3✔
1368

3✔
1369
                // Append the custom records. Their ids are in the experimental
3✔
1370
                // range and sorted, so there is no need to sort again.
3✔
1371
                records = append(records, customRecords...)
3✔
1372

3✔
1373
                tlvStream, err := tlv.NewStream(records...)
3✔
1374
                if err != nil {
3✔
1375
                        return err
×
1376
                }
×
1377

1378
                var b bytes.Buffer
3✔
1379
                if err := tlvStream.Encode(&b); err != nil {
3✔
1380
                        return err
×
1381
                }
×
1382

1383
                // Write the length of the tlv stream followed by the stream
1384
                // bytes.
1385
                err = binary.Write(w, byteOrder, uint64(b.Len()))
3✔
1386
                if err != nil {
3✔
1387
                        return err
×
1388
                }
×
1389

1390
                if _, err := w.Write(b.Bytes()); err != nil {
3✔
1391
                        return err
×
1392
                }
×
1393
        }
1394

1395
        return nil
3✔
1396
}
1397

1398
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
1399
// timestamp will be mapped to 0, since calling UnixNano in that case is
1400
// undefined.
1401
func putNanoTime(t time.Time) uint64 {
3✔
1402
        if t.IsZero() {
6✔
1403
                return 0
3✔
1404
        }
3✔
1405
        return uint64(t.UnixNano())
3✔
1406
}
1407

1408
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
1409
// is provided, an zero-value time stamp is returned.
1410
func getNanoTime(ns uint64) time.Time {
3✔
1411
        if ns == 0 {
6✔
1412
                return time.Time{}
3✔
1413
        }
3✔
1414
        return time.Unix(0, int64(ns))
3✔
1415
}
1416

1417
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
1418
// identified by the setID value.
1419
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1420
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1421
        error) {
3✔
1422

3✔
1423
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1424
        for _, setID := range setIDs {
6✔
1425
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
3✔
1426

3✔
1427
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
3✔
1428
                if htlcSetBytes == nil {
6✔
1429
                        // A set ID was passed in, but we don't have this
3✔
1430
                        // stored yet, meaning that the setID is being added
3✔
1431
                        // for the first time.
3✔
1432
                        return htlcs, invpkg.ErrInvoiceNotFound
3✔
1433
                }
3✔
1434

1435
                htlcSetReader := bytes.NewReader(htlcSetBytes)
3✔
1436
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
3✔
1437
                if err != nil {
3✔
1438
                        return nil, err
×
1439
                }
×
1440

1441
                for key, htlc := range htlcsBySetID {
6✔
1442
                        htlcs[key] = htlc
3✔
1443
                }
3✔
1444
        }
1445

1446
        return htlcs, nil
3✔
1447
}
1448

1449
// forEachAMPInvoice is a helper function that attempts to iterate over each of
1450
// the HTLC sets (based on their set ID) for the given AMP invoice identified
1451
// by its invoiceNum. The callback closure is called for each key within the
1452
// prefix range.
1453
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1454
        callback func(key, htlcSet []byte) error) error {
3✔
1455

3✔
1456
        invoiceCursor := invoiceBucket.ReadCursor()
3✔
1457

3✔
1458
        // Seek to the first key that includes the invoice data itself.
3✔
1459
        invoiceCursor.Seek(invoiceNum)
3✔
1460

3✔
1461
        // Advance to the very first key _after_ the invoice data, as this is
3✔
1462
        // where we'll encounter our first HTLC (if any are present).
3✔
1463
        cursorKey, htlcSet := invoiceCursor.Next()
3✔
1464

3✔
1465
        // If at this point, the cursor key doesn't match the invoice num
3✔
1466
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
3✔
1467
        // associated with it.
3✔
1468
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
6✔
1469
                return nil
3✔
1470
        }
3✔
1471

1472
        // Otherwise continue to iterate until we no longer match the prefix,
1473
        // executing the call back at each step.
1474
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
6✔
1475
                err := callback(cursorKey, htlcSet)
3✔
1476
                if err != nil {
3✔
1477
                        return err
×
1478
                }
×
1479
        }
1480

1481
        return nil
3✔
1482
}
1483

1484
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
1485
// AMP bucket to find all the individual HTLCs (by setID) associated with a
1486
// given invoice. If a list of set IDs are specified, then only HTLCs
1487
// associated with that setID will be retrieved.
1488
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1489
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1490
        error) {
3✔
1491

3✔
1492
        // If a set of setIDs was specified, then we can skip the cursor and
3✔
1493
        // just read out exactly what we need.
3✔
1494
        if len(setIDs) != 0 && setIDs[0] != nil {
6✔
1495
                return fetchFilteredAmpInvoices(
3✔
1496
                        invoiceBucket, invoiceNum, setIDs...,
3✔
1497
                )
3✔
1498
        }
3✔
1499

1500
        // Otherwise, iterate over all the htlc sets that are prefixed beside
1501
        // this invoice in the main invoice bucket.
1502
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1503
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
3✔
1504
                func(key, htlcSet []byte) error {
6✔
1505
                        htlcSetReader := bytes.NewReader(htlcSet)
3✔
1506
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
3✔
1507
                        if err != nil {
3✔
1508
                                return err
×
1509
                        }
×
1510

1511
                        for key, htlc := range htlcsBySetID {
6✔
1512
                                htlcs[key] = htlc
3✔
1513
                        }
3✔
1514

1515
                        return nil
3✔
1516
                },
1517
        )
1518

1519
        if err != nil {
3✔
1520
                return nil, err
×
1521
        }
×
1522

1523
        return htlcs, nil
3✔
1524
}
1525

1526
// fetchInvoice attempts to read out the relevant state for the invoice as
1527
// specified by the invoice number. If the setID fields are set, then only the
1528
// HTLC information pertaining to those set IDs is returned.
1529
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1530
        setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
3✔
1531

3✔
1532
        invoiceBytes := invoices.Get(invoiceNum)
3✔
1533
        if invoiceBytes == nil {
3✔
1534
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
×
1535
        }
×
1536

1537
        invoiceReader := bytes.NewReader(invoiceBytes)
3✔
1538

3✔
1539
        invoice, err := deserializeInvoice(invoiceReader)
3✔
1540
        if err != nil {
3✔
1541
                return invpkg.Invoice{}, err
×
1542
        }
×
1543

1544
        // If this is an AMP invoice we'll also attempt to read out the set of
1545
        // HTLCs that were paid to prior set IDs, if needed.
1546
        if !invoice.IsAMP() {
6✔
1547
                return invoice, nil
3✔
1548
        }
3✔
1549

1550
        if shouldFetchAMPHTLCs(invoice, setIDs) {
6✔
1551
                invoice.Htlcs, err = fetchAmpSubInvoices(
3✔
1552
                        invoices, invoiceNum, setIDs...,
3✔
1553
                )
3✔
1554
                // TODO(positiveblue): we should fail when we are not able to
3✔
1555
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
3✔
1556
                // the invoice and channeldb package break if we return this
3✔
1557
                // error. We need to update them when we migrate this logic to
3✔
1558
                // the sql implementation.
3✔
1559
                if err != nil {
6✔
1560
                        log.Errorf("unable to fetch amp htlcs for inv "+
3✔
1561
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
3✔
1562
                }
3✔
1563

1564
                if filterAMPState {
6✔
1565
                        filterInvoiceAMPState(&invoice, setIDs...)
3✔
1566
                }
3✔
1567
        }
1568

1569
        return invoice, nil
3✔
1570
}
1571

1572
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
1573
// were paid to the relevant set IDs.
1574
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
3✔
1575
        // For AMP invoice that already have HTLCs populated (created before
3✔
1576
        // recurring invoices), then we don't need to read from the prefix
3✔
1577
        // keyed section of the bucket.
3✔
1578
        if len(invoice.Htlcs) != 0 {
3✔
1579
                return false
×
1580
        }
×
1581

1582
        // If the "zero" setID was specified, then this means that no HTLC data
1583
        // should be returned alongside of it.
1584
        if len(setIDs) != 0 && setIDs[0] != nil &&
3✔
1585
                *setIDs[0] == invpkg.BlankPayAddr {
6✔
1586

3✔
1587
                return false
3✔
1588
        }
3✔
1589

1590
        return true
3✔
1591
}
1592

1593
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
1594
// an AMP invoice. This methods only decode the relevant state vs the entire
1595
// invoice.
1596
func fetchInvoiceStateAMP(invoiceNum []byte,
UNCOV
1597
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
×
UNCOV
1598

×
UNCOV
1599
        // Fetch the raw invoice bytes.
×
UNCOV
1600
        invoiceBytes := invoices.Get(invoiceNum)
×
UNCOV
1601
        if invoiceBytes == nil {
×
1602
                return nil, invpkg.ErrInvoiceNotFound
×
1603
        }
×
1604

UNCOV
1605
        r := bytes.NewReader(invoiceBytes)
×
UNCOV
1606

×
UNCOV
1607
        var bodyLen int64
×
UNCOV
1608
        err := binary.Read(r, byteOrder, &bodyLen)
×
UNCOV
1609
        if err != nil {
×
1610
                return nil, err
×
1611
        }
×
1612

1613
        // Next, we'll make a new TLV stream that only attempts to decode the
1614
        // bytes we actually need.
UNCOV
1615
        ampState := make(invpkg.AMPInvoiceState)
×
UNCOV
1616
        tlvStream, err := tlv.NewStream(
×
UNCOV
1617
                // Invoice AMP state.
×
UNCOV
1618
                tlv.MakeDynamicRecord(
×
UNCOV
1619
                        invoiceAmpStateType, &ampState, nil,
×
UNCOV
1620
                        ampStateEncoder, ampStateDecoder,
×
UNCOV
1621
                ),
×
UNCOV
1622
        )
×
UNCOV
1623
        if err != nil {
×
1624
                return nil, err
×
1625
        }
×
1626

UNCOV
1627
        invoiceReader := io.LimitReader(r, bodyLen)
×
UNCOV
1628
        if err = tlvStream.Decode(invoiceReader); err != nil {
×
1629
                return nil, err
×
1630
        }
×
1631

UNCOV
1632
        return ampState, nil
×
1633
}
1634

1635
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
3✔
1636
        var (
3✔
1637
                preimageBytes [32]byte
3✔
1638
                value         uint64
3✔
1639
                cltvDelta     uint32
3✔
1640
                expiry        uint64
3✔
1641
                amtPaid       uint64
3✔
1642
                state         uint8
3✔
1643
                hodlInvoice   uint8
3✔
1644

3✔
1645
                creationDateBytes []byte
3✔
1646
                settleDateBytes   []byte
3✔
1647
                featureBytes      []byte
3✔
1648
        )
3✔
1649

3✔
1650
        var i invpkg.Invoice
3✔
1651
        i.AMPState = make(invpkg.AMPInvoiceState)
3✔
1652
        tlvStream, err := tlv.NewStream(
3✔
1653
                // Memo and payreq.
3✔
1654
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1655
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3✔
1656

3✔
1657
                // Add/settle metadata.
3✔
1658
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3✔
1659
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1660
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1661
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1662

3✔
1663
                // Terms.
3✔
1664
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
3✔
1665
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1666
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1667
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1668
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1669
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1670

3✔
1671
                // Invoice state.
3✔
1672
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1673
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1674

3✔
1675
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1676

3✔
1677
                // Invoice AMP state.
3✔
1678
                tlv.MakeDynamicRecord(
3✔
1679
                        invoiceAmpStateType, &i.AMPState, nil,
3✔
1680
                        ampStateEncoder, ampStateDecoder,
3✔
1681
                ),
3✔
1682
        )
3✔
1683
        if err != nil {
3✔
1684
                return i, err
×
1685
        }
×
1686

1687
        var bodyLen int64
3✔
1688
        err = binary.Read(r, byteOrder, &bodyLen)
3✔
1689
        if err != nil {
3✔
1690
                return i, err
×
1691
        }
×
1692

1693
        lr := io.LimitReader(r, bodyLen)
3✔
1694
        if err = tlvStream.Decode(lr); err != nil {
3✔
1695
                return i, err
×
1696
        }
×
1697

1698
        preimage := lntypes.Preimage(preimageBytes)
3✔
1699
        if preimage != invpkg.UnknownPreimage {
6✔
1700
                i.Terms.PaymentPreimage = &preimage
3✔
1701
        }
3✔
1702

1703
        i.Terms.Value = lnwire.MilliSatoshi(value)
3✔
1704
        i.Terms.FinalCltvDelta = int32(cltvDelta)
3✔
1705
        i.Terms.Expiry = time.Duration(expiry)
3✔
1706
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
3✔
1707
        i.State = invpkg.ContractState(state)
3✔
1708

3✔
1709
        if hodlInvoice != 0 {
6✔
1710
                i.HodlInvoice = true
3✔
1711
        }
3✔
1712

1713
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
3✔
1714
        if err != nil {
3✔
1715
                return i, err
×
1716
        }
×
1717

1718
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
3✔
1719
        if err != nil {
3✔
1720
                return i, err
×
1721
        }
×
1722

1723
        rawFeatures := lnwire.NewRawFeatureVector()
3✔
1724
        err = rawFeatures.DecodeBase256(
3✔
1725
                bytes.NewReader(featureBytes), len(featureBytes),
3✔
1726
        )
3✔
1727
        if err != nil {
3✔
1728
                return i, err
×
1729
        }
×
1730

1731
        i.Terms.Features = lnwire.NewFeatureVector(
3✔
1732
                rawFeatures, lnwire.Features,
3✔
1733
        )
3✔
1734

3✔
1735
        i.Htlcs, err = deserializeHtlcs(r)
3✔
1736
        return i, err
3✔
1737
}
1738

1739
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1740
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
6✔
1741
                // We encode the set of circuit keys as a varint length prefix.
3✔
1742
                // followed by a series of fixed sized uint8 integers.
3✔
1743
                numKeys := uint64(len(*v))
3✔
1744

3✔
1745
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
3✔
1746
                        return err
×
1747
                }
×
1748

1749
                for key := range *v {
6✔
1750
                        scidInt := key.ChanID.ToUint64()
3✔
1751

3✔
1752
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
3✔
1753
                                return err
×
1754
                        }
×
1755
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
3✔
1756
                                return err
×
1757
                        }
×
1758
                }
1759

1760
                return nil
3✔
1761
        }
1762

1763
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1764
}
1765

1766
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1767
        l uint64) error {
3✔
1768

3✔
1769
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
6✔
1770
                // First, we'll read out the varint that encodes the number of
3✔
1771
                // circuit keys encoded.
3✔
1772
                numKeys, err := tlv.ReadVarInt(r, buf)
3✔
1773
                if err != nil {
3✔
1774
                        return err
×
1775
                }
×
1776

1777
                // Now that we know how many keys to expect, iterate reading
1778
                // each one until we're done.
1779
                for i := uint64(0); i < numKeys; i++ {
6✔
1780
                        var (
3✔
1781
                                key  models.CircuitKey
3✔
1782
                                scid uint64
3✔
1783
                        )
3✔
1784

3✔
1785
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
3✔
1786
                                return err
×
1787
                        }
×
1788

1789
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
3✔
1790

3✔
1791
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
3✔
1792
                        if err != nil {
3✔
1793
                                return err
×
1794
                        }
×
1795

1796
                        (*v)[key] = struct{}{}
3✔
1797
                }
1798

1799
                return nil
3✔
1800
        }
1801

1802
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
×
1803
}
1804

1805
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1806
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1807
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
6✔
1808
                // We'll encode the AMP state as a series of KV pairs on the
3✔
1809
                // wire with a length prefix.
3✔
1810
                numRecords := uint64(len(*v))
3✔
1811

3✔
1812
                // First, we'll write out the number of records as a var int.
3✔
1813
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
3✔
1814
                        return err
×
1815
                }
×
1816

1817
                // With that written out, we'll now encode the entries
1818
                // themselves as a sub-TLV record, which includes its _own_
1819
                // inner length prefix.
1820
                for setID, ampState := range *v {
6✔
1821
                        setID := [32]byte(setID)
3✔
1822
                        ampState := ampState
3✔
1823

3✔
1824
                        htlcState := uint8(ampState.State)
3✔
1825
                        settleDate := ampState.SettleDate
3✔
1826
                        settleDateBytes, err := settleDate.MarshalBinary()
3✔
1827
                        if err != nil {
3✔
1828
                                return err
×
1829
                        }
×
1830

1831
                        amtPaid := uint64(ampState.AmtPaid)
3✔
1832

3✔
1833
                        var ampStateTlvBytes bytes.Buffer
3✔
1834
                        tlvStream, err := tlv.NewStream(
3✔
1835
                                tlv.MakePrimitiveRecord(
3✔
1836
                                        ampStateSetIDType, &setID,
3✔
1837
                                ),
3✔
1838
                                tlv.MakePrimitiveRecord(
3✔
1839
                                        ampStateHtlcStateType, &htlcState,
3✔
1840
                                ),
3✔
1841
                                tlv.MakePrimitiveRecord(
3✔
1842
                                        ampStateSettleIndexType,
3✔
1843
                                        &ampState.SettleIndex,
3✔
1844
                                ),
3✔
1845
                                tlv.MakePrimitiveRecord(
3✔
1846
                                        ampStateSettleDateType,
3✔
1847
                                        &settleDateBytes,
3✔
1848
                                ),
3✔
1849
                                tlv.MakeDynamicRecord(
3✔
1850
                                        ampStateCircuitKeysType,
3✔
1851
                                        &ampState.InvoiceKeys,
3✔
1852
                                        func() uint64 {
6✔
1853
                                                // The record takes 8 bytes to
3✔
1854
                                                // encode the set of circuits,
3✔
1855
                                                // 8 bytes for the scid for the
3✔
1856
                                                // key, and 8 bytes for the HTLC
3✔
1857
                                                // index.
3✔
1858
                                                keys := ampState.InvoiceKeys
3✔
1859
                                                numKeys := uint64(len(keys))
3✔
1860
                                                size := tlv.VarIntSize(numKeys)
3✔
1861
                                                dataSize := (numKeys * 16)
3✔
1862

3✔
1863
                                                return size + dataSize
3✔
1864
                                        },
3✔
1865
                                        encodeCircuitKeys, decodeCircuitKeys,
1866
                                ),
1867
                                tlv.MakePrimitiveRecord(
1868
                                        ampStateAmtPaidType, &amtPaid,
1869
                                ),
1870
                        )
1871
                        if err != nil {
3✔
1872
                                return err
×
1873
                        }
×
1874

1875
                        err = tlvStream.Encode(&ampStateTlvBytes)
3✔
1876
                        if err != nil {
3✔
1877
                                return err
×
1878
                        }
×
1879

1880
                        // We encode the record with a varint length followed by
1881
                        // the _raw_ TLV bytes.
1882
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
3✔
1883
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
3✔
1884
                                return err
×
1885
                        }
×
1886

1887
                        _, err = w.Write(ampStateTlvBytes.Bytes())
3✔
1888
                        if err != nil {
3✔
1889
                                return err
×
1890
                        }
×
1891
                }
1892

1893
                return nil
3✔
1894
        }
1895

1896
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1897
}
1898

1899
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1900
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
1901
        l uint64) error {
3✔
1902

3✔
1903
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
6✔
1904
                // First, we'll decode the varint that encodes how many set IDs
3✔
1905
                // are encoded within the greater map.
3✔
1906
                numRecords, err := tlv.ReadVarInt(r, buf)
3✔
1907
                if err != nil {
3✔
1908
                        return err
×
1909
                }
×
1910

1911
                // Now that we know how many records we'll need to read, we can
1912
                // iterate and read them all out in series.
1913
                for i := uint64(0); i < numRecords; i++ {
6✔
1914
                        // Read out the varint that encodes the size of this
3✔
1915
                        // inner TLV record.
3✔
1916
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
3✔
1917
                        if err != nil {
3✔
1918
                                return err
×
1919
                        }
×
1920

1921
                        // Using this information, we'll create a new limited
1922
                        // reader that'll return an EOF once the end has been
1923
                        // reached so the stream stops consuming bytes.
1924
                        innerTlvReader := io.LimitedReader{
3✔
1925
                                R: r,
3✔
1926
                                N: int64(stateRecordSize),
3✔
1927
                        }
3✔
1928

3✔
1929
                        var (
3✔
1930
                                setID           [32]byte
3✔
1931
                                htlcState       uint8
3✔
1932
                                settleIndex     uint64
3✔
1933
                                settleDateBytes []byte
3✔
1934
                                invoiceKeys     = make(
3✔
1935
                                        map[models.CircuitKey]struct{},
3✔
1936
                                )
3✔
1937
                                amtPaid uint64
3✔
1938
                        )
3✔
1939
                        tlvStream, err := tlv.NewStream(
3✔
1940
                                tlv.MakePrimitiveRecord(
3✔
1941
                                        ampStateSetIDType, &setID,
3✔
1942
                                ),
3✔
1943
                                tlv.MakePrimitiveRecord(
3✔
1944
                                        ampStateHtlcStateType, &htlcState,
3✔
1945
                                ),
3✔
1946
                                tlv.MakePrimitiveRecord(
3✔
1947
                                        ampStateSettleIndexType, &settleIndex,
3✔
1948
                                ),
3✔
1949
                                tlv.MakePrimitiveRecord(
3✔
1950
                                        ampStateSettleDateType,
3✔
1951
                                        &settleDateBytes,
3✔
1952
                                ),
3✔
1953
                                tlv.MakeDynamicRecord(
3✔
1954
                                        ampStateCircuitKeysType,
3✔
1955
                                        &invoiceKeys, nil,
3✔
1956
                                        encodeCircuitKeys, decodeCircuitKeys,
3✔
1957
                                ),
3✔
1958
                                tlv.MakePrimitiveRecord(
3✔
1959
                                        ampStateAmtPaidType, &amtPaid,
3✔
1960
                                ),
3✔
1961
                        )
3✔
1962
                        if err != nil {
3✔
1963
                                return err
×
1964
                        }
×
1965

1966
                        err = tlvStream.Decode(&innerTlvReader)
3✔
1967
                        if err != nil {
3✔
1968
                                return err
×
1969
                        }
×
1970

1971
                        var settleDate time.Time
3✔
1972
                        err = settleDate.UnmarshalBinary(settleDateBytes)
3✔
1973
                        if err != nil {
3✔
1974
                                return err
×
1975
                        }
×
1976

1977
                        (*v)[setID] = invpkg.InvoiceStateAMP{
3✔
1978
                                State:       invpkg.HtlcState(htlcState),
3✔
1979
                                SettleIndex: settleIndex,
3✔
1980
                                SettleDate:  settleDate,
3✔
1981
                                InvoiceKeys: invoiceKeys,
3✔
1982
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
3✔
1983
                        }
3✔
1984
                }
1985

1986
                return nil
3✔
1987
        }
1988

1989
        return tlv.NewTypeForDecodingErr(
×
1990
                val, "channeldb.AMPInvoiceState", l, l,
×
1991
        )
×
1992
}
1993

1994
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
1995
// as a map.
1996
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1997
        error) {
3✔
1998

3✔
1999
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
2000
        for {
6✔
2001
                // Read the length of the tlv stream for this htlc.
3✔
2002
                var streamLen int64
3✔
2003
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
6✔
2004
                        if err == io.EOF {
6✔
2005
                                break
3✔
2006
                        }
2007

2008
                        return nil, err
×
2009
                }
2010

2011
                // Limit the reader so that it stops at the end of this htlc's
2012
                // stream.
2013
                htlcReader := io.LimitReader(r, streamLen)
3✔
2014

3✔
2015
                // Decode the contents into the htlc fields.
3✔
2016
                var (
3✔
2017
                        htlc                    invpkg.InvoiceHTLC
3✔
2018
                        key                     models.CircuitKey
3✔
2019
                        chanID                  uint64
3✔
2020
                        state                   uint8
3✔
2021
                        acceptTime, resolveTime uint64
3✔
2022
                        amt, mppTotalAmt        uint64
3✔
2023
                        amp                     = &record.AMP{}
3✔
2024
                        hash32                  = &[32]byte{}
3✔
2025
                        preimage32              = &[32]byte{}
3✔
2026
                )
3✔
2027
                tlvStream, err := tlv.NewStream(
3✔
2028
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
2029
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
2030
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
2031
                        tlv.MakePrimitiveRecord(
3✔
2032
                                acceptHeightType, &htlc.AcceptHeight,
3✔
2033
                        ),
3✔
2034
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
2035
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
2036
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
2037
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
2038
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
2039
                        tlv.MakeDynamicRecord(
3✔
2040
                                htlcAMPType, amp, amp.PayloadSize,
3✔
2041
                                record.AMPEncoder, record.AMPDecoder,
3✔
2042
                        ),
3✔
2043
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
3✔
2044
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
3✔
2045
                )
3✔
2046
                if err != nil {
3✔
2047
                        return nil, err
×
2048
                }
×
2049

2050
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
3✔
2051
                if err != nil {
3✔
2052
                        return nil, err
×
2053
                }
×
2054

2055
                if _, ok := parsedTypes[htlcAMPType]; !ok {
6✔
2056
                        amp = nil
3✔
2057
                }
3✔
2058

2059
                var preimage *lntypes.Preimage
3✔
2060
                if _, ok := parsedTypes[htlcPreimageType]; ok {
6✔
2061
                        pimg := lntypes.Preimage(*preimage32)
3✔
2062
                        preimage = &pimg
3✔
2063
                }
3✔
2064

2065
                var hash *lntypes.Hash
3✔
2066
                if _, ok := parsedTypes[htlcHashType]; ok {
6✔
2067
                        h := lntypes.Hash(*hash32)
3✔
2068
                        hash = &h
3✔
2069
                }
3✔
2070

2071
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
3✔
2072
                htlc.AcceptTime = getNanoTime(acceptTime)
3✔
2073
                htlc.ResolveTime = getNanoTime(resolveTime)
3✔
2074
                htlc.State = invpkg.HtlcState(state)
3✔
2075
                htlc.Amt = lnwire.MilliSatoshi(amt)
3✔
2076
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
3✔
2077
                if amp != nil && hash != nil {
6✔
2078
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
3✔
2079
                                Record:   *amp,
3✔
2080
                                Hash:     *hash,
3✔
2081
                                Preimage: preimage,
3✔
2082
                        }
3✔
2083
                }
3✔
2084

2085
                // Reconstruct the custom records fields from the parsed types
2086
                // map return from the tlv parser.
2087
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
3✔
2088

3✔
2089
                htlcs[key] = &htlc
3✔
2090
        }
2091

2092
        return htlcs, nil
3✔
2093
}
2094

2095
// invoiceSetIDKeyLen is the length of the key that's used to store the
2096
// individual HTLCs prefixed by their ID along side the main invoice within the
2097
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2098
// set ID.
2099
const invoiceSetIDKeyLen = 4 + 32
2100

2101
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2102
// number where the HTLCs for this setID will be stored udner.
2103
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
3✔
2104
        // Construct the prefix key we need to obtain the invoice information:
3✔
2105
        // invoiceNum || setID.
3✔
2106
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
3✔
2107
        copy(invoiceSetIDKey[:], invoiceNum)
3✔
2108
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
3✔
2109

3✔
2110
        return invoiceSetIDKey
3✔
2111
}
3✔
2112

2113
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
2114
// greater AMP invoices. We do this by deleting the set of keys that share the
2115
// invoice number as a prefix.
UNCOV
2116
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
×
UNCOV
2117
        // Since it isn't safe to delete using an active cursor, we'll use the
×
UNCOV
2118
        // cursor simply to collect the set of keys we need to delete, _then_
×
UNCOV
2119
        // delete them in another pass.
×
UNCOV
2120
        var keysToDel [][]byte
×
UNCOV
2121
        err := forEachAMPInvoice(
×
UNCOV
2122
                invoiceBucket, invoiceNum,
×
UNCOV
2123
                func(cursorKey, v []byte) error {
×
UNCOV
2124
                        keysToDel = append(keysToDel, cursorKey)
×
UNCOV
2125
                        return nil
×
UNCOV
2126
                },
×
2127
        )
UNCOV
2128
        if err != nil {
×
2129
                return err
×
2130
        }
×
2131

2132
        // In this next phase, we'll then delete all the relevant invoices.
UNCOV
2133
        for _, keyToDel := range keysToDel {
×
UNCOV
2134
                if err := invoiceBucket.Delete(keyToDel); err != nil {
×
2135
                        return err
×
2136
                }
×
2137
        }
2138

UNCOV
2139
        return nil
×
2140
}
2141

2142
// delAMPSettleIndex removes all the entries in the settle index associated
2143
// with a given AMP invoice.
2144
func delAMPSettleIndex(invoiceNum []byte, invoices,
UNCOV
2145
        settleIndex kvdb.RwBucket) error {
×
UNCOV
2146

×
UNCOV
2147
        // First, we need to grab the AMP invoice state to see if there's
×
UNCOV
2148
        // anything that we even need to delete.
×
UNCOV
2149
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
×
UNCOV
2150
        if err != nil {
×
2151
                return err
×
2152
        }
×
2153

2154
        // If there's no AMP state at all (non-AMP invoice), then we can return
2155
        // early.
UNCOV
2156
        if len(ampState) == 0 {
×
UNCOV
2157
                return nil
×
UNCOV
2158
        }
×
2159

2160
        // Otherwise, we'll need to iterate and delete each settle index within
2161
        // the set of returned entries.
UNCOV
2162
        var settleIndexKey [8]byte
×
UNCOV
2163
        for _, subState := range ampState {
×
UNCOV
2164
                byteOrder.PutUint64(
×
UNCOV
2165
                        settleIndexKey[:], subState.SettleIndex,
×
UNCOV
2166
                )
×
UNCOV
2167

×
UNCOV
2168
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
×
2169
                        return err
×
2170
                }
×
2171
        }
2172

UNCOV
2173
        return nil
×
2174
}
2175

2176
// DeleteCanceledInvoices deletes all canceled invoices from the database.
UNCOV
2177
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
×
UNCOV
2178
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2179
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2180
                if invoices == nil {
×
2181
                        return nil
×
2182
                }
×
2183

UNCOV
2184
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2185
                        invoiceIndexBucket,
×
UNCOV
2186
                )
×
UNCOV
2187
                if invoiceIndex == nil {
×
UNCOV
2188
                        return nil
×
UNCOV
2189
                }
×
2190

UNCOV
2191
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2192
                        addIndexBucket,
×
UNCOV
2193
                )
×
UNCOV
2194
                if invoiceAddIndex == nil {
×
2195
                        return nil
×
2196
                }
×
2197

UNCOV
2198
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
UNCOV
2199

×
UNCOV
2200
                return invoiceIndex.ForEach(func(k, v []byte) error {
×
UNCOV
2201
                        // Skip the special numInvoicesKey as that does not
×
UNCOV
2202
                        // point to a valid invoice.
×
UNCOV
2203
                        if bytes.Equal(k, numInvoicesKey) {
×
UNCOV
2204
                                return nil
×
UNCOV
2205
                        }
×
2206

2207
                        // Skip sub-buckets.
UNCOV
2208
                        if v == nil {
×
2209
                                return nil
×
2210
                        }
×
2211

UNCOV
2212
                        invoice, err := fetchInvoice(v, invoices, nil, false)
×
UNCOV
2213
                        if err != nil {
×
2214
                                return err
×
2215
                        }
×
2216

UNCOV
2217
                        if invoice.State != invpkg.ContractCanceled {
×
UNCOV
2218
                                return nil
×
UNCOV
2219
                        }
×
2220

2221
                        // Delete the payment hash from the invoice index.
UNCOV
2222
                        err = invoiceIndex.Delete(k)
×
UNCOV
2223
                        if err != nil {
×
2224
                                return err
×
2225
                        }
×
2226

2227
                        // Delete payment address index reference if there's a
2228
                        // valid payment address.
UNCOV
2229
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
×
UNCOV
2230
                                // To ensure consistency check that the already
×
UNCOV
2231
                                // fetched invoice key matches the one in the
×
UNCOV
2232
                                // payment address index.
×
UNCOV
2233
                                key := payAddrIndex.Get(
×
UNCOV
2234
                                        invoice.Terms.PaymentAddr[:],
×
UNCOV
2235
                                )
×
UNCOV
2236
                                if bytes.Equal(key, k) {
×
2237
                                        // Delete from the payment address
×
2238
                                        // index.
×
2239
                                        if err := payAddrIndex.Delete(
×
2240
                                                invoice.Terms.PaymentAddr[:],
×
2241
                                        ); err != nil {
×
2242
                                                return err
×
2243
                                        }
×
2244
                                }
2245
                        }
2246

2247
                        // Remove from the add index.
UNCOV
2248
                        var addIndexKey [8]byte
×
UNCOV
2249
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
×
UNCOV
2250
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
UNCOV
2251
                        if err != nil {
×
2252
                                return err
×
2253
                        }
×
2254

2255
                        // Note that we don't need to delete the invoice from
2256
                        // the settle index as it is not added until the
2257
                        // invoice is settled.
2258

2259
                        // Now remove all sub invoices.
UNCOV
2260
                        err = delAMPInvoices(k, invoices)
×
UNCOV
2261
                        if err != nil {
×
2262
                                return err
×
2263
                        }
×
2264

2265
                        // Finally remove the serialized invoice from the
2266
                        // invoice bucket.
UNCOV
2267
                        return invoices.Delete(k)
×
2268
                })
UNCOV
2269
        }, func() {})
×
2270
}
2271

2272
// DeleteInvoice attempts to delete the passed invoices from the database in
2273
// one transaction. The passed delete references hold all keys required to
2274
// delete the invoices without also needing to deserialize them.
2275
func (d *DB) DeleteInvoice(_ context.Context,
UNCOV
2276
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
×
UNCOV
2277

×
UNCOV
2278
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2279
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2280
                if invoices == nil {
×
2281
                        return invpkg.ErrNoInvoicesCreated
×
2282
                }
×
2283

UNCOV
2284
                invoiceIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2285
                        invoiceIndexBucket,
×
UNCOV
2286
                )
×
UNCOV
2287
                if invoiceIndex == nil {
×
2288
                        return invpkg.ErrNoInvoicesCreated
×
2289
                }
×
2290

UNCOV
2291
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
UNCOV
2292
                        addIndexBucket,
×
UNCOV
2293
                )
×
UNCOV
2294
                if invoiceAddIndex == nil {
×
2295
                        return invpkg.ErrNoInvoicesCreated
×
2296
                }
×
2297

2298
                // settleIndex can be nil, as the bucket is created lazily
2299
                // when the first invoice is settled.
UNCOV
2300
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
×
UNCOV
2301

×
UNCOV
2302
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
UNCOV
2303

×
UNCOV
2304
                for _, ref := range invoicesToDelete {
×
UNCOV
2305
                        // Fetch the invoice key for using it to check for
×
UNCOV
2306
                        // consistency and also to delete from the invoice
×
UNCOV
2307
                        // index.
×
UNCOV
2308
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
×
UNCOV
2309
                        if invoiceKey == nil {
×
UNCOV
2310
                                return invpkg.ErrInvoiceNotFound
×
UNCOV
2311
                        }
×
2312

UNCOV
2313
                        err := invoiceIndex.Delete(ref.PayHash[:])
×
UNCOV
2314
                        if err != nil {
×
2315
                                return err
×
2316
                        }
×
2317

2318
                        // Delete payment address index reference if there's a
2319
                        // valid payment address passed.
UNCOV
2320
                        if ref.PayAddr != nil {
×
UNCOV
2321
                                // To ensure consistency check that the already
×
UNCOV
2322
                                // fetched invoice key matches the one in the
×
UNCOV
2323
                                // payment address index.
×
UNCOV
2324
                                key := payAddrIndex.Get(ref.PayAddr[:])
×
UNCOV
2325
                                if bytes.Equal(key, invoiceKey) {
×
UNCOV
2326
                                        // Delete from the payment address
×
UNCOV
2327
                                        // index. Note that since the payment
×
UNCOV
2328
                                        // address index has been introduced
×
UNCOV
2329
                                        // with an empty migration it may be
×
UNCOV
2330
                                        // possible that the index doesn't have
×
UNCOV
2331
                                        // an entry for this invoice.
×
UNCOV
2332
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
×
UNCOV
2333
                                        if err := payAddrIndex.Delete(
×
UNCOV
2334
                                                ref.PayAddr[:],
×
UNCOV
2335
                                        ); err != nil {
×
2336
                                                return err
×
2337
                                        }
×
2338
                                }
2339
                        }
2340

UNCOV
2341
                        var addIndexKey [8]byte
×
UNCOV
2342
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
×
UNCOV
2343

×
UNCOV
2344
                        // To ensure consistency check that the key stored in
×
UNCOV
2345
                        // the add index also matches the previously fetched
×
UNCOV
2346
                        // invoice key.
×
UNCOV
2347
                        key := invoiceAddIndex.Get(addIndexKey[:])
×
UNCOV
2348
                        if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2349
                                return fmt.Errorf("unknown invoice " +
×
UNCOV
2350
                                        "in add index")
×
UNCOV
2351
                        }
×
2352

2353
                        // Remove from the add index.
UNCOV
2354
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
UNCOV
2355
                        if err != nil {
×
2356
                                return err
×
2357
                        }
×
2358

2359
                        // Remove from the settle index if available and
2360
                        // if the invoice is settled.
UNCOV
2361
                        if settleIndex != nil && ref.SettleIndex > 0 {
×
UNCOV
2362
                                var settleIndexKey [8]byte
×
UNCOV
2363
                                byteOrder.PutUint64(
×
UNCOV
2364
                                        settleIndexKey[:], ref.SettleIndex,
×
UNCOV
2365
                                )
×
UNCOV
2366

×
UNCOV
2367
                                // To ensure consistency check that the already
×
UNCOV
2368
                                // fetched invoice key matches the one in the
×
UNCOV
2369
                                // settle index
×
UNCOV
2370
                                key := settleIndex.Get(settleIndexKey[:])
×
UNCOV
2371
                                if !bytes.Equal(key, invoiceKey) {
×
UNCOV
2372
                                        return fmt.Errorf("unknown invoice " +
×
UNCOV
2373
                                                "in settle index")
×
UNCOV
2374
                                }
×
2375

UNCOV
2376
                                err = settleIndex.Delete(settleIndexKey[:])
×
UNCOV
2377
                                if err != nil {
×
2378
                                        return err
×
2379
                                }
×
2380
                        }
2381

2382
                        // In addition to deleting the main invoice state, if
2383
                        // this is an AMP invoice, then we'll also need to
2384
                        // delete the set HTLC set stored as a key prefix. For
2385
                        // non-AMP invoices, this'll be a noop.
UNCOV
2386
                        err = delAMPSettleIndex(
×
UNCOV
2387
                                invoiceKey, invoices, settleIndex,
×
UNCOV
2388
                        )
×
UNCOV
2389
                        if err != nil {
×
2390
                                return err
×
2391
                        }
×
UNCOV
2392
                        err = delAMPInvoices(invoiceKey, invoices)
×
UNCOV
2393
                        if err != nil {
×
2394
                                return err
×
2395
                        }
×
2396

2397
                        // Finally remove the serialized invoice from the
2398
                        // invoice bucket.
UNCOV
2399
                        err = invoices.Delete(invoiceKey)
×
UNCOV
2400
                        if err != nil {
×
2401
                                return err
×
2402
                        }
×
2403
                }
2404

UNCOV
2405
                return nil
×
UNCOV
2406
        }, func() {})
×
2407

UNCOV
2408
        return err
×
2409
}
2410

2411
// SetInvoiceBucketTombstone sets the tombstone key in the invoice bucket to
2412
// mark the bucket as permanently closed. This prevents it from being reopened
2413
// in the future.
UNCOV
2414
func (d *DB) SetInvoiceBucketTombstone() error {
×
UNCOV
2415
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
×
UNCOV
2416
                // Access the top-level invoice bucket.
×
UNCOV
2417
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
UNCOV
2418
                if invoices == nil {
×
2419
                        return fmt.Errorf("invoice bucket does not exist")
×
2420
                }
×
2421

2422
                // Add the tombstone key to the invoice bucket.
UNCOV
2423
                err := invoices.Put(invoiceBucketTombstone, []byte("1"))
×
UNCOV
2424
                if err != nil {
×
2425
                        return fmt.Errorf("failed to set tombstone: %w", err)
×
2426
                }
×
2427

UNCOV
2428
                return nil
×
UNCOV
2429
        }, func() {})
×
2430
}
2431

2432
// GetInvoiceBucketTombstone checks if the tombstone key exists in the invoice
2433
// bucket. It returns true if the tombstone is present and false otherwise.
2434
func (d *DB) GetInvoiceBucketTombstone() (bool, error) {
3✔
2435
        var tombstoneExists bool
3✔
2436

3✔
2437
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
2438
                // Access the top-level invoice bucket.
3✔
2439
                invoices := tx.ReadBucket(invoiceBucket)
3✔
2440
                if invoices == nil {
3✔
2441
                        return fmt.Errorf("invoice bucket does not exist")
×
2442
                }
×
2443

2444
                // Check if the tombstone key exists.
2445
                tombstone := invoices.Get(invoiceBucketTombstone)
3✔
2446
                tombstoneExists = tombstone != nil
3✔
2447

3✔
2448
                return nil
3✔
2449
        }, func() {})
3✔
2450
        if err != nil {
3✔
2451
                return false, err
×
2452
        }
×
2453

2454
        return tombstoneExists, nil
3✔
2455
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc