• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 11216766535

07 Oct 2024 01:37PM UTC coverage: 57.817% (-1.0%) from 58.817%
11216766535

Pull #9148

github

ProofOfKeags
lnwire: remove kickoff feerate from propose/commit
Pull Request #9148: DynComms [2/n]: lnwire: add authenticated wire messages for Dyn*

571 of 879 new or added lines in 16 files covered. (64.96%)

23253 existing lines in 251 files now uncovered.

99022 of 171268 relevant lines covered (57.82%)

38420.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.32
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "time"
11

12
        "github.com/lightningnetwork/lnd/channeldb/models"
13
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
14
        invpkg "github.com/lightningnetwork/lnd/invoices"
15
        "github.com/lightningnetwork/lnd/kvdb"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/tlv"
20
)
21

22
var (
23
        // invoiceBucket is the name of the bucket within the database that
24
        // stores all data related to invoices no matter their final state.
25
        // Within the invoice bucket, each invoice is keyed by its invoice ID
26
        // which is a monotonically increasing uint32.
27
        invoiceBucket = []byte("invoices")
28

29
        // paymentHashIndexBucket is the name of the sub-bucket within the
30
        // invoiceBucket which indexes all invoices by their payment hash. The
31
        // payment hash is the sha256 of the invoice's payment preimage. This
32
        // index is used to detect duplicates, and also to provide a fast path
33
        // for looking up incoming HTLCs to determine if we're able to settle
34
        // them fully.
35
        //
36
        // maps: payHash => invoiceKey
37
        invoiceIndexBucket = []byte("paymenthashes")
38

39
        // payAddrIndexBucket is the name of the top-level bucket that maps
40
        // payment addresses to their invoice number. This can be used
41
        // to efficiently query or update non-legacy invoices. Note that legacy
42
        // invoices will not be included in this index since they all have the
43
        // same, all-zero payment address, however all newly generated invoices
44
        // will end up in this index.
45
        //
46
        // maps: payAddr => invoiceKey
47
        payAddrIndexBucket = []byte("pay-addr-index")
48

49
        // setIDIndexBucket is the name of the top-level bucket that maps set
50
        // ids to their invoice number. This can be used to efficiently query or
51
        // update AMP invoice. Note that legacy or MPP invoices will not be
52
        // included in this index, since their HTLCs do not have a set id.
53
        //
54
        // maps: setID => invoiceKey
55
        setIDIndexBucket = []byte("set-id-index")
56

57
        // numInvoicesKey is the name of key which houses the auto-incrementing
58
        // invoice ID which is essentially used as a primary key. With each
59
        // invoice inserted, the primary key is incremented by one. This key is
60
        // stored within the invoiceIndexBucket. Within the invoiceBucket
61
        // invoices are uniquely identified by the invoice ID.
62
        numInvoicesKey = []byte("nik")
63

64
        // addIndexBucket is an index bucket that we'll use to create a
65
        // monotonically increasing set of add indexes. Each time we add a new
66
        // invoice, this sequence number will be incremented and then populated
67
        // within the new invoice.
68
        //
69
        // In addition to this sequence number, we map:
70
        //
71
        //   addIndexNo => invoiceKey
72
        addIndexBucket = []byte("invoice-add-index")
73

74
        // settleIndexBucket is an index bucket that we'll use to create a
75
        // monotonically increasing integer for tracking a "settle index". Each
76
        // time an invoice is settled, this sequence number will be incremented
77
        // as populate within the newly settled invoice.
78
        //
79
        // In addition to this sequence number, we map:
80
        //
81
        //   settleIndexNo => invoiceKey
82
        settleIndexBucket = []byte("invoice-settle-index")
83
)
84

85
const (
86
        // A set of tlv type definitions used to serialize invoice htlcs to the
87
        // database.
88
        //
89
        // NOTE: A migration should be added whenever this list changes. This
90
        // prevents against the database being rolled back to an older
91
        // format where the surrounding logic might assume a different set of
92
        // fields are known.
93
        chanIDType       tlv.Type = 1
94
        htlcIDType       tlv.Type = 3
95
        amtType          tlv.Type = 5
96
        acceptHeightType tlv.Type = 7
97
        acceptTimeType   tlv.Type = 9
98
        resolveTimeType  tlv.Type = 11
99
        expiryHeightType tlv.Type = 13
100
        htlcStateType    tlv.Type = 15
101
        mppTotalAmtType  tlv.Type = 17
102
        htlcAMPType      tlv.Type = 19
103
        htlcHashType     tlv.Type = 21
104
        htlcPreimageType tlv.Type = 23
105

106
        // A set of tlv type definitions used to serialize invoice bodiees.
107
        //
108
        // NOTE: A migration should be added whenever this list changes. This
109
        // prevents against the database being rolled back to an older
110
        // format where the surrounding logic might assume a different set of
111
        // fields are known.
112
        memoType            tlv.Type = 0
113
        payReqType          tlv.Type = 1
114
        createTimeType      tlv.Type = 2
115
        settleTimeType      tlv.Type = 3
116
        addIndexType        tlv.Type = 4
117
        settleIndexType     tlv.Type = 5
118
        preimageType        tlv.Type = 6
119
        valueType           tlv.Type = 7
120
        cltvDeltaType       tlv.Type = 8
121
        expiryType          tlv.Type = 9
122
        paymentAddrType     tlv.Type = 10
123
        featuresType        tlv.Type = 11
124
        invStateType        tlv.Type = 12
125
        amtPaidType         tlv.Type = 13
126
        hodlInvoiceType     tlv.Type = 14
127
        invoiceAmpStateType tlv.Type = 15
128

129
        // A set of tlv type definitions used to serialize the invoice AMP
130
        // state along-side the main invoice body.
131
        ampStateSetIDType       tlv.Type = 0
132
        ampStateHtlcStateType   tlv.Type = 1
133
        ampStateSettleIndexType tlv.Type = 2
134
        ampStateSettleDateType  tlv.Type = 3
135
        ampStateCircuitKeysType tlv.Type = 4
136
        ampStateAmtPaidType     tlv.Type = 5
137
)
138

139
// AddInvoice inserts the targeted invoice into the database. If the invoice has
140
// *any* payment hashes which already exists within the database, then the
141
// insertion will be aborted and rejected due to the strict policy banning any
142
// duplicate payment hashes. A side effect of this function is that it sets
143
// AddIndex on newInvoice.
144
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
145
        paymentHash lntypes.Hash) (uint64, error) {
1,055✔
146

1,055✔
147
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
1,057✔
148
                return 0, err
2✔
149
        }
2✔
150

151
        var invoiceAddIndex uint64
1,053✔
152
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
2,106✔
153
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
1,053✔
154
                if err != nil {
1,053✔
155
                        return err
×
156
                }
×
157

158
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
1,053✔
159
                        invoiceIndexBucket,
1,053✔
160
                )
1,053✔
161
                if err != nil {
1,053✔
162
                        return err
×
163
                }
×
164
                addIndex, err := invoices.CreateBucketIfNotExists(
1,053✔
165
                        addIndexBucket,
1,053✔
166
                )
1,053✔
167
                if err != nil {
1,053✔
168
                        return err
×
169
                }
×
170

171
                // Ensure that an invoice an identical payment hash doesn't
172
                // already exist within the index.
173
                if invoiceIndex.Get(paymentHash[:]) != nil {
1,056✔
174
                        return invpkg.ErrDuplicateInvoice
3✔
175
                }
3✔
176

177
                // Check that we aren't inserting an invoice with a duplicate
178
                // payment address. The all-zeros payment address is
179
                // special-cased to support legacy keysend invoices which don't
180
                // assign one. This is safe since later we also will avoid
181
                // indexing them and avoid collisions.
182
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
1,050✔
183
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
1,979✔
184
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
929✔
185
                        if payAddrIndex.Get(paymentAddr) != nil {
934✔
186
                                return invpkg.ErrDuplicatePayAddr
5✔
187
                        }
5✔
188
                }
189

190
                // If the current running payment ID counter hasn't yet been
191
                // created, then create it now.
192
                var invoiceNum uint32
1,045✔
193
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
1,045✔
194
                if invoiceCounter == nil {
1,241✔
195
                        var scratch [4]byte
196✔
196
                        byteOrder.PutUint32(scratch[:], invoiceNum)
196✔
197
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
196✔
198
                        if err != nil {
196✔
199
                                return err
×
200
                        }
×
201
                } else {
849✔
202
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
849✔
203
                }
849✔
204

205
                newIndex, err := putInvoice(
1,045✔
206
                        invoices, invoiceIndex, payAddrIndex, addIndex,
1,045✔
207
                        newInvoice, invoiceNum, paymentHash,
1,045✔
208
                )
1,045✔
209
                if err != nil {
1,045✔
210
                        return err
×
211
                }
×
212

213
                invoiceAddIndex = newIndex
1,045✔
214
                return nil
1,045✔
215
        }, func() {
1,053✔
216
                invoiceAddIndex = 0
1,053✔
217
        })
1,053✔
218
        if err != nil {
1,061✔
219
                return 0, err
8✔
220
        }
8✔
221

222
        return invoiceAddIndex, err
1,045✔
223
}
224

225
// InvoicesAddedSince can be used by callers to seek into the event time series
226
// of all the invoices added in the database. The specified sinceAddIndex
227
// should be the highest add index that the caller knows of. This method will
228
// return all invoices with an add index greater than the specified
229
// sinceAddIndex.
230
//
231
// NOTE: The index starts from 1, as a result. We enforce that specifying a
232
// value below the starting index value is a noop.
233
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
234
        []invpkg.Invoice, error) {
20✔
235

20✔
236
        var newInvoices []invpkg.Invoice
20✔
237

20✔
238
        // If an index of zero was specified, then in order to maintain
20✔
239
        // backwards compat, we won't send out any new invoices.
20✔
240
        if sinceAddIndex == 0 {
37✔
241
                return newInvoices, nil
17✔
242
        }
17✔
243

244
        var startIndex [8]byte
3✔
245
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
3✔
246

3✔
247
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
248
                invoices := tx.ReadBucket(invoiceBucket)
3✔
249
                if invoices == nil {
3✔
250
                        return nil
×
251
                }
×
252

253
                addIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
254
                if addIndex == nil {
3✔
255
                        return nil
×
256
                }
×
257

258
                // We'll now run through each entry in the add index starting
259
                // at our starting index. We'll continue until we reach the
260
                // very end of the current key space.
261
                invoiceCursor := addIndex.ReadCursor()
3✔
262

3✔
263
                // We'll seek to the starting index, then manually advance the
3✔
264
                // cursor in order to skip the entry with the since add index.
3✔
265
                invoiceCursor.Seek(startIndex[:])
3✔
266
                addSeqNo, invoiceKey := invoiceCursor.Next()
3✔
267

3✔
268
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
32✔
269

29✔
270
                        // For each key found, we'll look up the actual
29✔
271
                        // invoice, then accumulate it into our return value.
29✔
272
                        invoice, err := fetchInvoice(
29✔
273
                                invoiceKey, invoices, nil, false,
29✔
274
                        )
29✔
275
                        if err != nil {
29✔
276
                                return err
×
277
                        }
×
278

279
                        newInvoices = append(newInvoices, invoice)
29✔
280
                }
281

282
                return nil
3✔
283
        }, func() {
3✔
284
                newInvoices = nil
3✔
285
        })
3✔
286
        if err != nil {
3✔
287
                return nil, err
×
288
        }
×
289

290
        return newInvoices, nil
3✔
291
}
292

293
// LookupInvoice attempts to look up an invoice according to its 32 byte
294
// payment hash. If an invoice which can settle the HTLC identified by the
295
// passed payment hash isn't found, then an error is returned. Otherwise, the
296
// full invoice is returned. Before setting the incoming HTLC, the values
297
// SHOULD be checked to ensure the payer meets the agreed upon contractual
298
// terms of the payment.
299
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
300
        invpkg.Invoice, error) {
1,058✔
301

1,058✔
302
        var invoice invpkg.Invoice
1,058✔
303
        err := kvdb.View(d, func(tx kvdb.RTx) error {
2,116✔
304
                invoices := tx.ReadBucket(invoiceBucket)
1,058✔
305
                if invoices == nil {
1,058✔
306
                        return invpkg.ErrNoInvoicesCreated
×
307
                }
×
308
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
1,058✔
309
                if invoiceIndex == nil {
1,071✔
310
                        return invpkg.ErrNoInvoicesCreated
13✔
311
                }
13✔
312
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
1,045✔
313
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
1,045✔
314

1,045✔
315
                // Retrieve the invoice number for this invoice using
1,045✔
316
                // the provided invoice reference.
1,045✔
317
                invoiceNum, err := fetchInvoiceNumByRef(
1,045✔
318
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
1,045✔
319
                )
1,045✔
320
                if err != nil {
1,053✔
321
                        return err
8✔
322
                }
8✔
323

324
                var setID *invpkg.SetID
1,037✔
325
                switch {
1,037✔
326
                // If this is a payment address ref, and the blank modified was
327
                // specified, then we'll use the zero set ID to indicate that
328
                // we won't want any HTLCs returned.
329
                case ref.PayAddr() != nil &&
330
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
1✔
331

1✔
332
                        var zeroSetID invpkg.SetID
1✔
333
                        setID = &zeroSetID
1✔
334

335
                // If this is a set ID ref, and the htlc set only modified was
336
                // specified, then we'll pass through the specified setID so
337
                // only that will be returned.
338
                case ref.SetID() != nil &&
339
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
4✔
340

4✔
341
                        setID = (*invpkg.SetID)(ref.SetID())
4✔
342
                }
343

344
                // An invoice was found, retrieve the remainder of the invoice
345
                // body.
346
                i, err := fetchInvoice(
1,037✔
347
                        invoiceNum, invoices, []*invpkg.SetID{setID}, true,
1,037✔
348
                )
1,037✔
349
                if err != nil {
1,037✔
350
                        return err
×
351
                }
×
352
                invoice = i
1,037✔
353

1,037✔
354
                return nil
1,037✔
355
        }, func() {})
1,058✔
356
        if err != nil {
1,079✔
357
                return invoice, err
21✔
358
        }
21✔
359

360
        return invoice, nil
1,037✔
361
}
362

363
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
364
// reference. The payment address will be treated as the primary key, falling
365
// back to the payment hash if nothing is found for the payment address. An
366
// error is returned if the invoice is not found.
367
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
368
        ref invpkg.InvoiceRef) ([]byte, error) {
2,536✔
369

2,536✔
370
        // If the set id is present, we only consult the set id index for this
2,536✔
371
        // invoice. This type of query is only used to facilitate user-facing
2,536✔
372
        // requests to lookup, settle or cancel an AMP invoice.
2,536✔
373
        setID := ref.SetID()
2,536✔
374
        if setID != nil {
2,543✔
375
                invoiceNumBySetID := setIDIndex.Get(setID[:])
7✔
376
                if invoiceNumBySetID == nil {
8✔
377
                        return nil, invpkg.ErrInvoiceNotFound
1✔
378
                }
1✔
379

380
                return invoiceNumBySetID, nil
6✔
381
        }
382

383
        payHash := ref.PayHash()
2,529✔
384
        payAddr := ref.PayAddr()
2,529✔
385

2,529✔
386
        getInvoiceNumByHash := func() []byte {
5,058✔
387
                if payHash != nil {
5,034✔
388
                        return invoiceIndex.Get(payHash[:])
2,505✔
389
                }
2,505✔
390
                return nil
24✔
391
        }
392

393
        getInvoiceNumByAddr := func() []byte {
5,058✔
394
                if payAddr != nil {
3,010✔
395
                        // Only allow lookups for payment address if it is not a
481✔
396
                        // blank payment address, which is a special-cased value
481✔
397
                        // for legacy keysend invoices.
481✔
398
                        if *payAddr != invpkg.BlankPayAddr {
536✔
399
                                return payAddrIndex.Get(payAddr[:])
55✔
400
                        }
55✔
401
                }
402
                return nil
2,474✔
403
        }
404

405
        invoiceNumByHash := getInvoiceNumByHash()
2,529✔
406
        invoiceNumByAddr := getInvoiceNumByAddr()
2,529✔
407
        switch {
2,529✔
408
        // If payment address and payment hash both reference an existing
409
        // invoice, ensure they reference the _same_ invoice.
410
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
31✔
411
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
33✔
412
                        return nil, invpkg.ErrInvRefEquivocation
2✔
413
                }
2✔
414

415
                return invoiceNumByAddr, nil
29✔
416

417
        // Return invoices by payment addr only.
418
        //
419
        // NOTE: We constrain this lookup to only apply if the invoice ref does
420
        // not contain a payment hash. Legacy and MPP payments depend on the
421
        // payment hash index to enforce that the HTLCs payment hash matches the
422
        // payment hash for the invoice, without this check we would
423
        // inadvertently assume the invoice contains the correct preimage for
424
        // the HTLC, which we only enforce via the lookup by the invoice index.
425
        case invoiceNumByAddr != nil && payHash == nil:
23✔
426
                return invoiceNumByAddr, nil
23✔
427

428
        // If we were only able to reference the invoice by hash, return the
429
        // corresponding invoice number. This can happen when no payment address
430
        // was provided, or if it didn't match anything in our records.
431
        case invoiceNumByHash != nil:
2,466✔
432
                return invoiceNumByHash, nil
2,466✔
433

434
        // Otherwise we don't know of the target invoice.
435
        default:
9✔
436
                return nil, invpkg.ErrInvoiceNotFound
9✔
437
        }
438
}
439

440
// FetchPendingInvoices returns all invoices that have not yet been settled or
441
// canceled. The returned map is keyed by the payment hash of each respective
442
// invoice.
443
func (d *DB) FetchPendingInvoices(_ context.Context) (
444
        map[lntypes.Hash]invpkg.Invoice, error) {
386✔
445

386✔
446
        result := make(map[lntypes.Hash]invpkg.Invoice)
386✔
447

386✔
448
        err := kvdb.View(d, func(tx kvdb.RTx) error {
772✔
449
                invoices := tx.ReadBucket(invoiceBucket)
386✔
450
                if invoices == nil {
386✔
451
                        return nil
×
452
                }
×
453

454
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
386✔
455
                if invoiceIndex == nil {
769✔
456
                        // Mask the error if there's no invoice
383✔
457
                        // index as that simply means there are no
383✔
458
                        // invoices added yet to the DB. In this case
383✔
459
                        // we simply return an empty list.
383✔
460
                        return nil
383✔
461
                }
383✔
462

463
                return invoiceIndex.ForEach(func(k, v []byte) error {
41✔
464
                        // Skip the special numInvoicesKey as that does not
38✔
465
                        // point to a valid invoice.
38✔
466
                        if bytes.Equal(k, numInvoicesKey) {
41✔
467
                                return nil
3✔
468
                        }
3✔
469

470
                        // Skip sub-buckets.
471
                        if v == nil {
35✔
472
                                return nil
×
473
                        }
×
474

475
                        invoice, err := fetchInvoice(v, invoices, nil, false)
35✔
476
                        if err != nil {
35✔
477
                                return err
×
478
                        }
×
479

480
                        if invoice.IsPending() {
55✔
481
                                var paymentHash lntypes.Hash
20✔
482
                                copy(paymentHash[:], k)
20✔
483
                                result[paymentHash] = invoice
20✔
484
                        }
20✔
485

486
                        return nil
35✔
487
                })
488
        }, func() {
386✔
489
                result = make(map[lntypes.Hash]invpkg.Invoice)
386✔
490
        })
386✔
491

492
        if err != nil {
386✔
493
                return nil, err
×
494
        }
×
495

496
        return result, nil
386✔
497
}
498

499
// QueryInvoices allows a caller to query the invoice database for invoices
500
// within the specified add index range.
501
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
502
        invpkg.InvoiceSlice, error) {
43✔
503

43✔
504
        var resp invpkg.InvoiceSlice
43✔
505

43✔
506
        err := kvdb.View(d, func(tx kvdb.RTx) error {
86✔
507
                // If the bucket wasn't found, then there aren't any invoices
43✔
508
                // within the database yet, so we can simply exit.
43✔
509
                invoices := tx.ReadBucket(invoiceBucket)
43✔
510
                if invoices == nil {
43✔
511
                        return invpkg.ErrNoInvoicesCreated
×
512
                }
×
513

514
                // Get the add index bucket which we will use to iterate through
515
                // our indexed invoices.
516
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
43✔
517
                if invoiceAddIndex == nil {
43✔
UNCOV
518
                        return invpkg.ErrNoInvoicesCreated
×
UNCOV
519
                }
×
520

521
                // Create a paginator which reads from our add index bucket with
522
                // the parameters provided by the invoice query.
523
                paginator := newPaginator(
43✔
524
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
43✔
525
                        q.NumMaxInvoices,
43✔
526
                )
43✔
527

43✔
528
                // accumulateInvoices looks up an invoice based on the index we
43✔
529
                // are given, adds it to our set of invoices if it has the right
43✔
530
                // characteristics for our query and returns the number of items
43✔
531
                // we have added to our set of invoices.
43✔
532
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
982✔
533
                        invoice, err := fetchInvoice(
939✔
534
                                indexValue, invoices, nil, false,
939✔
535
                        )
939✔
536
                        if err != nil {
939✔
537
                                return false, err
×
538
                        }
×
539

540
                        // Skip any settled or canceled invoices if the caller
541
                        // is only interested in pending ones.
542
                        if q.PendingOnly && !invoice.IsPending() {
1,094✔
543
                                return false, nil
155✔
544
                        }
155✔
545

546
                        // Get the creation time in Unix seconds, this always
547
                        // rounds down the nanoseconds to full seconds.
548
                        createTime := invoice.CreationDate.Unix()
784✔
549

784✔
550
                        // Skip any invoices that were created before the
784✔
551
                        // specified time.
784✔
552
                        if createTime < q.CreationDateStart {
881✔
553
                                return false, nil
97✔
554
                        }
97✔
555

556
                        // Skip any invoices that were created after the
557
                        // specified time.
558
                        if q.CreationDateEnd != 0 &&
687✔
559
                                createTime > q.CreationDateEnd {
865✔
560

178✔
561
                                return false, nil
178✔
562
                        }
178✔
563

564
                        // At this point, we've exhausted the offset, so we'll
565
                        // begin collecting invoices found within the range.
566
                        resp.Invoices = append(resp.Invoices, invoice)
509✔
567

509✔
568
                        return true, nil
509✔
569
                }
570

571
                // Query our paginator using accumulateInvoices to build up a
572
                // set of invoices.
573
                if err := paginator.query(accumulateInvoices); err != nil {
43✔
574
                        return err
×
575
                }
×
576

577
                // If we iterated through the add index in reverse order, then
578
                // we'll need to reverse the slice of invoices to return them in
579
                // forward order.
580
                if q.Reversed {
57✔
581
                        numInvoices := len(resp.Invoices)
14✔
582
                        for i := 0; i < numInvoices/2; i++ {
83✔
583
                                reverse := numInvoices - i - 1
69✔
584
                                resp.Invoices[i], resp.Invoices[reverse] =
69✔
585
                                        resp.Invoices[reverse], resp.Invoices[i]
69✔
586
                        }
69✔
587
                }
588

589
                return nil
43✔
590
        }, func() {
43✔
591
                resp = invpkg.InvoiceSlice{
43✔
592
                        InvoiceQuery: q,
43✔
593
                }
43✔
594
        })
43✔
595
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
43✔
596
                return resp, err
×
597
        }
×
598

599
        // Finally, record the indexes of the first and last invoices returned
600
        // so that the caller can resume from this point later on.
601
        if len(resp.Invoices) > 0 {
81✔
602
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
38✔
603
                lastIdx := len(resp.Invoices) - 1
38✔
604
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
38✔
605
        }
38✔
606

607
        return resp, nil
43✔
608
}
609

610
// UpdateInvoice attempts to update an invoice corresponding to the passed
611
// payment hash. If an invoice matching the passed payment hash doesn't exist
612
// within the database, then the action will fail with a "not found" error.
613
//
614
// The update is performed inside the same database transaction that fetches the
615
// invoice and is therefore atomic. The fields to update are controlled by the
616
// supplied callback.  When updating an invoice, the update itself happens
617
// in-memory on a copy of the invoice. Once it is written successfully to the
618
// database, the in-memory copy is returned to the caller.
619
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
620
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
621
        *invpkg.Invoice, error) {
1,491✔
622

1,491✔
623
        var updatedInvoice *invpkg.Invoice
1,491✔
624
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
2,982✔
625
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
1,491✔
626
                if err != nil {
1,491✔
627
                        return err
×
628
                }
×
629
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
1,491✔
630
                        invoiceIndexBucket,
1,491✔
631
                )
1,491✔
632
                if err != nil {
1,491✔
633
                        return err
×
634
                }
×
635
                settleIndex, err := invoices.CreateBucketIfNotExists(
1,491✔
636
                        settleIndexBucket,
1,491✔
637
                )
1,491✔
638
                if err != nil {
1,491✔
639
                        return err
×
640
                }
×
641
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
1,491✔
642
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
1,491✔
643

1,491✔
644
                // Retrieve the invoice number for this invoice using the
1,491✔
645
                // provided invoice reference.
1,491✔
646
                invoiceNum, err := fetchInvoiceNumByRef(
1,491✔
647
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
1,491✔
648
                )
1,491✔
649
                if err != nil {
1,495✔
650
                        return err
4✔
651
                }
4✔
652

653
                // If the set ID hint is non-nil, then we'll use that to filter
654
                // out the HTLCs for AMP invoice so we don't need to read them
655
                // all out to satisfy the invoice callback below. If it's nil,
656
                // then we pass in the zero set ID which means no HTLCs will be
657
                // read out.
658
                var invSetID invpkg.SetID
1,487✔
659

1,487✔
660
                if setIDHint != nil {
1,518✔
661
                        invSetID = *setIDHint
31✔
662
                }
31✔
663
                invoice, err := fetchInvoice(
1,487✔
664
                        invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
1,487✔
665
                )
1,487✔
666
                if err != nil {
1,487✔
667
                        return err
×
668
                }
×
669

670
                now := d.clock.Now()
1,487✔
671
                updater := &kvInvoiceUpdater{
1,487✔
672
                        db:                d,
1,487✔
673
                        invoicesBucket:    invoices,
1,487✔
674
                        settleIndexBucket: settleIndex,
1,487✔
675
                        setIDIndexBucket:  setIDIndex,
1,487✔
676
                        updateTime:        now,
1,487✔
677
                        invoiceNum:        invoiceNum,
1,487✔
678
                        invoice:           &invoice,
1,487✔
679
                        updatedAmpHtlcs:   make(ampHTLCsMap),
1,487✔
680
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
1,487✔
681
                }
1,487✔
682

1,487✔
683
                payHash := ref.PayHash()
1,487✔
684
                updatedInvoice, err = invpkg.UpdateInvoice(
1,487✔
685
                        payHash, updater.invoice, now, callback, updater,
1,487✔
686
                )
1,487✔
687
                if err != nil {
1,499✔
688
                        return err
12✔
689
                }
12✔
690

691
                // If this is an AMP update, then limit the returned AMP state
692
                // to only the requested set ID.
693
                if setIDHint != nil {
1,503✔
694
                        filterInvoiceAMPState(updatedInvoice, &invSetID)
28✔
695
                }
28✔
696

697
                return nil
1,475✔
698
        }, func() {
1,491✔
699
                updatedInvoice = nil
1,491✔
700
        })
1,491✔
701

702
        return updatedInvoice, err
1,491✔
703
}
704

705
// filterInvoiceAMPState filters the AMP state of the invoice to only include
706
// state for the specified set IDs.
707
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
52✔
708
        filteredAMPState := make(invpkg.AMPInvoiceState)
52✔
709

52✔
710
        for _, setID := range setIDs {
104✔
711
                if setID == nil {
70✔
712
                        return
18✔
713
                }
18✔
714

715
                ampState, ok := invoice.AMPState[*setID]
34✔
716
                if ok {
67✔
717
                        filteredAMPState[*setID] = ampState
33✔
718
                }
33✔
719
        }
720

721
        invoice.AMPState = filteredAMPState
34✔
722
}
723

724
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
725
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
726

727
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
728
// is used with the kv implementation of the invoice database. Note that this
729
// updater is not concurrency safe and synchronizaton is expected to be handled
730
// on the DB level.
731
type kvInvoiceUpdater struct {
732
        db                *DB
733
        invoicesBucket    kvdb.RwBucket
734
        settleIndexBucket kvdb.RwBucket
735
        setIDIndexBucket  kvdb.RwBucket
736

737
        // updateTime is the timestamp for the update.
738
        updateTime time.Time
739

740
        // invoiceNum is a legacy key similar to the add index that is used
741
        // only in the kv implementation.
742
        invoiceNum []byte
743

744
        // invoice is the invoice that we're updating. As a side effect of the
745
        // update this invoice will be mutated.
746
        invoice *invpkg.Invoice
747

748
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
749
        // cancelled as part of this update.
750
        updatedAmpHtlcs ampHTLCsMap
751

752
        // settledSetIDs holds the set IDs that are settled with this update.
753
        settledSetIDs map[invpkg.SetID]struct{}
754
}
755

756
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
757
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
758
        _ *invpkg.InvoiceHTLC) error {
924✔
759

924✔
760
        return nil
924✔
761
}
924✔
762

763
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
764
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
765
        _ time.Time) error {
918✔
766

918✔
767
        return nil
918✔
768
}
918✔
769

770
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
771
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
772
        _ lntypes.Preimage) error {
6✔
773

6✔
774
        return nil
6✔
775
}
6✔
776

777
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
778
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
779
        _ *lntypes.Preimage) error {
1,316✔
780

1,316✔
781
        return nil
1,316✔
782
}
1,316✔
783

784
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
785
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
1,423✔
786
        return nil
1,423✔
787
}
1,423✔
788

789
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
790
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
791
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
33✔
792

33✔
793
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
61✔
794
                switch state.State {
28✔
795
                case invpkg.HtlcStateAccepted:
19✔
796
                        // If we're just now creating the HTLCs for this set
19✔
797
                        // then we'll also pull in the existing HTLCs that are
19✔
798
                        // part of this set, so we can write them all to disk
19✔
799
                        // together (same value)
19✔
800
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
19✔
801
                                &setID, invpkg.HtlcStateAccepted,
19✔
802
                        )
19✔
803

804
                case invpkg.HtlcStateCanceled:
3✔
805
                        // Only HTLCs in the accepted state, can be cancelled,
3✔
806
                        // but we also want to merge that with HTLCs that may be
3✔
807
                        // canceled as well since it can be cancelled one by
3✔
808
                        // one.
3✔
809
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
3✔
810
                                &setID, invpkg.HtlcStateAccepted,
3✔
811
                        )
3✔
812

3✔
813
                        cancelledHtlcs := k.invoice.HTLCSet(
3✔
814
                                &setID, invpkg.HtlcStateCanceled,
3✔
815
                        )
3✔
816
                        for htlcKey, htlc := range cancelledHtlcs {
7✔
817
                                k.updatedAmpHtlcs[setID][htlcKey] = htlc
4✔
818
                        }
4✔
819

820
                case invpkg.HtlcStateSettled:
6✔
821
                        k.updatedAmpHtlcs[setID] = make(
6✔
822
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
6✔
823
                        )
6✔
824
                }
825
        }
826

827
        if state.State == invpkg.HtlcStateSettled {
44✔
828
                // Add the set ID to the set that was settled in this invoice
11✔
829
                // update. We'll use this later to update the settle index.
11✔
830
                k.settledSetIDs[setID] = struct{}{}
11✔
831
        }
11✔
832

833
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
33✔
834

33✔
835
        return nil
33✔
836
}
837

838
// Finalize finalizes the update before it is written to the database.
839
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
1,457✔
840
        switch updateType {
1,457✔
841
        case invpkg.AddHTLCsUpdate:
929✔
842
                return k.storeAddHtlcsUpdate()
929✔
843

844
        case invpkg.CancelHTLCsUpdate:
6✔
845
                return k.storeCancelHtlcsUpdate()
6✔
846

847
        case invpkg.SettleHodlInvoiceUpdate:
491✔
848
                return k.storeSettleHodlInvoiceUpdate()
491✔
849

850
        case invpkg.CancelInvoiceUpdate:
31✔
851
                return k.serializeAndStoreInvoice()
31✔
852
        }
853

854
        return fmt.Errorf("unknown update type: %v", updateType)
×
855
}
856

857
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
858
// set of HTLCs.
859
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
6✔
860
        err := k.serializeAndStoreInvoice()
6✔
861
        if err != nil {
6✔
862
                return err
×
863
        }
×
864

865
        // If this is an AMP invoice, then we'll actually store the rest
866
        // of the HTLCs in-line with the invoice, using the invoice ID
867
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
868
        // setID.
869
        if k.invoice.IsAMP() {
9✔
870
                return k.updateAMPInvoices()
3✔
871
        }
3✔
872

873
        return nil
3✔
874
}
875

876
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
877
// HTLCs.
878
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
929✔
879
        invoiceIsAMP := k.invoice.IsAMP()
929✔
880

929✔
881
        for htlcSetID := range k.updatedAmpHtlcs {
954✔
882
                // Check if this SetID already exist.
25✔
883
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
25✔
884

25✔
885
                if setIDInvNum == nil {
38✔
886
                        err := k.setIDIndexBucket.Put(
13✔
887
                                htlcSetID[:], k.invoiceNum,
13✔
888
                        )
13✔
889
                        if err != nil {
13✔
890
                                return err
×
891
                        }
×
892
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
13✔
893
                        return invpkg.ErrDuplicateSetID{
1✔
894
                                SetID: htlcSetID,
1✔
895
                        }
1✔
896
                }
1✔
897
        }
898

899
        // If this is a non-AMP invoice, then the state can eventually go to
900
        // ContractSettled, so we pass in nil value as part of
901
        // setSettleMetaFields.
902
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
1,226✔
903
                err := k.setSettleMetaFields(nil)
298✔
904
                if err != nil {
298✔
905
                        return err
×
906
                }
×
907
        }
908

909
        // As we don't update the settle index above for AMP invoices, we'll do
910
        // it here for each sub-AMP invoice that was settled.
911
        for settledSetID := range k.settledSetIDs {
936✔
912
                settledSetID := settledSetID
8✔
913
                err := k.setSettleMetaFields(&settledSetID)
8✔
914
                if err != nil {
8✔
915
                        return err
×
916
                }
×
917
        }
918

919
        err := k.serializeAndStoreInvoice()
928✔
920
        if err != nil {
928✔
921
                return err
×
922
        }
×
923

924
        // If this is an AMP invoice, then we'll actually store the rest of the
925
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
926
        // and the AMP key as a suffix: invoiceNum || setID.
927
        if invoiceIsAMP {
952✔
928
                return k.updateAMPInvoices()
24✔
929
        }
24✔
930

931
        return nil
904✔
932
}
933

934
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
935
// settling a hodl invoice.
936
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
491✔
937
        err := k.setSettleMetaFields(nil)
491✔
938
        if err != nil {
491✔
939
                return err
×
940
        }
×
941

942
        return k.serializeAndStoreInvoice()
491✔
943
}
944

945
// setSettleMetaFields updates the metadata associated with settlement of an
946
// invoice. If a non-nil setID is passed in, then the value will be append to
947
// the invoice number as well, in order to allow us to detect repeated payments
948
// to the same AMP invoices "across time".
949
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
797✔
950
        // Now that we know the invoice hasn't already been settled, we'll
797✔
951
        // update the settle index so we can place this settle event in the
797✔
952
        // proper location within our time series.
797✔
953
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
797✔
954
        if err != nil {
797✔
955
                return err
×
956
        }
×
957

958
        // Make a new byte array on the stack that can potentially store the 4
959
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
960
        // here which is the number of bytes copied so we can only store the 4
961
        // bytes if this is a non-AMP invoice.
962
        var indexKey [invoiceSetIDKeyLen]byte
797✔
963
        valueLen := copy(indexKey[:], k.invoiceNum)
797✔
964

797✔
965
        if setID != nil {
805✔
966
                valueLen += copy(indexKey[valueLen:], setID[:])
8✔
967
        }
8✔
968

969
        var seqNoBytes [8]byte
797✔
970
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
797✔
971
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
797✔
972
        if err != nil {
797✔
973
                return err
×
974
        }
×
975

976
        // If the setID is nil, then this means that this is a non-AMP settle,
977
        // so we'll update the invoice settle index directly.
978
        if setID == nil {
1,586✔
979
                k.invoice.SettleDate = k.updateTime
789✔
980
                k.invoice.SettleIndex = nextSettleSeqNo
789✔
981
        } else {
797✔
982
                // If the set ID isn't blank, we'll update the AMP state map
8✔
983
                // which tracks when each of the setIDs associated with a given
8✔
984
                // AMP invoice are settled.
8✔
985
                ampState := k.invoice.AMPState[*setID]
8✔
986

8✔
987
                ampState.SettleDate = k.updateTime
8✔
988
                ampState.SettleIndex = nextSettleSeqNo
8✔
989

8✔
990
                k.invoice.AMPState[*setID] = ampState
8✔
991
        }
8✔
992

993
        return nil
797✔
994
}
995

996
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
997
// then continually write the invoices to the end of the invoice value, we
998
// instead write the invoices into a new key preifx that follows the main
999
// invoice number. This ensures that we don't need to continually decode a
1000
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
1001
// associated with a particular HTLC set.
1002
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
27✔
1003
        for setID, htlcSet := range k.updatedAmpHtlcs {
54✔
1004
                // First write out the set of HTLCs including all the relevant
27✔
1005
                // TLV values.
27✔
1006
                var b bytes.Buffer
27✔
1007
                if err := serializeHtlcs(&b, htlcSet); err != nil {
27✔
1008
                        return err
×
1009
                }
×
1010

1011
                // Next store each HTLC in-line, using a prefix based off the
1012
                // invoice number.
1013
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
27✔
1014

27✔
1015
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
27✔
1016
                if err != nil {
27✔
1017
                        return err
×
1018
                }
×
1019
        }
1020

1021
        return nil
27✔
1022
}
1023

1024
// serializeAndStoreInvoice is a helper function used to store invoices.
1025
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
1,456✔
1026
        var buf bytes.Buffer
1,456✔
1027
        if err := serializeInvoice(&buf, k.invoice); err != nil {
1,456✔
1028
                return err
×
1029
        }
×
1030

1031
        return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
1,456✔
1032
}
1033

1034
// InvoicesSettledSince can be used by callers to catch up any settled invoices
1035
// they missed within the settled invoice time series. We'll return all known
1036
// settled invoice that have a settle index higher than the passed
1037
// sinceSettleIndex.
1038
//
1039
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1040
// value below the starting index value is a noop.
1041
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1042
        []invpkg.Invoice, error) {
21✔
1043

21✔
1044
        var settledInvoices []invpkg.Invoice
21✔
1045

21✔
1046
        // If an index of zero was specified, then in order to maintain
21✔
1047
        // backwards compat, we won't send out any new invoices.
21✔
1048
        if sinceSettleIndex == 0 {
39✔
1049
                return settledInvoices, nil
18✔
1050
        }
18✔
1051

1052
        var startIndex [8]byte
3✔
1053
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
3✔
1054

3✔
1055
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
1056
                invoices := tx.ReadBucket(invoiceBucket)
3✔
1057
                if invoices == nil {
3✔
1058
                        return nil
×
1059
                }
×
1060

1061
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
3✔
1062
                if settleIndex == nil {
3✔
1063
                        return nil
×
1064
                }
×
1065

1066
                // We'll now run through each entry in the add index starting
1067
                // at our starting index. We'll continue until we reach the
1068
                // very end of the current key space.
1069
                invoiceCursor := settleIndex.ReadCursor()
3✔
1070

3✔
1071
                // We'll seek to the starting index, then manually advance the
3✔
1072
                // cursor in order to skip the entry with the since add index.
3✔
1073
                invoiceCursor.Seek(startIndex[:])
3✔
1074
                seqNo, indexValue := invoiceCursor.Next()
3✔
1075

3✔
1076
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
14✔
1077
                        // Depending on the length of the index value, this may
11✔
1078
                        // or may not be an AMP invoice, so we'll extract the
11✔
1079
                        // invoice value into two components: the invoice num,
11✔
1080
                        // and the setID (may not be there).
11✔
1081
                        var (
11✔
1082
                                invoiceKey [4]byte
11✔
1083
                                setID      *invpkg.SetID
11✔
1084
                        )
11✔
1085

11✔
1086
                        valueLen := copy(invoiceKey[:], indexValue)
11✔
1087
                        if len(indexValue) == invoiceSetIDKeyLen {
13✔
1088
                                setID = new(invpkg.SetID)
2✔
1089
                                copy(setID[:], indexValue[valueLen:])
2✔
1090
                        }
2✔
1091

1092
                        // For each key found, we'll look up the actual
1093
                        // invoice, then accumulate it into our return value.
1094
                        invoice, err := fetchInvoice(
11✔
1095
                                invoiceKey[:], invoices, []*invpkg.SetID{setID},
11✔
1096
                                true,
11✔
1097
                        )
11✔
1098
                        if err != nil {
11✔
1099
                                return err
×
1100
                        }
×
1101

1102
                        settledInvoices = append(settledInvoices, invoice)
11✔
1103
                }
1104

1105
                return nil
3✔
1106
        }, func() {
3✔
1107
                settledInvoices = nil
3✔
1108
        })
3✔
1109
        if err != nil {
3✔
1110
                return nil, err
×
1111
        }
×
1112

1113
        return settledInvoices, nil
3✔
1114
}
1115

1116
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
1117
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
1118
        uint64, error) {
1,045✔
1119

1,045✔
1120
        // Create the invoice key which is just the big-endian representation
1,045✔
1121
        // of the invoice number.
1,045✔
1122
        var invoiceKey [4]byte
1,045✔
1123
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
1,045✔
1124

1,045✔
1125
        // Increment the num invoice counter index so the next invoice bares
1,045✔
1126
        // the proper ID.
1,045✔
1127
        var scratch [4]byte
1,045✔
1128
        invoiceCounter := invoiceNum + 1
1,045✔
1129
        byteOrder.PutUint32(scratch[:], invoiceCounter)
1,045✔
1130
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
1,045✔
1131
                return 0, err
×
1132
        }
×
1133

1134
        // Add the payment hash to the invoice index. This will let us quickly
1135
        // identify if we can settle an incoming payment, and also to possibly
1136
        // allow a single invoice to have multiple payment installations.
1137
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
1,045✔
1138
        if err != nil {
1,045✔
1139
                return 0, err
×
1140
        }
×
1141

1142
        // Add the invoice to the payment address index, but only if the invoice
1143
        // has a non-zero payment address. The all-zero payment address is still
1144
        // in use by legacy keysend, so we special-case here to avoid
1145
        // collisions.
1146
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
1,969✔
1147
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
924✔
1148
                if err != nil {
924✔
1149
                        return 0, err
×
1150
                }
×
1151
        }
1152

1153
        // Next, we'll obtain the next add invoice index (sequence
1154
        // number), so we can properly place this invoice within this
1155
        // event stream.
1156
        nextAddSeqNo, err := addIndex.NextSequence()
1,045✔
1157
        if err != nil {
1,045✔
1158
                return 0, err
×
1159
        }
×
1160

1161
        // With the next sequence obtained, we'll updating the event series in
1162
        // the add index bucket to map this current add counter to the index of
1163
        // this new invoice.
1164
        var seqNoBytes [8]byte
1,045✔
1165
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
1,045✔
1166
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
1,045✔
1167
                return 0, err
×
1168
        }
×
1169

1170
        i.AddIndex = nextAddSeqNo
1,045✔
1171

1,045✔
1172
        // Finally, serialize the invoice itself to be written to the disk.
1,045✔
1173
        var buf bytes.Buffer
1,045✔
1174
        if err := serializeInvoice(&buf, i); err != nil {
1,045✔
1175
                return 0, err
×
1176
        }
×
1177

1178
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
1,045✔
1179
                return 0, err
×
1180
        }
×
1181

1182
        return nextAddSeqNo, nil
1,045✔
1183
}
1184

1185
// recordSize returns the amount of bytes this TLV record will occupy when
1186
// encoded.
1187
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
2,502✔
1188
        var (
2,502✔
1189
                b   bytes.Buffer
2,502✔
1190
                buf [8]byte
2,502✔
1191
        )
2,502✔
1192

2,502✔
1193
        // We know that encoding works since the tests pass in the build this
2,502✔
1194
        // file is checked into, so we'll simplify things and simply encode it
2,502✔
1195
        // ourselves then report the total amount of bytes used.
2,502✔
1196
        if err := ampStateEncoder(&b, a, &buf); err != nil {
2,502✔
1197
                // This should never error out, but we log it just in case it
×
1198
                // does.
×
1199
                log.Errorf("encoding the amp invoice state failed: %v", err)
×
1200
        }
×
1201

1202
        return func() uint64 {
5,004✔
1203
                return uint64(len(b.Bytes()))
2,502✔
1204
        }
2,502✔
1205
}
1206

1207
// serializeInvoice serializes an invoice to a writer.
1208
//
1209
// Note: this function is in use for a migration. Before making changes that
1210
// would modify the on disk format, make a copy of the original code and store
1211
// it with the migration.
1212
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
2,501✔
1213
        creationDateBytes, err := i.CreationDate.MarshalBinary()
2,501✔
1214
        if err != nil {
2,501✔
1215
                return err
×
1216
        }
×
1217

1218
        settleDateBytes, err := i.SettleDate.MarshalBinary()
2,501✔
1219
        if err != nil {
2,501✔
1220
                return err
×
1221
        }
×
1222

1223
        var fb bytes.Buffer
2,501✔
1224
        err = i.Terms.Features.EncodeBase256(&fb)
2,501✔
1225
        if err != nil {
2,501✔
1226
                return err
×
1227
        }
×
1228
        featureBytes := fb.Bytes()
2,501✔
1229

2,501✔
1230
        preimage := [32]byte(invpkg.UnknownPreimage)
2,501✔
1231
        if i.Terms.PaymentPreimage != nil {
4,002✔
1232
                preimage = *i.Terms.PaymentPreimage
1,501✔
1233
                if preimage == invpkg.UnknownPreimage {
1,501✔
1234
                        return errors.New("cannot use all-zeroes preimage")
×
1235
                }
×
1236
        }
1237
        value := uint64(i.Terms.Value)
2,501✔
1238
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
2,501✔
1239
        expiry := uint64(i.Terms.Expiry)
2,501✔
1240

2,501✔
1241
        amtPaid := uint64(i.AmtPaid)
2,501✔
1242
        state := uint8(i.State)
2,501✔
1243

2,501✔
1244
        var hodlInvoice uint8
2,501✔
1245
        if i.HodlInvoice {
3,997✔
1246
                hodlInvoice = 1
1,496✔
1247
        }
1,496✔
1248

1249
        tlvStream, err := tlv.NewStream(
2,501✔
1250
                // Memo and payreq.
2,501✔
1251
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
2,501✔
1252
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
2,501✔
1253

2,501✔
1254
                // Add/settle metadata.
2,501✔
1255
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
2,501✔
1256
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
2,501✔
1257
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
2,501✔
1258
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
2,501✔
1259

2,501✔
1260
                // Terms.
2,501✔
1261
                tlv.MakePrimitiveRecord(preimageType, &preimage),
2,501✔
1262
                tlv.MakePrimitiveRecord(valueType, &value),
2,501✔
1263
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
2,501✔
1264
                tlv.MakePrimitiveRecord(expiryType, &expiry),
2,501✔
1265
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
2,501✔
1266
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
2,501✔
1267

2,501✔
1268
                // Invoice state.
2,501✔
1269
                tlv.MakePrimitiveRecord(invStateType, &state),
2,501✔
1270
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
2,501✔
1271

2,501✔
1272
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
2,501✔
1273

2,501✔
1274
                // Invoice AMP state.
2,501✔
1275
                tlv.MakeDynamicRecord(
2,501✔
1276
                        invoiceAmpStateType, &i.AMPState,
2,501✔
1277
                        ampRecordSize(&i.AMPState),
2,501✔
1278
                        ampStateEncoder, ampStateDecoder,
2,501✔
1279
                ),
2,501✔
1280
        )
2,501✔
1281
        if err != nil {
2,501✔
1282
                return err
×
1283
        }
×
1284

1285
        var b bytes.Buffer
2,501✔
1286
        if err = tlvStream.Encode(&b); err != nil {
2,501✔
1287
                return err
×
1288
        }
×
1289

1290
        err = binary.Write(w, byteOrder, uint64(b.Len()))
2,501✔
1291
        if err != nil {
2,501✔
1292
                return err
×
1293
        }
×
1294

1295
        if _, err = w.Write(b.Bytes()); err != nil {
2,501✔
1296
                return err
×
1297
        }
×
1298

1299
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
1300
        // in-line with the rest of the invoice.
1301
        if i.IsAMP() {
2,541✔
1302
                return nil
40✔
1303
        }
40✔
1304

1305
        return serializeHtlcs(w, i.Htlcs)
2,461✔
1306
}
1307

1308
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
1309
// a writer.
1310
func serializeHtlcs(w io.Writer,
1311
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
2,488✔
1312

2,488✔
1313
        for key, htlc := range htlcs {
4,041✔
1314
                // Encode the htlc in a tlv stream.
1,553✔
1315
                chanID := key.ChanID.ToUint64()
1,553✔
1316
                amt := uint64(htlc.Amt)
1,553✔
1317
                mppTotalAmt := uint64(htlc.MppTotalAmt)
1,553✔
1318
                acceptTime := putNanoTime(htlc.AcceptTime)
1,553✔
1319
                resolveTime := putNanoTime(htlc.ResolveTime)
1,553✔
1320
                state := uint8(htlc.State)
1,553✔
1321

1,553✔
1322
                var records []tlv.Record
1,553✔
1323
                records = append(records,
1,553✔
1324
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
1,553✔
1325
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
1,553✔
1326
                        tlv.MakePrimitiveRecord(amtType, &amt),
1,553✔
1327
                        tlv.MakePrimitiveRecord(
1,553✔
1328
                                acceptHeightType, &htlc.AcceptHeight,
1,553✔
1329
                        ),
1,553✔
1330
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
1,553✔
1331
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
1,553✔
1332
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
1,553✔
1333
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
1,553✔
1334
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
1,553✔
1335
                )
1,553✔
1336

1,553✔
1337
                if htlc.AMP != nil {
1,589✔
1338
                        setIDRecord := tlv.MakeDynamicRecord(
36✔
1339
                                htlcAMPType, &htlc.AMP.Record,
36✔
1340
                                htlc.AMP.Record.PayloadSize,
36✔
1341
                                record.AMPEncoder, record.AMPDecoder,
36✔
1342
                        )
36✔
1343
                        records = append(records, setIDRecord)
36✔
1344

36✔
1345
                        hash32 := [32]byte(htlc.AMP.Hash)
36✔
1346
                        hashRecord := tlv.MakePrimitiveRecord(
36✔
1347
                                htlcHashType, &hash32,
36✔
1348
                        )
36✔
1349
                        records = append(records, hashRecord)
36✔
1350

36✔
1351
                        if htlc.AMP.Preimage != nil {
63✔
1352
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
27✔
1353
                                preimageRecord := tlv.MakePrimitiveRecord(
27✔
1354
                                        htlcPreimageType, &preimage32,
27✔
1355
                                )
27✔
1356
                                records = append(records, preimageRecord)
27✔
1357
                        }
27✔
1358
                }
1359

1360
                // Convert the custom records to tlv.Record types that are ready
1361
                // for serialization.
1362
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
1,553✔
1363

1,553✔
1364
                // Append the custom records. Their ids are in the experimental
1,553✔
1365
                // range and sorted, so there is no need to sort again.
1,553✔
1366
                records = append(records, customRecords...)
1,553✔
1367

1,553✔
1368
                tlvStream, err := tlv.NewStream(records...)
1,553✔
1369
                if err != nil {
1,553✔
1370
                        return err
×
1371
                }
×
1372

1373
                var b bytes.Buffer
1,553✔
1374
                if err := tlvStream.Encode(&b); err != nil {
1,553✔
1375
                        return err
×
1376
                }
×
1377

1378
                // Write the length of the tlv stream followed by the stream
1379
                // bytes.
1380
                err = binary.Write(w, byteOrder, uint64(b.Len()))
1,553✔
1381
                if err != nil {
1,553✔
1382
                        return err
×
1383
                }
×
1384

1385
                if _, err := w.Write(b.Bytes()); err != nil {
1,553✔
1386
                        return err
×
1387
                }
×
1388
        }
1389

1390
        return nil
2,488✔
1391
}
1392

1393
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
1394
// timestamp will be mapped to 0, since calling UnixNano in that case is
1395
// undefined.
1396
func putNanoTime(t time.Time) uint64 {
3,106✔
1397
        if t.IsZero() {
3,736✔
1398
                return 0
630✔
1399
        }
630✔
1400
        return uint64(t.UnixNano())
2,476✔
1401
}
1402

1403
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
1404
// is provided, an zero-value time stamp is returned.
1405
func getNanoTime(ns uint64) time.Time {
3,054✔
1406
        if ns == 0 {
3,835✔
1407
                return time.Time{}
781✔
1408
        }
781✔
1409
        return time.Unix(0, int64(ns))
2,273✔
1410
}
1411

1412
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
1413
// identified by the setID value.
1414
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1415
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1416
        error) {
37✔
1417

37✔
1418
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
37✔
1419
        for _, setID := range setIDs {
74✔
1420
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
37✔
1421

37✔
1422
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
37✔
1423
                if htlcSetBytes == nil {
51✔
1424
                        // A set ID was passed in, but we don't have this
14✔
1425
                        // stored yet, meaning that the setID is being added
14✔
1426
                        // for the first time.
14✔
1427
                        return htlcs, invpkg.ErrInvoiceNotFound
14✔
1428
                }
14✔
1429

1430
                htlcSetReader := bytes.NewReader(htlcSetBytes)
23✔
1431
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
23✔
1432
                if err != nil {
23✔
1433
                        return nil, err
×
1434
                }
×
1435

1436
                for key, htlc := range htlcsBySetID {
51✔
1437
                        htlcs[key] = htlc
28✔
1438
                }
28✔
1439
        }
1440

1441
        return htlcs, nil
23✔
1442
}
1443

1444
// forEachAMPInvoice is a helper function that attempts to iterate over each of
1445
// the HTLC sets (based on their set ID) for the given AMP invoice identified
1446
// by its invoiceNum. The callback closure is called for each key within the
1447
// prefix range.
1448
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1449
        callback func(key, htlcSet []byte) error) error {
33✔
1450

33✔
1451
        invoiceCursor := invoiceBucket.ReadCursor()
33✔
1452

33✔
1453
        // Seek to the first key that includes the invoice data itself.
33✔
1454
        invoiceCursor.Seek(invoiceNum)
33✔
1455

33✔
1456
        // Advance to the very first key _after_ the invoice data, as this is
33✔
1457
        // where we'll encounter our first HTLC (if any are present).
33✔
1458
        cursorKey, htlcSet := invoiceCursor.Next()
33✔
1459

33✔
1460
        // If at this point, the cursor key doesn't match the invoice num
33✔
1461
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
33✔
1462
        // associated with it.
33✔
1463
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
51✔
1464
                return nil
18✔
1465
        }
18✔
1466

1467
        // Otherwise continue to iterate until we no longer match the prefix,
1468
        // executing the call back at each step.
1469
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
41✔
1470
                err := callback(cursorKey, htlcSet)
26✔
1471
                if err != nil {
26✔
1472
                        return err
×
1473
                }
×
1474
        }
1475

1476
        return nil
15✔
1477
}
1478

1479
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
1480
// AMP bucket to find all the individual HTLCs (by setID) associated with a
1481
// given invoice. If a list of set IDs are specified, then only HTLCs
1482
// associated with that setID will be retrieved.
1483
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1484
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1485
        error) {
55✔
1486

55✔
1487
        // If a set of setIDs was specified, then we can skip the cursor and
55✔
1488
        // just read out exactly what we need.
55✔
1489
        if len(setIDs) != 0 && setIDs[0] != nil {
92✔
1490
                return fetchFilteredAmpInvoices(
37✔
1491
                        invoiceBucket, invoiceNum, setIDs...,
37✔
1492
                )
37✔
1493
        }
37✔
1494

1495
        // Otherwise, iterate over all the htlc sets that are prefixed beside
1496
        // this invoice in the main invoice bucket.
1497
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
18✔
1498
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
18✔
1499
                func(key, htlcSet []byte) error {
41✔
1500
                        htlcSetReader := bytes.NewReader(htlcSet)
23✔
1501
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
23✔
1502
                        if err != nil {
23✔
1503
                                return err
×
1504
                        }
×
1505

1506
                        for key, htlc := range htlcsBySetID {
55✔
1507
                                htlcs[key] = htlc
32✔
1508
                        }
32✔
1509

1510
                        return nil
23✔
1511
                },
1512
        )
1513

1514
        if err != nil {
18✔
1515
                return nil, err
×
1516
        }
×
1517

1518
        return htlcs, nil
18✔
1519
}
1520

1521
// fetchInvoice attempts to read out the relevant state for the invoice as
1522
// specified by the invoice number. If the setID fields are set, then only the
1523
// HTLC information pertaining to those set IDs is returned.
1524
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1525
        setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
3,553✔
1526

3,553✔
1527
        invoiceBytes := invoices.Get(invoiceNum)
3,553✔
1528
        if invoiceBytes == nil {
3,553✔
1529
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
×
1530
        }
×
1531

1532
        invoiceReader := bytes.NewReader(invoiceBytes)
3,553✔
1533

3,553✔
1534
        invoice, err := deserializeInvoice(invoiceReader)
3,553✔
1535
        if err != nil {
3,553✔
1536
                return invpkg.Invoice{}, err
×
1537
        }
×
1538

1539
        // If this is an AMP invoice we'll also attempt to read out the set of
1540
        // HTLCs that were paid to prior set IDs, if needed.
1541
        if !invoice.IsAMP() {
7,048✔
1542
                return invoice, nil
3,495✔
1543
        }
3,495✔
1544

1545
        if shouldFetchAMPHTLCs(invoice, setIDs) {
113✔
1546
                invoice.Htlcs, err = fetchAmpSubInvoices(
55✔
1547
                        invoices, invoiceNum, setIDs...,
55✔
1548
                )
55✔
1549
                // TODO(positiveblue): we should fail when we are not able to
55✔
1550
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
55✔
1551
                // the invoice and channeldb package break if we return this
55✔
1552
                // error. We need to update them when we migrate this logic to
55✔
1553
                // the sql implementation.
55✔
1554
                if err != nil {
69✔
1555
                        log.Errorf("unable to fetch amp htlcs for inv "+
14✔
1556
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
14✔
1557
                }
14✔
1558

1559
                if filterAMPState {
79✔
1560
                        filterInvoiceAMPState(&invoice, setIDs...)
24✔
1561
                }
24✔
1562
        }
1563

1564
        return invoice, nil
58✔
1565
}
1566

1567
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
1568
// were paid to the relevant set IDs.
1569
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
58✔
1570
        // For AMP invoice that already have HTLCs populated (created before
58✔
1571
        // recurring invoices), then we don't need to read from the prefix
58✔
1572
        // keyed section of the bucket.
58✔
1573
        if len(invoice.Htlcs) != 0 {
58✔
1574
                return false
×
1575
        }
×
1576

1577
        // If the "zero" setID was specified, then this means that no HTLC data
1578
        // should be returned alongside of it.
1579
        if len(setIDs) != 0 && setIDs[0] != nil &&
58✔
1580
                *setIDs[0] == invpkg.BlankPayAddr {
61✔
1581

3✔
1582
                return false
3✔
1583
        }
3✔
1584

1585
        return true
55✔
1586
}
1587

1588
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
1589
// an AMP invoice. This methods only decode the relevant state vs the entire
1590
// invoice.
1591
func fetchInvoiceStateAMP(invoiceNum []byte,
1592
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
8✔
1593

8✔
1594
        // Fetch the raw invoice bytes.
8✔
1595
        invoiceBytes := invoices.Get(invoiceNum)
8✔
1596
        if invoiceBytes == nil {
8✔
1597
                return nil, invpkg.ErrInvoiceNotFound
×
1598
        }
×
1599

1600
        r := bytes.NewReader(invoiceBytes)
8✔
1601

8✔
1602
        var bodyLen int64
8✔
1603
        err := binary.Read(r, byteOrder, &bodyLen)
8✔
1604
        if err != nil {
8✔
1605
                return nil, err
×
1606
        }
×
1607

1608
        // Next, we'll make a new TLV stream that only attempts to decode the
1609
        // bytes we actually need.
1610
        ampState := make(invpkg.AMPInvoiceState)
8✔
1611
        tlvStream, err := tlv.NewStream(
8✔
1612
                // Invoice AMP state.
8✔
1613
                tlv.MakeDynamicRecord(
8✔
1614
                        invoiceAmpStateType, &ampState, nil,
8✔
1615
                        ampStateEncoder, ampStateDecoder,
8✔
1616
                ),
8✔
1617
        )
8✔
1618
        if err != nil {
8✔
1619
                return nil, err
×
1620
        }
×
1621

1622
        invoiceReader := io.LimitReader(r, bodyLen)
8✔
1623
        if err = tlvStream.Decode(invoiceReader); err != nil {
8✔
1624
                return nil, err
×
1625
        }
×
1626

1627
        return ampState, nil
8✔
1628
}
1629

1630
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
3,553✔
1631
        var (
3,553✔
1632
                preimageBytes [32]byte
3,553✔
1633
                value         uint64
3,553✔
1634
                cltvDelta     uint32
3,553✔
1635
                expiry        uint64
3,553✔
1636
                amtPaid       uint64
3,553✔
1637
                state         uint8
3,553✔
1638
                hodlInvoice   uint8
3,553✔
1639

3,553✔
1640
                creationDateBytes []byte
3,553✔
1641
                settleDateBytes   []byte
3,553✔
1642
                featureBytes      []byte
3,553✔
1643
        )
3,553✔
1644

3,553✔
1645
        var i invpkg.Invoice
3,553✔
1646
        i.AMPState = make(invpkg.AMPInvoiceState)
3,553✔
1647
        tlvStream, err := tlv.NewStream(
3,553✔
1648
                // Memo and payreq.
3,553✔
1649
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3,553✔
1650
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3,553✔
1651

3,553✔
1652
                // Add/settle metadata.
3,553✔
1653
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3,553✔
1654
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3,553✔
1655
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3,553✔
1656
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3,553✔
1657

3,553✔
1658
                // Terms.
3,553✔
1659
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
3,553✔
1660
                tlv.MakePrimitiveRecord(valueType, &value),
3,553✔
1661
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3,553✔
1662
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3,553✔
1663
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3,553✔
1664
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3,553✔
1665

3,553✔
1666
                // Invoice state.
3,553✔
1667
                tlv.MakePrimitiveRecord(invStateType, &state),
3,553✔
1668
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3,553✔
1669

3,553✔
1670
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3,553✔
1671

3,553✔
1672
                // Invoice AMP state.
3,553✔
1673
                tlv.MakeDynamicRecord(
3,553✔
1674
                        invoiceAmpStateType, &i.AMPState, nil,
3,553✔
1675
                        ampStateEncoder, ampStateDecoder,
3,553✔
1676
                ),
3,553✔
1677
        )
3,553✔
1678
        if err != nil {
3,553✔
1679
                return i, err
×
1680
        }
×
1681

1682
        var bodyLen int64
3,553✔
1683
        err = binary.Read(r, byteOrder, &bodyLen)
3,553✔
1684
        if err != nil {
3,553✔
1685
                return i, err
×
1686
        }
×
1687

1688
        lr := io.LimitReader(r, bodyLen)
3,553✔
1689
        if err = tlvStream.Decode(lr); err != nil {
3,553✔
1690
                return i, err
×
1691
        }
×
1692

1693
        preimage := lntypes.Preimage(preimageBytes)
3,553✔
1694
        if preimage != invpkg.UnknownPreimage {
5,600✔
1695
                i.Terms.PaymentPreimage = &preimage
2,047✔
1696
        }
2,047✔
1697

1698
        i.Terms.Value = lnwire.MilliSatoshi(value)
3,553✔
1699
        i.Terms.FinalCltvDelta = int32(cltvDelta)
3,553✔
1700
        i.Terms.Expiry = time.Duration(expiry)
3,553✔
1701
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
3,553✔
1702
        i.State = invpkg.ContractState(state)
3,553✔
1703

3,553✔
1704
        if hodlInvoice != 0 {
5,083✔
1705
                i.HodlInvoice = true
1,530✔
1706
        }
1,530✔
1707

1708
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
3,553✔
1709
        if err != nil {
3,553✔
1710
                return i, err
×
1711
        }
×
1712

1713
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
3,553✔
1714
        if err != nil {
3,553✔
1715
                return i, err
×
1716
        }
×
1717

1718
        rawFeatures := lnwire.NewRawFeatureVector()
3,553✔
1719
        err = rawFeatures.DecodeBase256(
3,553✔
1720
                bytes.NewReader(featureBytes), len(featureBytes),
3,553✔
1721
        )
3,553✔
1722
        if err != nil {
3,553✔
1723
                return i, err
×
1724
        }
×
1725

1726
        i.Terms.Features = lnwire.NewFeatureVector(
3,553✔
1727
                rawFeatures, lnwire.Features,
3,553✔
1728
        )
3,553✔
1729

3,553✔
1730
        i.Htlcs, err = deserializeHtlcs(r)
3,553✔
1731
        return i, err
3,553✔
1732
}
1733

1734
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
98✔
1735
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
196✔
1736
                // We encode the set of circuit keys as a varint length prefix.
98✔
1737
                // followed by a series of fixed sized uint8 integers.
98✔
1738
                numKeys := uint64(len(*v))
98✔
1739

98✔
1740
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
98✔
1741
                        return err
×
1742
                }
×
1743

1744
                for key := range *v {
226✔
1745
                        scidInt := key.ChanID.ToUint64()
128✔
1746

128✔
1747
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
128✔
1748
                                return err
×
1749
                        }
×
1750
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
128✔
1751
                                return err
×
1752
                        }
×
1753
                }
1754

1755
                return nil
98✔
1756
        }
1757

1758
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1759
}
1760

1761
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1762
        l uint64) error {
88✔
1763

88✔
1764
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
176✔
1765
                // First, we'll read out the varint that encodes the number of
88✔
1766
                // circuit keys encoded.
88✔
1767
                numKeys, err := tlv.ReadVarInt(r, buf)
88✔
1768
                if err != nil {
88✔
1769
                        return err
×
1770
                }
×
1771

1772
                // Now that we know how many keys to expect, iterate reading
1773
                // each one until we're done.
1774
                for i := uint64(0); i < numKeys; i++ {
197✔
1775
                        var (
109✔
1776
                                key  models.CircuitKey
109✔
1777
                                scid uint64
109✔
1778
                        )
109✔
1779

109✔
1780
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
109✔
1781
                                return err
×
1782
                        }
×
1783

1784
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
109✔
1785

109✔
1786
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
109✔
1787
                        if err != nil {
109✔
1788
                                return err
×
1789
                        }
×
1790

1791
                        (*v)[key] = struct{}{}
109✔
1792
                }
1793

1794
                return nil
88✔
1795
        }
1796

1797
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
×
1798
}
1799

1800
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1801
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
5,004✔
1802
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
10,008✔
1803
                // We'll encode the AMP state as a series of KV pairs on the
5,004✔
1804
                // wire with a length prefix.
5,004✔
1805
                numRecords := uint64(len(*v))
5,004✔
1806

5,004✔
1807
                // First, we'll write out the number of records as a var int.
5,004✔
1808
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
5,004✔
1809
                        return err
×
1810
                }
×
1811

1812
                // With that written out, we'll now encode the entries
1813
                // themselves as a sub-TLV record, which includes its _own_
1814
                // inner length prefix.
1815
                for setID, ampState := range *v {
5,102✔
1816
                        setID := [32]byte(setID)
98✔
1817
                        ampState := ampState
98✔
1818

98✔
1819
                        htlcState := uint8(ampState.State)
98✔
1820
                        settleDate := ampState.SettleDate
98✔
1821
                        settleDateBytes, err := settleDate.MarshalBinary()
98✔
1822
                        if err != nil {
98✔
1823
                                return err
×
1824
                        }
×
1825

1826
                        amtPaid := uint64(ampState.AmtPaid)
98✔
1827

98✔
1828
                        var ampStateTlvBytes bytes.Buffer
98✔
1829
                        tlvStream, err := tlv.NewStream(
98✔
1830
                                tlv.MakePrimitiveRecord(
98✔
1831
                                        ampStateSetIDType, &setID,
98✔
1832
                                ),
98✔
1833
                                tlv.MakePrimitiveRecord(
98✔
1834
                                        ampStateHtlcStateType, &htlcState,
98✔
1835
                                ),
98✔
1836
                                tlv.MakePrimitiveRecord(
98✔
1837
                                        ampStateSettleIndexType,
98✔
1838
                                        &ampState.SettleIndex,
98✔
1839
                                ),
98✔
1840
                                tlv.MakePrimitiveRecord(
98✔
1841
                                        ampStateSettleDateType,
98✔
1842
                                        &settleDateBytes,
98✔
1843
                                ),
98✔
1844
                                tlv.MakeDynamicRecord(
98✔
1845
                                        ampStateCircuitKeysType,
98✔
1846
                                        &ampState.InvoiceKeys,
98✔
1847
                                        func() uint64 {
196✔
1848
                                                // The record takes 8 bytes to
98✔
1849
                                                // encode the set of circuits,
98✔
1850
                                                // 8 bytes for the scid for the
98✔
1851
                                                // key, and 8 bytes for the HTLC
98✔
1852
                                                // index.
98✔
1853
                                                keys := ampState.InvoiceKeys
98✔
1854
                                                numKeys := uint64(len(keys))
98✔
1855
                                                size := tlv.VarIntSize(numKeys)
98✔
1856
                                                dataSize := (numKeys * 16)
98✔
1857

98✔
1858
                                                return size + dataSize
98✔
1859
                                        },
98✔
1860
                                        encodeCircuitKeys, decodeCircuitKeys,
1861
                                ),
1862
                                tlv.MakePrimitiveRecord(
1863
                                        ampStateAmtPaidType, &amtPaid,
1864
                                ),
1865
                        )
1866
                        if err != nil {
98✔
1867
                                return err
×
1868
                        }
×
1869

1870
                        err = tlvStream.Encode(&ampStateTlvBytes)
98✔
1871
                        if err != nil {
98✔
1872
                                return err
×
1873
                        }
×
1874

1875
                        // We encode the record with a varint length followed by
1876
                        // the _raw_ TLV bytes.
1877
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
98✔
1878
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
98✔
1879
                                return err
×
1880
                        }
×
1881

1882
                        _, err = w.Write(ampStateTlvBytes.Bytes())
98✔
1883
                        if err != nil {
98✔
1884
                                return err
×
1885
                        }
×
1886
                }
1887

1888
                return nil
5,004✔
1889
        }
1890

1891
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1892
}
1893

1894
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1895
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
1896
        l uint64) error {
3,562✔
1897

3,562✔
1898
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
7,124✔
1899
                // First, we'll decode the varint that encodes how many set IDs
3,562✔
1900
                // are encoded within the greater map.
3,562✔
1901
                numRecords, err := tlv.ReadVarInt(r, buf)
3,562✔
1902
                if err != nil {
3,562✔
1903
                        return err
×
1904
                }
×
1905

1906
                // Now that we know how many records we'll need to read, we can
1907
                // iterate and read them all out in series.
1908
                for i := uint64(0); i < numRecords; i++ {
3,650✔
1909
                        // Read out the varint that encodes the size of this
88✔
1910
                        // inner TLV record.
88✔
1911
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
88✔
1912
                        if err != nil {
88✔
1913
                                return err
×
1914
                        }
×
1915

1916
                        // Using this information, we'll create a new limited
1917
                        // reader that'll return an EOF once the end has been
1918
                        // reached so the stream stops consuming bytes.
1919
                        innerTlvReader := io.LimitedReader{
88✔
1920
                                R: r,
88✔
1921
                                N: int64(stateRecordSize),
88✔
1922
                        }
88✔
1923

88✔
1924
                        var (
88✔
1925
                                setID           [32]byte
88✔
1926
                                htlcState       uint8
88✔
1927
                                settleIndex     uint64
88✔
1928
                                settleDateBytes []byte
88✔
1929
                                invoiceKeys     = make(
88✔
1930
                                        map[models.CircuitKey]struct{},
88✔
1931
                                )
88✔
1932
                                amtPaid uint64
88✔
1933
                        )
88✔
1934
                        tlvStream, err := tlv.NewStream(
88✔
1935
                                tlv.MakePrimitiveRecord(
88✔
1936
                                        ampStateSetIDType, &setID,
88✔
1937
                                ),
88✔
1938
                                tlv.MakePrimitiveRecord(
88✔
1939
                                        ampStateHtlcStateType, &htlcState,
88✔
1940
                                ),
88✔
1941
                                tlv.MakePrimitiveRecord(
88✔
1942
                                        ampStateSettleIndexType, &settleIndex,
88✔
1943
                                ),
88✔
1944
                                tlv.MakePrimitiveRecord(
88✔
1945
                                        ampStateSettleDateType,
88✔
1946
                                        &settleDateBytes,
88✔
1947
                                ),
88✔
1948
                                tlv.MakeDynamicRecord(
88✔
1949
                                        ampStateCircuitKeysType,
88✔
1950
                                        &invoiceKeys, nil,
88✔
1951
                                        encodeCircuitKeys, decodeCircuitKeys,
88✔
1952
                                ),
88✔
1953
                                tlv.MakePrimitiveRecord(
88✔
1954
                                        ampStateAmtPaidType, &amtPaid,
88✔
1955
                                ),
88✔
1956
                        )
88✔
1957
                        if err != nil {
88✔
1958
                                return err
×
1959
                        }
×
1960

1961
                        err = tlvStream.Decode(&innerTlvReader)
88✔
1962
                        if err != nil {
88✔
1963
                                return err
×
1964
                        }
×
1965

1966
                        var settleDate time.Time
88✔
1967
                        err = settleDate.UnmarshalBinary(settleDateBytes)
88✔
1968
                        if err != nil {
88✔
1969
                                return err
×
1970
                        }
×
1971

1972
                        (*v)[setID] = invpkg.InvoiceStateAMP{
88✔
1973
                                State:       invpkg.HtlcState(htlcState),
88✔
1974
                                SettleIndex: settleIndex,
88✔
1975
                                SettleDate:  settleDate,
88✔
1976
                                InvoiceKeys: invoiceKeys,
88✔
1977
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
88✔
1978
                        }
88✔
1979
                }
1980

1981
                return nil
3,562✔
1982
        }
1983

1984
        return tlv.NewTypeForDecodingErr(
×
1985
                val, "channeldb.AMPInvoiceState", l, l,
×
1986
        )
×
1987
}
1988

1989
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
1990
// as a map.
1991
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1992
        error) {
3,599✔
1993

3,599✔
1994
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3,599✔
1995
        for {
8,725✔
1996
                // Read the length of the tlv stream for this htlc.
5,126✔
1997
                var streamLen int64
5,126✔
1998
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
8,725✔
1999
                        if err == io.EOF {
7,198✔
2000
                                break
3,599✔
2001
                        }
2002

2003
                        return nil, err
×
2004
                }
2005

2006
                // Limit the reader so that it stops at the end of this htlc's
2007
                // stream.
2008
                htlcReader := io.LimitReader(r, streamLen)
1,527✔
2009

1,527✔
2010
                // Decode the contents into the htlc fields.
1,527✔
2011
                var (
1,527✔
2012
                        htlc                    invpkg.InvoiceHTLC
1,527✔
2013
                        key                     models.CircuitKey
1,527✔
2014
                        chanID                  uint64
1,527✔
2015
                        state                   uint8
1,527✔
2016
                        acceptTime, resolveTime uint64
1,527✔
2017
                        amt, mppTotalAmt        uint64
1,527✔
2018
                        amp                     = &record.AMP{}
1,527✔
2019
                        hash32                  = &[32]byte{}
1,527✔
2020
                        preimage32              = &[32]byte{}
1,527✔
2021
                )
1,527✔
2022
                tlvStream, err := tlv.NewStream(
1,527✔
2023
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
1,527✔
2024
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
1,527✔
2025
                        tlv.MakePrimitiveRecord(amtType, &amt),
1,527✔
2026
                        tlv.MakePrimitiveRecord(
1,527✔
2027
                                acceptHeightType, &htlc.AcceptHeight,
1,527✔
2028
                        ),
1,527✔
2029
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
1,527✔
2030
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
1,527✔
2031
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
1,527✔
2032
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
1,527✔
2033
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
1,527✔
2034
                        tlv.MakeDynamicRecord(
1,527✔
2035
                                htlcAMPType, amp, amp.PayloadSize,
1,527✔
2036
                                record.AMPEncoder, record.AMPDecoder,
1,527✔
2037
                        ),
1,527✔
2038
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
1,527✔
2039
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
1,527✔
2040
                )
1,527✔
2041
                if err != nil {
1,527✔
2042
                        return nil, err
×
2043
                }
×
2044

2045
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
1,527✔
2046
                if err != nil {
1,527✔
2047
                        return nil, err
×
2048
                }
×
2049

2050
                if _, ok := parsedTypes[htlcAMPType]; !ok {
2,994✔
2051
                        amp = nil
1,467✔
2052
                }
1,467✔
2053

2054
                var preimage *lntypes.Preimage
1,527✔
2055
                if _, ok := parsedTypes[htlcPreimageType]; ok {
1,566✔
2056
                        pimg := lntypes.Preimage(*preimage32)
39✔
2057
                        preimage = &pimg
39✔
2058
                }
39✔
2059

2060
                var hash *lntypes.Hash
1,527✔
2061
                if _, ok := parsedTypes[htlcHashType]; ok {
1,587✔
2062
                        h := lntypes.Hash(*hash32)
60✔
2063
                        hash = &h
60✔
2064
                }
60✔
2065

2066
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
1,527✔
2067
                htlc.AcceptTime = getNanoTime(acceptTime)
1,527✔
2068
                htlc.ResolveTime = getNanoTime(resolveTime)
1,527✔
2069
                htlc.State = invpkg.HtlcState(state)
1,527✔
2070
                htlc.Amt = lnwire.MilliSatoshi(amt)
1,527✔
2071
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
1,527✔
2072
                if amp != nil && hash != nil {
1,587✔
2073
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
60✔
2074
                                Record:   *amp,
60✔
2075
                                Hash:     *hash,
60✔
2076
                                Preimage: preimage,
60✔
2077
                        }
60✔
2078
                }
60✔
2079

2080
                // Reconstruct the custom records fields from the parsed types
2081
                // map return from the tlv parser.
2082
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
1,527✔
2083

1,527✔
2084
                htlcs[key] = &htlc
1,527✔
2085
        }
2086

2087
        return htlcs, nil
3,599✔
2088
}
2089

2090
// invoiceSetIDKeyLen is the length of the key that's used to store the
2091
// individual HTLCs prefixed by their ID along side the main invoice within the
2092
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2093
// set ID.
2094
const invoiceSetIDKeyLen = 4 + 32
2095

2096
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2097
// number where the HTLCs for this setID will be stored udner.
2098
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
64✔
2099
        // Construct the prefix key we need to obtain the invoice information:
64✔
2100
        // invoiceNum || setID.
64✔
2101
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
64✔
2102
        copy(invoiceSetIDKey[:], invoiceNum)
64✔
2103
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
64✔
2104

64✔
2105
        return invoiceSetIDKey
64✔
2106
}
64✔
2107

2108
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
2109
// greater AMP invoices. We do this by deleting the set of keys that share the
2110
// invoice number as a prefix.
2111
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
15✔
2112
        // Since it isn't safe to delete using an active cursor, we'll use the
15✔
2113
        // cursor simply to collect the set of keys we need to delete, _then_
15✔
2114
        // delete them in another pass.
15✔
2115
        var keysToDel [][]byte
15✔
2116
        err := forEachAMPInvoice(
15✔
2117
                invoiceBucket, invoiceNum,
15✔
2118
                func(cursorKey, v []byte) error {
18✔
2119
                        keysToDel = append(keysToDel, cursorKey)
3✔
2120
                        return nil
3✔
2121
                },
3✔
2122
        )
2123
        if err != nil {
15✔
2124
                return err
×
2125
        }
×
2126

2127
        // In this next phase, we'll then delete all the relevant invoices.
2128
        for _, keyToDel := range keysToDel {
18✔
2129
                if err := invoiceBucket.Delete(keyToDel); err != nil {
3✔
2130
                        return err
×
2131
                }
×
2132
        }
2133

2134
        return nil
15✔
2135
}
2136

2137
// delAMPSettleIndex removes all the entries in the settle index associated
2138
// with a given AMP invoice.
2139
func delAMPSettleIndex(invoiceNum []byte, invoices,
2140
        settleIndex kvdb.RwBucket) error {
8✔
2141

8✔
2142
        // First, we need to grab the AMP invoice state to see if there's
8✔
2143
        // anything that we even need to delete.
8✔
2144
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
8✔
2145
        if err != nil {
8✔
2146
                return err
×
2147
        }
×
2148

2149
        // If there's no AMP state at all (non-AMP invoice), then we can return
2150
        // early.
2151
        if len(ampState) == 0 {
15✔
2152
                return nil
7✔
2153
        }
7✔
2154

2155
        // Otherwise, we'll need to iterate and delete each settle index within
2156
        // the set of returned entries.
2157
        var settleIndexKey [8]byte
1✔
2158
        for _, subState := range ampState {
4✔
2159
                byteOrder.PutUint64(
3✔
2160
                        settleIndexKey[:], subState.SettleIndex,
3✔
2161
                )
3✔
2162

3✔
2163
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
3✔
2164
                        return err
×
2165
                }
×
2166
        }
2167

2168
        return nil
1✔
2169
}
2170

2171
// DeleteCanceledInvoices deletes all canceled invoices from the database.
2172
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
3✔
2173
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
2174
                invoices := tx.ReadWriteBucket(invoiceBucket)
3✔
2175
                if invoices == nil {
3✔
2176
                        return nil
×
2177
                }
×
2178

2179
                invoiceIndex := invoices.NestedReadWriteBucket(
3✔
2180
                        invoiceIndexBucket,
3✔
2181
                )
3✔
2182
                if invoiceIndex == nil {
4✔
2183
                        return nil
1✔
2184
                }
1✔
2185

2186
                invoiceAddIndex := invoices.NestedReadWriteBucket(
2✔
2187
                        addIndexBucket,
2✔
2188
                )
2✔
2189
                if invoiceAddIndex == nil {
2✔
2190
                        return nil
×
2191
                }
×
2192

2193
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
2✔
2194

2✔
2195
                return invoiceIndex.ForEach(func(k, v []byte) error {
19✔
2196
                        // Skip the special numInvoicesKey as that does not
17✔
2197
                        // point to a valid invoice.
17✔
2198
                        if bytes.Equal(k, numInvoicesKey) {
19✔
2199
                                return nil
2✔
2200
                        }
2✔
2201

2202
                        // Skip sub-buckets.
2203
                        if v == nil {
15✔
2204
                                return nil
×
2205
                        }
×
2206

2207
                        invoice, err := fetchInvoice(v, invoices, nil, false)
15✔
2208
                        if err != nil {
15✔
2209
                                return err
×
2210
                        }
×
2211

2212
                        if invoice.State != invpkg.ContractCanceled {
23✔
2213
                                return nil
8✔
2214
                        }
8✔
2215

2216
                        // Delete the payment hash from the invoice index.
2217
                        err = invoiceIndex.Delete(k)
7✔
2218
                        if err != nil {
7✔
2219
                                return err
×
2220
                        }
×
2221

2222
                        // Delete payment address index reference if there's a
2223
                        // valid payment address.
2224
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
14✔
2225
                                // To ensure consistency check that the already
7✔
2226
                                // fetched invoice key matches the one in the
7✔
2227
                                // payment address index.
7✔
2228
                                key := payAddrIndex.Get(
7✔
2229
                                        invoice.Terms.PaymentAddr[:],
7✔
2230
                                )
7✔
2231
                                if bytes.Equal(key, k) {
7✔
2232
                                        // Delete from the payment address
×
2233
                                        // index.
×
2234
                                        if err := payAddrIndex.Delete(
×
2235
                                                invoice.Terms.PaymentAddr[:],
×
2236
                                        ); err != nil {
×
2237
                                                return err
×
2238
                                        }
×
2239
                                }
2240
                        }
2241

2242
                        // Remove from the add index.
2243
                        var addIndexKey [8]byte
7✔
2244
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
7✔
2245
                        err = invoiceAddIndex.Delete(addIndexKey[:])
7✔
2246
                        if err != nil {
7✔
2247
                                return err
×
2248
                        }
×
2249

2250
                        // Note that we don't need to delete the invoice from
2251
                        // the settle index as it is not added until the
2252
                        // invoice is settled.
2253

2254
                        // Now remove all sub invoices.
2255
                        err = delAMPInvoices(k, invoices)
7✔
2256
                        if err != nil {
7✔
2257
                                return err
×
2258
                        }
×
2259

2260
                        // Finally remove the serialized invoice from the
2261
                        // invoice bucket.
2262
                        return invoices.Delete(k)
7✔
2263
                })
2264
        }, func() {})
3✔
2265
}
2266

2267
// DeleteInvoice attempts to delete the passed invoices from the database in
2268
// one transaction. The passed delete references hold all keys required to
2269
// delete the invoices without also needing to deserialize them.
2270
func (d *DB) DeleteInvoice(_ context.Context,
2271
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
6✔
2272

6✔
2273
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
12✔
2274
                invoices := tx.ReadWriteBucket(invoiceBucket)
6✔
2275
                if invoices == nil {
6✔
2276
                        return invpkg.ErrNoInvoicesCreated
×
2277
                }
×
2278

2279
                invoiceIndex := invoices.NestedReadWriteBucket(
6✔
2280
                        invoiceIndexBucket,
6✔
2281
                )
6✔
2282
                if invoiceIndex == nil {
6✔
2283
                        return invpkg.ErrNoInvoicesCreated
×
2284
                }
×
2285

2286
                invoiceAddIndex := invoices.NestedReadWriteBucket(
6✔
2287
                        addIndexBucket,
6✔
2288
                )
6✔
2289
                if invoiceAddIndex == nil {
6✔
2290
                        return invpkg.ErrNoInvoicesCreated
×
2291
                }
×
2292

2293
                // settleIndex can be nil, as the bucket is created lazily
2294
                // when the first invoice is settled.
2295
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
6✔
2296

6✔
2297
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
6✔
2298

6✔
2299
                for _, ref := range invoicesToDelete {
17✔
2300
                        // Fetch the invoice key for using it to check for
11✔
2301
                        // consistency and also to delete from the invoice
11✔
2302
                        // index.
11✔
2303
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
11✔
2304
                        if invoiceKey == nil {
12✔
2305
                                return invpkg.ErrInvoiceNotFound
1✔
2306
                        }
1✔
2307

2308
                        err := invoiceIndex.Delete(ref.PayHash[:])
10✔
2309
                        if err != nil {
10✔
2310
                                return err
×
2311
                        }
×
2312

2313
                        // Delete payment address index reference if there's a
2314
                        // valid payment address passed.
2315
                        if ref.PayAddr != nil {
19✔
2316
                                // To ensure consistency check that the already
9✔
2317
                                // fetched invoice key matches the one in the
9✔
2318
                                // payment address index.
9✔
2319
                                key := payAddrIndex.Get(ref.PayAddr[:])
9✔
2320
                                if bytes.Equal(key, invoiceKey) {
18✔
2321
                                        // Delete from the payment address
9✔
2322
                                        // index. Note that since the payment
9✔
2323
                                        // address index has been introduced
9✔
2324
                                        // with an empty migration it may be
9✔
2325
                                        // possible that the index doesn't have
9✔
2326
                                        // an entry for this invoice.
9✔
2327
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
9✔
2328
                                        if err := payAddrIndex.Delete(
9✔
2329
                                                ref.PayAddr[:],
9✔
2330
                                        ); err != nil {
9✔
2331
                                                return err
×
2332
                                        }
×
2333
                                }
2334
                        }
2335

2336
                        var addIndexKey [8]byte
10✔
2337
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
10✔
2338

10✔
2339
                        // To ensure consistency check that the key stored in
10✔
2340
                        // the add index also matches the previously fetched
10✔
2341
                        // invoice key.
10✔
2342
                        key := invoiceAddIndex.Get(addIndexKey[:])
10✔
2343
                        if !bytes.Equal(key, invoiceKey) {
11✔
2344
                                return fmt.Errorf("unknown invoice " +
1✔
2345
                                        "in add index")
1✔
2346
                        }
1✔
2347

2348
                        // Remove from the add index.
2349
                        err = invoiceAddIndex.Delete(addIndexKey[:])
9✔
2350
                        if err != nil {
9✔
2351
                                return err
×
2352
                        }
×
2353

2354
                        // Remove from the settle index if available and
2355
                        // if the invoice is settled.
2356
                        if settleIndex != nil && ref.SettleIndex > 0 {
12✔
2357
                                var settleIndexKey [8]byte
3✔
2358
                                byteOrder.PutUint64(
3✔
2359
                                        settleIndexKey[:], ref.SettleIndex,
3✔
2360
                                )
3✔
2361

3✔
2362
                                // To ensure consistency check that the already
3✔
2363
                                // fetched invoice key matches the one in the
3✔
2364
                                // settle index
3✔
2365
                                key := settleIndex.Get(settleIndexKey[:])
3✔
2366
                                if !bytes.Equal(key, invoiceKey) {
4✔
2367
                                        return fmt.Errorf("unknown invoice " +
1✔
2368
                                                "in settle index")
1✔
2369
                                }
1✔
2370

2371
                                err = settleIndex.Delete(settleIndexKey[:])
2✔
2372
                                if err != nil {
2✔
2373
                                        return err
×
2374
                                }
×
2375
                        }
2376

2377
                        // In addition to deleting the main invoice state, if
2378
                        // this is an AMP invoice, then we'll also need to
2379
                        // delete the set HTLC set stored as a key prefix. For
2380
                        // non-AMP invoices, this'll be a noop.
2381
                        err = delAMPSettleIndex(
8✔
2382
                                invoiceKey, invoices, settleIndex,
8✔
2383
                        )
8✔
2384
                        if err != nil {
8✔
2385
                                return err
×
2386
                        }
×
2387
                        err = delAMPInvoices(invoiceKey, invoices)
8✔
2388
                        if err != nil {
8✔
2389
                                return err
×
2390
                        }
×
2391

2392
                        // Finally remove the serialized invoice from the
2393
                        // invoice bucket.
2394
                        err = invoices.Delete(invoiceKey)
8✔
2395
                        if err != nil {
8✔
2396
                                return err
×
2397
                        }
×
2398
                }
2399

2400
                return nil
3✔
2401
        }, func() {})
6✔
2402

2403
        return err
6✔
2404
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc