• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 9915780197

13 Jul 2024 12:30AM UTC coverage: 49.268% (-9.1%) from 58.413%
9915780197

push

github

web-flow
Merge pull request #8653 from ProofOfKeags/fn-prim

DynComms [0/n]: `fn` package additions

92837 of 188433 relevant lines covered (49.27%)

1.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.34
/channeldb/invoices.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "time"
11

12
        "github.com/lightningnetwork/lnd/channeldb/models"
13
        "github.com/lightningnetwork/lnd/htlcswitch/hop"
14
        invpkg "github.com/lightningnetwork/lnd/invoices"
15
        "github.com/lightningnetwork/lnd/kvdb"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/tlv"
20
)
21

22
var (
23
        // invoiceBucket is the name of the bucket within the database that
24
        // stores all data related to invoices no matter their final state.
25
        // Within the invoice bucket, each invoice is keyed by its invoice ID
26
        // which is a monotonically increasing uint32.
27
        invoiceBucket = []byte("invoices")
28

29
        // paymentHashIndexBucket is the name of the sub-bucket within the
30
        // invoiceBucket which indexes all invoices by their payment hash. The
31
        // payment hash is the sha256 of the invoice's payment preimage. This
32
        // index is used to detect duplicates, and also to provide a fast path
33
        // for looking up incoming HTLCs to determine if we're able to settle
34
        // them fully.
35
        //
36
        // maps: payHash => invoiceKey
37
        invoiceIndexBucket = []byte("paymenthashes")
38

39
        // payAddrIndexBucket is the name of the top-level bucket that maps
40
        // payment addresses to their invoice number. This can be used
41
        // to efficiently query or update non-legacy invoices. Note that legacy
42
        // invoices will not be included in this index since they all have the
43
        // same, all-zero payment address, however all newly generated invoices
44
        // will end up in this index.
45
        //
46
        // maps: payAddr => invoiceKey
47
        payAddrIndexBucket = []byte("pay-addr-index")
48

49
        // setIDIndexBucket is the name of the top-level bucket that maps set
50
        // ids to their invoice number. This can be used to efficiently query or
51
        // update AMP invoice. Note that legacy or MPP invoices will not be
52
        // included in this index, since their HTLCs do not have a set id.
53
        //
54
        // maps: setID => invoiceKey
55
        setIDIndexBucket = []byte("set-id-index")
56

57
        // numInvoicesKey is the name of key which houses the auto-incrementing
58
        // invoice ID which is essentially used as a primary key. With each
59
        // invoice inserted, the primary key is incremented by one. This key is
60
        // stored within the invoiceIndexBucket. Within the invoiceBucket
61
        // invoices are uniquely identified by the invoice ID.
62
        numInvoicesKey = []byte("nik")
63

64
        // addIndexBucket is an index bucket that we'll use to create a
65
        // monotonically increasing set of add indexes. Each time we add a new
66
        // invoice, this sequence number will be incremented and then populated
67
        // within the new invoice.
68
        //
69
        // In addition to this sequence number, we map:
70
        //
71
        //   addIndexNo => invoiceKey
72
        addIndexBucket = []byte("invoice-add-index")
73

74
        // settleIndexBucket is an index bucket that we'll use to create a
75
        // monotonically increasing integer for tracking a "settle index". Each
76
        // time an invoice is settled, this sequence number will be incremented
77
        // as populate within the newly settled invoice.
78
        //
79
        // In addition to this sequence number, we map:
80
        //
81
        //   settleIndexNo => invoiceKey
82
        settleIndexBucket = []byte("invoice-settle-index")
83
)
84

85
const (
86
        // A set of tlv type definitions used to serialize invoice htlcs to the
87
        // database.
88
        //
89
        // NOTE: A migration should be added whenever this list changes. This
90
        // prevents against the database being rolled back to an older
91
        // format where the surrounding logic might assume a different set of
92
        // fields are known.
93
        chanIDType       tlv.Type = 1
94
        htlcIDType       tlv.Type = 3
95
        amtType          tlv.Type = 5
96
        acceptHeightType tlv.Type = 7
97
        acceptTimeType   tlv.Type = 9
98
        resolveTimeType  tlv.Type = 11
99
        expiryHeightType tlv.Type = 13
100
        htlcStateType    tlv.Type = 15
101
        mppTotalAmtType  tlv.Type = 17
102
        htlcAMPType      tlv.Type = 19
103
        htlcHashType     tlv.Type = 21
104
        htlcPreimageType tlv.Type = 23
105

106
        // A set of tlv type definitions used to serialize invoice bodiees.
107
        //
108
        // NOTE: A migration should be added whenever this list changes. This
109
        // prevents against the database being rolled back to an older
110
        // format where the surrounding logic might assume a different set of
111
        // fields are known.
112
        memoType            tlv.Type = 0
113
        payReqType          tlv.Type = 1
114
        createTimeType      tlv.Type = 2
115
        settleTimeType      tlv.Type = 3
116
        addIndexType        tlv.Type = 4
117
        settleIndexType     tlv.Type = 5
118
        preimageType        tlv.Type = 6
119
        valueType           tlv.Type = 7
120
        cltvDeltaType       tlv.Type = 8
121
        expiryType          tlv.Type = 9
122
        paymentAddrType     tlv.Type = 10
123
        featuresType        tlv.Type = 11
124
        invStateType        tlv.Type = 12
125
        amtPaidType         tlv.Type = 13
126
        hodlInvoiceType     tlv.Type = 14
127
        invoiceAmpStateType tlv.Type = 15
128

129
        // A set of tlv type definitions used to serialize the invoice AMP
130
        // state along-side the main invoice body.
131
        ampStateSetIDType       tlv.Type = 0
132
        ampStateHtlcStateType   tlv.Type = 1
133
        ampStateSettleIndexType tlv.Type = 2
134
        ampStateSettleDateType  tlv.Type = 3
135
        ampStateCircuitKeysType tlv.Type = 4
136
        ampStateAmtPaidType     tlv.Type = 5
137
)
138

139
// AddInvoice inserts the targeted invoice into the database. If the invoice has
140
// *any* payment hashes which already exists within the database, then the
141
// insertion will be aborted and rejected due to the strict policy banning any
142
// duplicate payment hashes. A side effect of this function is that it sets
143
// AddIndex on newInvoice.
144
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
145
        paymentHash lntypes.Hash) (uint64, error) {
3✔
146

3✔
147
        if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
3✔
148
                return 0, err
×
149
        }
×
150

151
        var invoiceAddIndex uint64
3✔
152
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
153
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
3✔
154
                if err != nil {
3✔
155
                        return err
×
156
                }
×
157

158
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
159
                        invoiceIndexBucket,
3✔
160
                )
3✔
161
                if err != nil {
3✔
162
                        return err
×
163
                }
×
164
                addIndex, err := invoices.CreateBucketIfNotExists(
3✔
165
                        addIndexBucket,
3✔
166
                )
3✔
167
                if err != nil {
3✔
168
                        return err
×
169
                }
×
170

171
                // Ensure that an invoice an identical payment hash doesn't
172
                // already exist within the index.
173
                if invoiceIndex.Get(paymentHash[:]) != nil {
3✔
174
                        return invpkg.ErrDuplicateInvoice
×
175
                }
×
176

177
                // Check that we aren't inserting an invoice with a duplicate
178
                // payment address. The all-zeros payment address is
179
                // special-cased to support legacy keysend invoices which don't
180
                // assign one. This is safe since later we also will avoid
181
                // indexing them and avoid collisions.
182
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
3✔
183
                if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
6✔
184
                        paymentAddr := newInvoice.Terms.PaymentAddr[:]
3✔
185
                        if payAddrIndex.Get(paymentAddr) != nil {
6✔
186
                                return invpkg.ErrDuplicatePayAddr
3✔
187
                        }
3✔
188
                }
189

190
                // If the current running payment ID counter hasn't yet been
191
                // created, then create it now.
192
                var invoiceNum uint32
3✔
193
                invoiceCounter := invoiceIndex.Get(numInvoicesKey)
3✔
194
                if invoiceCounter == nil {
6✔
195
                        var scratch [4]byte
3✔
196
                        byteOrder.PutUint32(scratch[:], invoiceNum)
3✔
197
                        err := invoiceIndex.Put(numInvoicesKey, scratch[:])
3✔
198
                        if err != nil {
3✔
199
                                return err
×
200
                        }
×
201
                } else {
3✔
202
                        invoiceNum = byteOrder.Uint32(invoiceCounter)
3✔
203
                }
3✔
204

205
                newIndex, err := putInvoice(
3✔
206
                        invoices, invoiceIndex, payAddrIndex, addIndex,
3✔
207
                        newInvoice, invoiceNum, paymentHash,
3✔
208
                )
3✔
209
                if err != nil {
3✔
210
                        return err
×
211
                }
×
212

213
                invoiceAddIndex = newIndex
3✔
214
                return nil
3✔
215
        }, func() {
3✔
216
                invoiceAddIndex = 0
3✔
217
        })
3✔
218
        if err != nil {
6✔
219
                return 0, err
3✔
220
        }
3✔
221

222
        return invoiceAddIndex, err
3✔
223
}
224

225
// InvoicesAddedSince can be used by callers to seek into the event time series
226
// of all the invoices added in the database. The specified sinceAddIndex
227
// should be the highest add index that the caller knows of. This method will
228
// return all invoices with an add index greater than the specified
229
// sinceAddIndex.
230
//
231
// NOTE: The index starts from 1, as a result. We enforce that specifying a
232
// value below the starting index value is a noop.
233
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
234
        []invpkg.Invoice, error) {
3✔
235

3✔
236
        var newInvoices []invpkg.Invoice
3✔
237

3✔
238
        // If an index of zero was specified, then in order to maintain
3✔
239
        // backwards compat, we won't send out any new invoices.
3✔
240
        if sinceAddIndex == 0 {
6✔
241
                return newInvoices, nil
3✔
242
        }
3✔
243

244
        var startIndex [8]byte
3✔
245
        byteOrder.PutUint64(startIndex[:], sinceAddIndex)
3✔
246

3✔
247
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
248
                invoices := tx.ReadBucket(invoiceBucket)
3✔
249
                if invoices == nil {
3✔
250
                        return nil
×
251
                }
×
252

253
                addIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
254
                if addIndex == nil {
3✔
255
                        return nil
×
256
                }
×
257

258
                // We'll now run through each entry in the add index starting
259
                // at our starting index. We'll continue until we reach the
260
                // very end of the current key space.
261
                invoiceCursor := addIndex.ReadCursor()
3✔
262

3✔
263
                // We'll seek to the starting index, then manually advance the
3✔
264
                // cursor in order to skip the entry with the since add index.
3✔
265
                invoiceCursor.Seek(startIndex[:])
3✔
266
                addSeqNo, invoiceKey := invoiceCursor.Next()
3✔
267

3✔
268
                for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
6✔
269

3✔
270
                        // For each key found, we'll look up the actual
3✔
271
                        // invoice, then accumulate it into our return value.
3✔
272
                        invoice, err := fetchInvoice(invoiceKey, invoices)
3✔
273
                        if err != nil {
3✔
274
                                return err
×
275
                        }
×
276

277
                        newInvoices = append(newInvoices, invoice)
3✔
278
                }
279

280
                return nil
3✔
281
        }, func() {
3✔
282
                newInvoices = nil
3✔
283
        })
3✔
284
        if err != nil {
3✔
285
                return nil, err
×
286
        }
×
287

288
        return newInvoices, nil
3✔
289
}
290

291
// LookupInvoice attempts to look up an invoice according to its 32 byte
292
// payment hash. If an invoice which can settle the HTLC identified by the
293
// passed payment hash isn't found, then an error is returned. Otherwise, the
294
// full invoice is returned. Before setting the incoming HTLC, the values
295
// SHOULD be checked to ensure the payer meets the agreed upon contractual
296
// terms of the payment.
297
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
298
        invpkg.Invoice, error) {
3✔
299

3✔
300
        var invoice invpkg.Invoice
3✔
301
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
302
                invoices := tx.ReadBucket(invoiceBucket)
3✔
303
                if invoices == nil {
3✔
304
                        return invpkg.ErrNoInvoicesCreated
×
305
                }
×
306
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
307
                if invoiceIndex == nil {
6✔
308
                        return invpkg.ErrNoInvoicesCreated
3✔
309
                }
3✔
310
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
3✔
311
                setIDIndex := tx.ReadBucket(setIDIndexBucket)
3✔
312

3✔
313
                // Retrieve the invoice number for this invoice using
3✔
314
                // the provided invoice reference.
3✔
315
                invoiceNum, err := fetchInvoiceNumByRef(
3✔
316
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
3✔
317
                )
3✔
318
                if err != nil {
6✔
319
                        return err
3✔
320
                }
3✔
321

322
                var setID *invpkg.SetID
3✔
323
                switch {
3✔
324
                // If this is a payment address ref, and the blank modified was
325
                // specified, then we'll use the zero set ID to indicate that
326
                // we won't want any HTLCs returned.
327
                case ref.PayAddr() != nil &&
328
                        ref.Modifier() == invpkg.HtlcSetBlankModifier:
3✔
329

3✔
330
                        var zeroSetID invpkg.SetID
3✔
331
                        setID = &zeroSetID
3✔
332

333
                // If this is a set ID ref, and the htlc set only modified was
334
                // specified, then we'll pass through the specified setID so
335
                // only that will be returned.
336
                case ref.SetID() != nil &&
337
                        ref.Modifier() == invpkg.HtlcSetOnlyModifier:
3✔
338

3✔
339
                        setID = (*invpkg.SetID)(ref.SetID())
3✔
340
                }
341

342
                // An invoice was found, retrieve the remainder of the invoice
343
                // body.
344
                i, err := fetchInvoice(invoiceNum, invoices, setID)
3✔
345
                if err != nil {
3✔
346
                        return err
×
347
                }
×
348
                invoice = i
3✔
349

3✔
350
                return nil
3✔
351
        }, func() {})
3✔
352
        if err != nil {
6✔
353
                return invoice, err
3✔
354
        }
3✔
355

356
        return invoice, nil
3✔
357
}
358

359
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
360
// reference. The payment address will be treated as the primary key, falling
361
// back to the payment hash if nothing is found for the payment address. An
362
// error is returned if the invoice is not found.
363
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
364
        ref invpkg.InvoiceRef) ([]byte, error) {
3✔
365

3✔
366
        // If the set id is present, we only consult the set id index for this
3✔
367
        // invoice. This type of query is only used to facilitate user-facing
3✔
368
        // requests to lookup, settle or cancel an AMP invoice.
3✔
369
        setID := ref.SetID()
3✔
370
        if setID != nil {
6✔
371
                invoiceNumBySetID := setIDIndex.Get(setID[:])
3✔
372
                if invoiceNumBySetID == nil {
3✔
373
                        return nil, invpkg.ErrInvoiceNotFound
×
374
                }
×
375

376
                return invoiceNumBySetID, nil
3✔
377
        }
378

379
        payHash := ref.PayHash()
3✔
380
        payAddr := ref.PayAddr()
3✔
381

3✔
382
        getInvoiceNumByHash := func() []byte {
6✔
383
                if payHash != nil {
6✔
384
                        return invoiceIndex.Get(payHash[:])
3✔
385
                }
3✔
386
                return nil
3✔
387
        }
388

389
        getInvoiceNumByAddr := func() []byte {
6✔
390
                if payAddr != nil {
6✔
391
                        // Only allow lookups for payment address if it is not a
3✔
392
                        // blank payment address, which is a special-cased value
3✔
393
                        // for legacy keysend invoices.
3✔
394
                        if *payAddr != invpkg.BlankPayAddr {
6✔
395
                                return payAddrIndex.Get(payAddr[:])
3✔
396
                        }
3✔
397
                }
398
                return nil
3✔
399
        }
400

401
        invoiceNumByHash := getInvoiceNumByHash()
3✔
402
        invoiceNumByAddr := getInvoiceNumByAddr()
3✔
403
        switch {
3✔
404
        // If payment address and payment hash both reference an existing
405
        // invoice, ensure they reference the _same_ invoice.
406
        case invoiceNumByAddr != nil && invoiceNumByHash != nil:
3✔
407
                if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
3✔
408
                        return nil, invpkg.ErrInvRefEquivocation
×
409
                }
×
410

411
                return invoiceNumByAddr, nil
3✔
412

413
        // Return invoices by payment addr only.
414
        //
415
        // NOTE: We constrain this lookup to only apply if the invoice ref does
416
        // not contain a payment hash. Legacy and MPP payments depend on the
417
        // payment hash index to enforce that the HTLCs payment hash matches the
418
        // payment hash for the invoice, without this check we would
419
        // inadvertently assume the invoice contains the correct preimage for
420
        // the HTLC, which we only enforce via the lookup by the invoice index.
421
        case invoiceNumByAddr != nil && payHash == nil:
3✔
422
                return invoiceNumByAddr, nil
3✔
423

424
        // If we were only able to reference the invoice by hash, return the
425
        // corresponding invoice number. This can happen when no payment address
426
        // was provided, or if it didn't match anything in our records.
427
        case invoiceNumByHash != nil:
3✔
428
                return invoiceNumByHash, nil
3✔
429

430
        // Otherwise we don't know of the target invoice.
431
        default:
3✔
432
                return nil, invpkg.ErrInvoiceNotFound
3✔
433
        }
434
}
435

436
// FetchPendingInvoices returns all invoices that have not yet been settled or
437
// canceled. The returned map is keyed by the payment hash of each respective
438
// invoice.
439
func (d *DB) FetchPendingInvoices(_ context.Context) (
440
        map[lntypes.Hash]invpkg.Invoice, error) {
3✔
441

3✔
442
        result := make(map[lntypes.Hash]invpkg.Invoice)
3✔
443

3✔
444
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
445
                invoices := tx.ReadBucket(invoiceBucket)
3✔
446
                if invoices == nil {
3✔
447
                        return nil
×
448
                }
×
449

450
                invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
3✔
451
                if invoiceIndex == nil {
6✔
452
                        // Mask the error if there's no invoice
3✔
453
                        // index as that simply means there are no
3✔
454
                        // invoices added yet to the DB. In this case
3✔
455
                        // we simply return an empty list.
3✔
456
                        return nil
3✔
457
                }
3✔
458

459
                return invoiceIndex.ForEach(func(k, v []byte) error {
6✔
460
                        // Skip the special numInvoicesKey as that does not
3✔
461
                        // point to a valid invoice.
3✔
462
                        if bytes.Equal(k, numInvoicesKey) {
6✔
463
                                return nil
3✔
464
                        }
3✔
465

466
                        // Skip sub-buckets.
467
                        if v == nil {
3✔
468
                                return nil
×
469
                        }
×
470

471
                        invoice, err := fetchInvoice(v, invoices)
3✔
472
                        if err != nil {
3✔
473
                                return err
×
474
                        }
×
475

476
                        if invoice.IsPending() {
6✔
477
                                var paymentHash lntypes.Hash
3✔
478
                                copy(paymentHash[:], k)
3✔
479
                                result[paymentHash] = invoice
3✔
480
                        }
3✔
481

482
                        return nil
3✔
483
                })
484
        }, func() {
3✔
485
                result = make(map[lntypes.Hash]invpkg.Invoice)
3✔
486
        })
3✔
487

488
        if err != nil {
3✔
489
                return nil, err
×
490
        }
×
491

492
        return result, nil
3✔
493
}
494

495
// QueryInvoices allows a caller to query the invoice database for invoices
496
// within the specified add index range.
497
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
498
        invpkg.InvoiceSlice, error) {
3✔
499

3✔
500
        var resp invpkg.InvoiceSlice
3✔
501

3✔
502
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
503
                // If the bucket wasn't found, then there aren't any invoices
3✔
504
                // within the database yet, so we can simply exit.
3✔
505
                invoices := tx.ReadBucket(invoiceBucket)
3✔
506
                if invoices == nil {
3✔
507
                        return invpkg.ErrNoInvoicesCreated
×
508
                }
×
509

510
                // Get the add index bucket which we will use to iterate through
511
                // our indexed invoices.
512
                invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
3✔
513
                if invoiceAddIndex == nil {
6✔
514
                        return invpkg.ErrNoInvoicesCreated
3✔
515
                }
3✔
516

517
                // Create a paginator which reads from our add index bucket with
518
                // the parameters provided by the invoice query.
519
                paginator := newPaginator(
3✔
520
                        invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
3✔
521
                        q.NumMaxInvoices,
3✔
522
                )
3✔
523

3✔
524
                // accumulateInvoices looks up an invoice based on the index we
3✔
525
                // are given, adds it to our set of invoices if it has the right
3✔
526
                // characteristics for our query and returns the number of items
3✔
527
                // we have added to our set of invoices.
3✔
528
                accumulateInvoices := func(_, indexValue []byte) (bool, error) {
6✔
529
                        invoice, err := fetchInvoice(indexValue, invoices)
3✔
530
                        if err != nil {
3✔
531
                                return false, err
×
532
                        }
×
533

534
                        // Skip any settled or canceled invoices if the caller
535
                        // is only interested in pending ones.
536
                        if q.PendingOnly && !invoice.IsPending() {
3✔
537
                                return false, nil
×
538
                        }
×
539

540
                        // Get the creation time in Unix seconds, this always
541
                        // rounds down the nanoseconds to full seconds.
542
                        createTime := invoice.CreationDate.Unix()
3✔
543

3✔
544
                        // Skip any invoices that were created before the
3✔
545
                        // specified time.
3✔
546
                        if createTime < q.CreationDateStart {
6✔
547
                                return false, nil
3✔
548
                        }
3✔
549

550
                        // Skip any invoices that were created after the
551
                        // specified time.
552
                        if q.CreationDateEnd != 0 &&
3✔
553
                                createTime > q.CreationDateEnd {
6✔
554

3✔
555
                                return false, nil
3✔
556
                        }
3✔
557

558
                        // At this point, we've exhausted the offset, so we'll
559
                        // begin collecting invoices found within the range.
560
                        resp.Invoices = append(resp.Invoices, invoice)
3✔
561

3✔
562
                        return true, nil
3✔
563
                }
564

565
                // Query our paginator using accumulateInvoices to build up a
566
                // set of invoices.
567
                if err := paginator.query(accumulateInvoices); err != nil {
3✔
568
                        return err
×
569
                }
×
570

571
                // If we iterated through the add index in reverse order, then
572
                // we'll need to reverse the slice of invoices to return them in
573
                // forward order.
574
                if q.Reversed {
3✔
575
                        numInvoices := len(resp.Invoices)
×
576
                        for i := 0; i < numInvoices/2; i++ {
×
577
                                reverse := numInvoices - i - 1
×
578
                                resp.Invoices[i], resp.Invoices[reverse] =
×
579
                                        resp.Invoices[reverse], resp.Invoices[i]
×
580
                        }
×
581
                }
582

583
                return nil
3✔
584
        }, func() {
3✔
585
                resp = invpkg.InvoiceSlice{
3✔
586
                        InvoiceQuery: q,
3✔
587
                }
3✔
588
        })
3✔
589
        if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
3✔
590
                return resp, err
×
591
        }
×
592

593
        // Finally, record the indexes of the first and last invoices returned
594
        // so that the caller can resume from this point later on.
595
        if len(resp.Invoices) > 0 {
6✔
596
                resp.FirstIndexOffset = resp.Invoices[0].AddIndex
3✔
597
                lastIdx := len(resp.Invoices) - 1
3✔
598
                resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
3✔
599
        }
3✔
600

601
        return resp, nil
3✔
602
}
603

604
// UpdateInvoice attempts to update an invoice corresponding to the passed
605
// payment hash. If an invoice matching the passed payment hash doesn't exist
606
// within the database, then the action will fail with a "not found" error.
607
//
608
// The update is performed inside the same database transaction that fetches the
609
// invoice and is therefore atomic. The fields to update are controlled by the
610
// supplied callback.  When updating an invoice, the update itself happens
611
// in-memory on a copy of the invoice. Once it is written successfully to the
612
// database, the in-memory copy is returned to the caller.
613
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
614
        setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
615
        *invpkg.Invoice, error) {
3✔
616

3✔
617
        var updatedInvoice *invpkg.Invoice
3✔
618
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
6✔
619
                invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
3✔
620
                if err != nil {
3✔
621
                        return err
×
622
                }
×
623
                invoiceIndex, err := invoices.CreateBucketIfNotExists(
3✔
624
                        invoiceIndexBucket,
3✔
625
                )
3✔
626
                if err != nil {
3✔
627
                        return err
×
628
                }
×
629
                settleIndex, err := invoices.CreateBucketIfNotExists(
3✔
630
                        settleIndexBucket,
3✔
631
                )
3✔
632
                if err != nil {
3✔
633
                        return err
×
634
                }
×
635
                payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
3✔
636
                setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
3✔
637

3✔
638
                // Retrieve the invoice number for this invoice using the
3✔
639
                // provided invoice reference.
3✔
640
                invoiceNum, err := fetchInvoiceNumByRef(
3✔
641
                        invoiceIndex, payAddrIndex, setIDIndex, ref,
3✔
642
                )
3✔
643
                if err != nil {
6✔
644
                        return err
3✔
645
                }
3✔
646

647
                // If the set ID hint is non-nil, then we'll use that to filter
648
                // out the HTLCs for AMP invoice so we don't need to read them
649
                // all out to satisfy the invoice callback below. If it's nil,
650
                // then we pass in the zero set ID which means no HTLCs will be
651
                // read out.
652
                var invSetID invpkg.SetID
3✔
653

3✔
654
                if setIDHint != nil {
6✔
655
                        invSetID = *setIDHint
3✔
656
                }
3✔
657
                invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
3✔
658
                if err != nil {
3✔
659
                        return err
×
660
                }
×
661

662
                now := d.clock.Now()
3✔
663
                updater := &kvInvoiceUpdater{
3✔
664
                        db:                d,
3✔
665
                        invoicesBucket:    invoices,
3✔
666
                        settleIndexBucket: settleIndex,
3✔
667
                        setIDIndexBucket:  setIDIndex,
3✔
668
                        updateTime:        now,
3✔
669
                        invoiceNum:        invoiceNum,
3✔
670
                        invoice:           &invoice,
3✔
671
                        updatedAmpHtlcs:   make(ampHTLCsMap),
3✔
672
                        settledSetIDs:     make(map[invpkg.SetID]struct{}),
3✔
673
                }
3✔
674

3✔
675
                payHash := ref.PayHash()
3✔
676
                updatedInvoice, err = invpkg.UpdateInvoice(
3✔
677
                        payHash, updater.invoice, now, callback, updater,
3✔
678
                )
3✔
679

3✔
680
                return err
3✔
681
        }, func() {
3✔
682
                updatedInvoice = nil
3✔
683
        })
3✔
684

685
        return updatedInvoice, err
3✔
686
}
687

688
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
689
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
690

691
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
692
// is used with the kv implementation of the invoice database. Note that this
693
// updater is not concurrency safe and synchronizaton is expected to be handled
694
// on the DB level.
695
type kvInvoiceUpdater struct {
696
        db                *DB
697
        invoicesBucket    kvdb.RwBucket
698
        settleIndexBucket kvdb.RwBucket
699
        setIDIndexBucket  kvdb.RwBucket
700

701
        // updateTime is the timestamp for the update.
702
        updateTime time.Time
703

704
        // invoiceNum is a legacy key similar to the add index that is used
705
        // only in the kv implementation.
706
        invoiceNum []byte
707

708
        // invoice is the invoice that we're updating. As a side effect of the
709
        // update this invoice will be mutated.
710
        invoice *invpkg.Invoice
711

712
        // updatedAmpHtlcs holds the set of AMP HTLCs that were added or
713
        // cancelled as part of this update.
714
        updatedAmpHtlcs ampHTLCsMap
715

716
        // settledSetIDs holds the set IDs that are settled with this update.
717
        settledSetIDs map[invpkg.SetID]struct{}
718
}
719

720
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
721
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
722
        _ *invpkg.InvoiceHTLC) error {
3✔
723

3✔
724
        return nil
3✔
725
}
3✔
726

727
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
728
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
729
        _ time.Time) error {
3✔
730

3✔
731
        return nil
3✔
732
}
3✔
733

734
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
735
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
736
        _ lntypes.Preimage) error {
3✔
737

3✔
738
        return nil
3✔
739
}
3✔
740

741
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
742
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
743
        _ *lntypes.Preimage) error {
3✔
744

3✔
745
        return nil
3✔
746
}
3✔
747

748
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
749
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
3✔
750
        return nil
3✔
751
}
3✔
752

753
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
754
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
755
        state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
3✔
756

3✔
757
        if _, ok := k.updatedAmpHtlcs[setID]; !ok {
6✔
758
                switch state.State {
3✔
759
                case invpkg.HtlcStateAccepted:
3✔
760
                        // If we're just now creating the HTLCs for this set
3✔
761
                        // then we'll also pull in the existing HTLCs that are
3✔
762
                        // part of this set, so we can write them all to disk
3✔
763
                        // together (same value)
3✔
764
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
3✔
765
                                &setID, invpkg.HtlcStateAccepted,
3✔
766
                        )
3✔
767

768
                case invpkg.HtlcStateCanceled:
×
769
                        // Only HTLCs in the accepted state, can be cancelled,
×
770
                        // but we also want to merge that with HTLCs that may be
×
771
                        // canceled as well since it can be cancelled one by
×
772
                        // one.
×
773
                        k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
×
774
                                &setID, invpkg.HtlcStateAccepted,
×
775
                        )
×
776

×
777
                        cancelledHtlcs := k.invoice.HTLCSet(
×
778
                                &setID, invpkg.HtlcStateCanceled,
×
779
                        )
×
780
                        for htlcKey, htlc := range cancelledHtlcs {
×
781
                                k.updatedAmpHtlcs[setID][htlcKey] = htlc
×
782
                        }
×
783

784
                case invpkg.HtlcStateSettled:
×
785
                        k.updatedAmpHtlcs[setID] = make(
×
786
                                map[models.CircuitKey]*invpkg.InvoiceHTLC,
×
787
                        )
×
788
                }
789
        }
790

791
        if state.State == invpkg.HtlcStateSettled {
6✔
792
                // Add the set ID to the set that was settled in this invoice
3✔
793
                // update. We'll use this later to update the settle index.
3✔
794
                k.settledSetIDs[setID] = struct{}{}
3✔
795
        }
3✔
796

797
        k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
3✔
798

3✔
799
        return nil
3✔
800
}
801

802
// Finalize finalizes the update before it is written to the database.
803
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
3✔
804
        switch updateType {
3✔
805
        case invpkg.AddHTLCsUpdate:
3✔
806
                return k.storeAddHtlcsUpdate()
3✔
807

808
        case invpkg.CancelHTLCsUpdate:
×
809
                return k.storeCancelHtlcsUpdate()
×
810

811
        case invpkg.SettleHodlInvoiceUpdate:
3✔
812
                return k.storeSettleHodlInvoiceUpdate()
3✔
813

814
        case invpkg.CancelInvoiceUpdate:
3✔
815
                return k.serializeAndStoreInvoice()
3✔
816
        }
817

818
        return fmt.Errorf("unknown update type: %v", updateType)
×
819
}
820

821
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
822
// set of HTLCs.
823
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
×
824
        err := k.serializeAndStoreInvoice()
×
825
        if err != nil {
×
826
                return err
×
827
        }
×
828

829
        // If this is an AMP invoice, then we'll actually store the rest
830
        // of the HTLCs in-line with the invoice, using the invoice ID
831
        // as a prefix, and the AMP key as a suffix: invoiceNum ||
832
        // setID.
833
        if k.invoice.IsAMP() {
×
834
                return k.updateAMPInvoices()
×
835
        }
×
836

837
        return nil
×
838
}
839

840
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
841
// HTLCs.
842
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
3✔
843
        invoiceIsAMP := k.invoice.IsAMP()
3✔
844

3✔
845
        for htlcSetID := range k.updatedAmpHtlcs {
6✔
846
                // Check if this SetID already exist.
3✔
847
                setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
3✔
848

3✔
849
                if setIDInvNum == nil {
6✔
850
                        err := k.setIDIndexBucket.Put(
3✔
851
                                htlcSetID[:], k.invoiceNum,
3✔
852
                        )
3✔
853
                        if err != nil {
3✔
854
                                return err
×
855
                        }
×
856
                } else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
3✔
857
                        return invpkg.ErrDuplicateSetID{
×
858
                                SetID: htlcSetID,
×
859
                        }
×
860
                }
×
861
        }
862

863
        // If this is a non-AMP invoice, then the state can eventually go to
864
        // ContractSettled, so we pass in nil value as part of
865
        // setSettleMetaFields.
866
        if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
6✔
867
                err := k.setSettleMetaFields(nil)
3✔
868
                if err != nil {
3✔
869
                        return err
×
870
                }
×
871
        }
872

873
        // As we don't update the settle index above for AMP invoices, we'll do
874
        // it here for each sub-AMP invoice that was settled.
875
        for settledSetID := range k.settledSetIDs {
6✔
876
                settledSetID := settledSetID
3✔
877
                err := k.setSettleMetaFields(&settledSetID)
3✔
878
                if err != nil {
3✔
879
                        return err
×
880
                }
×
881
        }
882

883
        err := k.serializeAndStoreInvoice()
3✔
884
        if err != nil {
3✔
885
                return err
×
886
        }
×
887

888
        // If this is an AMP invoice, then we'll actually store the rest of the
889
        // HTLCs in-line with the invoice, using the invoice ID as a prefix,
890
        // and the AMP key as a suffix: invoiceNum || setID.
891
        if invoiceIsAMP {
6✔
892
                return k.updateAMPInvoices()
3✔
893
        }
3✔
894

895
        return nil
3✔
896
}
897

898
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
899
// settling a hodl invoice.
900
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
3✔
901
        err := k.setSettleMetaFields(nil)
3✔
902
        if err != nil {
3✔
903
                return err
×
904
        }
×
905

906
        return k.serializeAndStoreInvoice()
3✔
907
}
908

909
// setSettleMetaFields updates the metadata associated with settlement of an
910
// invoice. If a non-nil setID is passed in, then the value will be append to
911
// the invoice number as well, in order to allow us to detect repeated payments
912
// to the same AMP invoices "across time".
913
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
3✔
914
        // Now that we know the invoice hasn't already been settled, we'll
3✔
915
        // update the settle index so we can place this settle event in the
3✔
916
        // proper location within our time series.
3✔
917
        nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
3✔
918
        if err != nil {
3✔
919
                return err
×
920
        }
×
921

922
        // Make a new byte array on the stack that can potentially store the 4
923
        // byte invoice number along w/ the 32 byte set ID. We capture valueLen
924
        // here which is the number of bytes copied so we can only store the 4
925
        // bytes if this is a non-AMP invoice.
926
        var indexKey [invoiceSetIDKeyLen]byte
3✔
927
        valueLen := copy(indexKey[:], k.invoiceNum)
3✔
928

3✔
929
        if setID != nil {
6✔
930
                valueLen += copy(indexKey[valueLen:], setID[:])
3✔
931
        }
3✔
932

933
        var seqNoBytes [8]byte
3✔
934
        byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
3✔
935
        err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
3✔
936
        if err != nil {
3✔
937
                return err
×
938
        }
×
939

940
        // If the setID is nil, then this means that this is a non-AMP settle,
941
        // so we'll update the invoice settle index directly.
942
        if setID == nil {
6✔
943
                k.invoice.SettleDate = k.updateTime
3✔
944
                k.invoice.SettleIndex = nextSettleSeqNo
3✔
945
        } else {
6✔
946
                // If the set ID isn't blank, we'll update the AMP state map
3✔
947
                // which tracks when each of the setIDs associated with a given
3✔
948
                // AMP invoice are settled.
3✔
949
                ampState := k.invoice.AMPState[*setID]
3✔
950

3✔
951
                ampState.SettleDate = k.updateTime
3✔
952
                ampState.SettleIndex = nextSettleSeqNo
3✔
953

3✔
954
                k.invoice.AMPState[*setID] = ampState
3✔
955
        }
3✔
956

957
        return nil
3✔
958
}
959

960
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
961
// then continually write the invoices to the end of the invoice value, we
962
// instead write the invoices into a new key preifx that follows the main
963
// invoice number. This ensures that we don't need to continually decode a
964
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
965
// associated with a particular HTLC set.
966
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
3✔
967
        for setID, htlcSet := range k.updatedAmpHtlcs {
6✔
968
                // First write out the set of HTLCs including all the relevant
3✔
969
                // TLV values.
3✔
970
                var b bytes.Buffer
3✔
971
                if err := serializeHtlcs(&b, htlcSet); err != nil {
3✔
972
                        return err
×
973
                }
×
974

975
                // Next store each HTLC in-line, using a prefix based off the
976
                // invoice number.
977
                invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
3✔
978

3✔
979
                err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
3✔
980
                if err != nil {
3✔
981
                        return err
×
982
                }
×
983
        }
984

985
        return nil
3✔
986
}
987

988
// serializeAndStoreInvoice is a helper function used to store invoices.
989
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
3✔
990
        var buf bytes.Buffer
3✔
991
        if err := serializeInvoice(&buf, k.invoice); err != nil {
3✔
992
                return err
×
993
        }
×
994

995
        return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
3✔
996
}
997

998
// InvoicesSettledSince can be used by callers to catch up any settled invoices
999
// they missed within the settled invoice time series. We'll return all known
1000
// settled invoice that have a settle index higher than the passed
1001
// sinceSettleIndex.
1002
//
1003
// NOTE: The index starts from 1, as a result. We enforce that specifying a
1004
// value below the starting index value is a noop.
1005
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
1006
        []invpkg.Invoice, error) {
3✔
1007

3✔
1008
        var settledInvoices []invpkg.Invoice
3✔
1009

3✔
1010
        // If an index of zero was specified, then in order to maintain
3✔
1011
        // backwards compat, we won't send out any new invoices.
3✔
1012
        if sinceSettleIndex == 0 {
6✔
1013
                return settledInvoices, nil
3✔
1014
        }
3✔
1015

1016
        var startIndex [8]byte
3✔
1017
        byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
3✔
1018

3✔
1019
        err := kvdb.View(d, func(tx kvdb.RTx) error {
6✔
1020
                invoices := tx.ReadBucket(invoiceBucket)
3✔
1021
                if invoices == nil {
3✔
1022
                        return nil
×
1023
                }
×
1024

1025
                settleIndex := invoices.NestedReadBucket(settleIndexBucket)
3✔
1026
                if settleIndex == nil {
3✔
1027
                        return nil
×
1028
                }
×
1029

1030
                // We'll now run through each entry in the add index starting
1031
                // at our starting index. We'll continue until we reach the
1032
                // very end of the current key space.
1033
                invoiceCursor := settleIndex.ReadCursor()
3✔
1034

3✔
1035
                // We'll seek to the starting index, then manually advance the
3✔
1036
                // cursor in order to skip the entry with the since add index.
3✔
1037
                invoiceCursor.Seek(startIndex[:])
3✔
1038
                seqNo, indexValue := invoiceCursor.Next()
3✔
1039

3✔
1040
                for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
6✔
1041
                        // Depending on the length of the index value, this may
3✔
1042
                        // or may not be an AMP invoice, so we'll extract the
3✔
1043
                        // invoice value into two components: the invoice num,
3✔
1044
                        // and the setID (may not be there).
3✔
1045
                        var (
3✔
1046
                                invoiceKey [4]byte
3✔
1047
                                setID      *invpkg.SetID
3✔
1048
                        )
3✔
1049

3✔
1050
                        valueLen := copy(invoiceKey[:], indexValue)
3✔
1051
                        if len(indexValue) == invoiceSetIDKeyLen {
6✔
1052
                                setID = new(invpkg.SetID)
3✔
1053
                                copy(setID[:], indexValue[valueLen:])
3✔
1054
                        }
3✔
1055

1056
                        // For each key found, we'll look up the actual
1057
                        // invoice, then accumulate it into our return value.
1058
                        invoice, err := fetchInvoice(
3✔
1059
                                invoiceKey[:], invoices, setID,
3✔
1060
                        )
3✔
1061
                        if err != nil {
3✔
1062
                                return err
×
1063
                        }
×
1064

1065
                        settledInvoices = append(settledInvoices, invoice)
3✔
1066
                }
1067

1068
                return nil
3✔
1069
        }, func() {
3✔
1070
                settledInvoices = nil
3✔
1071
        })
3✔
1072
        if err != nil {
3✔
1073
                return nil, err
×
1074
        }
×
1075

1076
        return settledInvoices, nil
3✔
1077
}
1078

1079
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
1080
        i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
1081
        uint64, error) {
3✔
1082

3✔
1083
        // Create the invoice key which is just the big-endian representation
3✔
1084
        // of the invoice number.
3✔
1085
        var invoiceKey [4]byte
3✔
1086
        byteOrder.PutUint32(invoiceKey[:], invoiceNum)
3✔
1087

3✔
1088
        // Increment the num invoice counter index so the next invoice bares
3✔
1089
        // the proper ID.
3✔
1090
        var scratch [4]byte
3✔
1091
        invoiceCounter := invoiceNum + 1
3✔
1092
        byteOrder.PutUint32(scratch[:], invoiceCounter)
3✔
1093
        if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
3✔
1094
                return 0, err
×
1095
        }
×
1096

1097
        // Add the payment hash to the invoice index. This will let us quickly
1098
        // identify if we can settle an incoming payment, and also to possibly
1099
        // allow a single invoice to have multiple payment installations.
1100
        err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
3✔
1101
        if err != nil {
3✔
1102
                return 0, err
×
1103
        }
×
1104

1105
        // Add the invoice to the payment address index, but only if the invoice
1106
        // has a non-zero payment address. The all-zero payment address is still
1107
        // in use by legacy keysend, so we special-case here to avoid
1108
        // collisions.
1109
        if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
6✔
1110
                err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
3✔
1111
                if err != nil {
3✔
1112
                        return 0, err
×
1113
                }
×
1114
        }
1115

1116
        // Next, we'll obtain the next add invoice index (sequence
1117
        // number), so we can properly place this invoice within this
1118
        // event stream.
1119
        nextAddSeqNo, err := addIndex.NextSequence()
3✔
1120
        if err != nil {
3✔
1121
                return 0, err
×
1122
        }
×
1123

1124
        // With the next sequence obtained, we'll updating the event series in
1125
        // the add index bucket to map this current add counter to the index of
1126
        // this new invoice.
1127
        var seqNoBytes [8]byte
3✔
1128
        byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
3✔
1129
        if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
3✔
1130
                return 0, err
×
1131
        }
×
1132

1133
        i.AddIndex = nextAddSeqNo
3✔
1134

3✔
1135
        // Finally, serialize the invoice itself to be written to the disk.
3✔
1136
        var buf bytes.Buffer
3✔
1137
        if err := serializeInvoice(&buf, i); err != nil {
3✔
1138
                return 0, err
×
1139
        }
×
1140

1141
        if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
3✔
1142
                return 0, err
×
1143
        }
×
1144

1145
        return nextAddSeqNo, nil
3✔
1146
}
1147

1148
// recordSize returns the amount of bytes this TLV record will occupy when
1149
// encoded.
1150
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
3✔
1151
        var (
3✔
1152
                b   bytes.Buffer
3✔
1153
                buf [8]byte
3✔
1154
        )
3✔
1155

3✔
1156
        // We know that encoding works since the tests pass in the build this
3✔
1157
        // file is checked into, so we'll simplify things and simply encode it
3✔
1158
        // ourselves then report the total amount of bytes used.
3✔
1159
        if err := ampStateEncoder(&b, a, &buf); err != nil {
3✔
1160
                // This should never error out, but we log it just in case it
×
1161
                // does.
×
1162
                log.Errorf("encoding the amp invoice state failed: %v", err)
×
1163
        }
×
1164

1165
        return func() uint64 {
6✔
1166
                return uint64(len(b.Bytes()))
3✔
1167
        }
3✔
1168
}
1169

1170
// serializeInvoice serializes an invoice to a writer.
1171
//
1172
// Note: this function is in use for a migration. Before making changes that
1173
// would modify the on disk format, make a copy of the original code and store
1174
// it with the migration.
1175
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
3✔
1176
        creationDateBytes, err := i.CreationDate.MarshalBinary()
3✔
1177
        if err != nil {
3✔
1178
                return err
×
1179
        }
×
1180

1181
        settleDateBytes, err := i.SettleDate.MarshalBinary()
3✔
1182
        if err != nil {
3✔
1183
                return err
×
1184
        }
×
1185

1186
        var fb bytes.Buffer
3✔
1187
        err = i.Terms.Features.EncodeBase256(&fb)
3✔
1188
        if err != nil {
3✔
1189
                return err
×
1190
        }
×
1191
        featureBytes := fb.Bytes()
3✔
1192

3✔
1193
        preimage := [32]byte(invpkg.UnknownPreimage)
3✔
1194
        if i.Terms.PaymentPreimage != nil {
6✔
1195
                preimage = *i.Terms.PaymentPreimage
3✔
1196
                if preimage == invpkg.UnknownPreimage {
3✔
1197
                        return errors.New("cannot use all-zeroes preimage")
×
1198
                }
×
1199
        }
1200
        value := uint64(i.Terms.Value)
3✔
1201
        cltvDelta := uint32(i.Terms.FinalCltvDelta)
3✔
1202
        expiry := uint64(i.Terms.Expiry)
3✔
1203

3✔
1204
        amtPaid := uint64(i.AmtPaid)
3✔
1205
        state := uint8(i.State)
3✔
1206

3✔
1207
        var hodlInvoice uint8
3✔
1208
        if i.HodlInvoice {
6✔
1209
                hodlInvoice = 1
3✔
1210
        }
3✔
1211

1212
        tlvStream, err := tlv.NewStream(
3✔
1213
                // Memo and payreq.
3✔
1214
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1215
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3✔
1216

3✔
1217
                // Add/settle metadata.
3✔
1218
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3✔
1219
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1220
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1221
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1222

3✔
1223
                // Terms.
3✔
1224
                tlv.MakePrimitiveRecord(preimageType, &preimage),
3✔
1225
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1226
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1227
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1228
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1229
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1230

3✔
1231
                // Invoice state.
3✔
1232
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1233
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1234

3✔
1235
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1236

3✔
1237
                // Invoice AMP state.
3✔
1238
                tlv.MakeDynamicRecord(
3✔
1239
                        invoiceAmpStateType, &i.AMPState,
3✔
1240
                        ampRecordSize(&i.AMPState),
3✔
1241
                        ampStateEncoder, ampStateDecoder,
3✔
1242
                ),
3✔
1243
        )
3✔
1244
        if err != nil {
3✔
1245
                return err
×
1246
        }
×
1247

1248
        var b bytes.Buffer
3✔
1249
        if err = tlvStream.Encode(&b); err != nil {
3✔
1250
                return err
×
1251
        }
×
1252

1253
        err = binary.Write(w, byteOrder, uint64(b.Len()))
3✔
1254
        if err != nil {
3✔
1255
                return err
×
1256
        }
×
1257

1258
        if _, err = w.Write(b.Bytes()); err != nil {
3✔
1259
                return err
×
1260
        }
×
1261

1262
        // Only if this is a _non_ AMP invoice do we serialize the HTLCs
1263
        // in-line with the rest of the invoice.
1264
        if i.IsAMP() {
6✔
1265
                return nil
3✔
1266
        }
3✔
1267

1268
        return serializeHtlcs(w, i.Htlcs)
3✔
1269
}
1270

1271
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
1272
// a writer.
1273
func serializeHtlcs(w io.Writer,
1274
        htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
3✔
1275

3✔
1276
        for key, htlc := range htlcs {
6✔
1277
                // Encode the htlc in a tlv stream.
3✔
1278
                chanID := key.ChanID.ToUint64()
3✔
1279
                amt := uint64(htlc.Amt)
3✔
1280
                mppTotalAmt := uint64(htlc.MppTotalAmt)
3✔
1281
                acceptTime := putNanoTime(htlc.AcceptTime)
3✔
1282
                resolveTime := putNanoTime(htlc.ResolveTime)
3✔
1283
                state := uint8(htlc.State)
3✔
1284

3✔
1285
                var records []tlv.Record
3✔
1286
                records = append(records,
3✔
1287
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
1288
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
1289
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
1290
                        tlv.MakePrimitiveRecord(
3✔
1291
                                acceptHeightType, &htlc.AcceptHeight,
3✔
1292
                        ),
3✔
1293
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
1294
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
1295
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
1296
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
1297
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
1298
                )
3✔
1299

3✔
1300
                if htlc.AMP != nil {
6✔
1301
                        setIDRecord := tlv.MakeDynamicRecord(
3✔
1302
                                htlcAMPType, &htlc.AMP.Record,
3✔
1303
                                htlc.AMP.Record.PayloadSize,
3✔
1304
                                record.AMPEncoder, record.AMPDecoder,
3✔
1305
                        )
3✔
1306
                        records = append(records, setIDRecord)
3✔
1307

3✔
1308
                        hash32 := [32]byte(htlc.AMP.Hash)
3✔
1309
                        hashRecord := tlv.MakePrimitiveRecord(
3✔
1310
                                htlcHashType, &hash32,
3✔
1311
                        )
3✔
1312
                        records = append(records, hashRecord)
3✔
1313

3✔
1314
                        if htlc.AMP.Preimage != nil {
6✔
1315
                                preimage32 := [32]byte(*htlc.AMP.Preimage)
3✔
1316
                                preimageRecord := tlv.MakePrimitiveRecord(
3✔
1317
                                        htlcPreimageType, &preimage32,
3✔
1318
                                )
3✔
1319
                                records = append(records, preimageRecord)
3✔
1320
                        }
3✔
1321
                }
1322

1323
                // Convert the custom records to tlv.Record types that are ready
1324
                // for serialization.
1325
                customRecords := tlv.MapToRecords(htlc.CustomRecords)
3✔
1326

3✔
1327
                // Append the custom records. Their ids are in the experimental
3✔
1328
                // range and sorted, so there is no need to sort again.
3✔
1329
                records = append(records, customRecords...)
3✔
1330

3✔
1331
                tlvStream, err := tlv.NewStream(records...)
3✔
1332
                if err != nil {
3✔
1333
                        return err
×
1334
                }
×
1335

1336
                var b bytes.Buffer
3✔
1337
                if err := tlvStream.Encode(&b); err != nil {
3✔
1338
                        return err
×
1339
                }
×
1340

1341
                // Write the length of the tlv stream followed by the stream
1342
                // bytes.
1343
                err = binary.Write(w, byteOrder, uint64(b.Len()))
3✔
1344
                if err != nil {
3✔
1345
                        return err
×
1346
                }
×
1347

1348
                if _, err := w.Write(b.Bytes()); err != nil {
3✔
1349
                        return err
×
1350
                }
×
1351
        }
1352

1353
        return nil
3✔
1354
}
1355

1356
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
1357
// timestamp will be mapped to 0, since calling UnixNano in that case is
1358
// undefined.
1359
func putNanoTime(t time.Time) uint64 {
3✔
1360
        if t.IsZero() {
6✔
1361
                return 0
3✔
1362
        }
3✔
1363
        return uint64(t.UnixNano())
3✔
1364
}
1365

1366
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
1367
// is provided, an zero-value time stamp is returned.
1368
func getNanoTime(ns uint64) time.Time {
3✔
1369
        if ns == 0 {
6✔
1370
                return time.Time{}
3✔
1371
        }
3✔
1372
        return time.Unix(0, int64(ns))
3✔
1373
}
1374

1375
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
1376
// identified by the setID value.
1377
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1378
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1379
        error) {
3✔
1380

3✔
1381
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1382
        for _, setID := range setIDs {
6✔
1383
                invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
3✔
1384

3✔
1385
                htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
3✔
1386
                if htlcSetBytes == nil {
6✔
1387
                        // A set ID was passed in, but we don't have this
3✔
1388
                        // stored yet, meaning that the setID is being added
3✔
1389
                        // for the first time.
3✔
1390
                        return htlcs, invpkg.ErrInvoiceNotFound
3✔
1391
                }
3✔
1392

1393
                htlcSetReader := bytes.NewReader(htlcSetBytes)
3✔
1394
                htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
3✔
1395
                if err != nil {
3✔
1396
                        return nil, err
×
1397
                }
×
1398

1399
                for key, htlc := range htlcsBySetID {
6✔
1400
                        htlcs[key] = htlc
3✔
1401
                }
3✔
1402
        }
1403

1404
        return htlcs, nil
3✔
1405
}
1406

1407
// forEachAMPInvoice is a helper function that attempts to iterate over each of
1408
// the HTLC sets (based on their set ID) for the given AMP invoice identified
1409
// by its invoiceNum. The callback closure is called for each key within the
1410
// prefix range.
1411
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1412
        callback func(key, htlcSet []byte) error) error {
3✔
1413

3✔
1414
        invoiceCursor := invoiceBucket.ReadCursor()
3✔
1415

3✔
1416
        // Seek to the first key that includes the invoice data itself.
3✔
1417
        invoiceCursor.Seek(invoiceNum)
3✔
1418

3✔
1419
        // Advance to the very first key _after_ the invoice data, as this is
3✔
1420
        // where we'll encounter our first HTLC (if any are present).
3✔
1421
        cursorKey, htlcSet := invoiceCursor.Next()
3✔
1422

3✔
1423
        // If at this point, the cursor key doesn't match the invoice num
3✔
1424
        // prefix, then we know that this HTLC doesn't have any set ID HTLCs
3✔
1425
        // associated with it.
3✔
1426
        if !bytes.HasPrefix(cursorKey, invoiceNum) {
6✔
1427
                return nil
3✔
1428
        }
3✔
1429

1430
        // Otherwise continue to iterate until we no longer match the prefix,
1431
        // executing the call back at each step.
1432
        for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
6✔
1433
                err := callback(cursorKey, htlcSet)
3✔
1434
                if err != nil {
3✔
1435
                        return err
×
1436
                }
×
1437
        }
1438

1439
        return nil
3✔
1440
}
1441

1442
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix  within the
1443
// AMP bucket to find all the individual HTLCs (by setID) associated with a
1444
// given invoice. If a list of set IDs are specified, then only HTLCs
1445
// associated with that setID will be retrieved.
1446
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
1447
        setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1448
        error) {
3✔
1449

3✔
1450
        // If a set of setIDs was specified, then we can skip the cursor and
3✔
1451
        // just read out exactly what we need.
3✔
1452
        if len(setIDs) != 0 && setIDs[0] != nil {
6✔
1453
                return fetchFilteredAmpInvoices(
3✔
1454
                        invoiceBucket, invoiceNum, setIDs...,
3✔
1455
                )
3✔
1456
        }
3✔
1457

1458
        // Otherwise, iterate over all the htlc sets that are prefixed beside
1459
        // this invoice in the main invoice bucket.
1460
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1461
        err := forEachAMPInvoice(invoiceBucket, invoiceNum,
3✔
1462
                func(key, htlcSet []byte) error {
6✔
1463
                        htlcSetReader := bytes.NewReader(htlcSet)
3✔
1464
                        htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
3✔
1465
                        if err != nil {
3✔
1466
                                return err
×
1467
                        }
×
1468

1469
                        for key, htlc := range htlcsBySetID {
6✔
1470
                                htlcs[key] = htlc
3✔
1471
                        }
3✔
1472

1473
                        return nil
3✔
1474
                },
1475
        )
1476

1477
        if err != nil {
3✔
1478
                return nil, err
×
1479
        }
×
1480

1481
        return htlcs, nil
3✔
1482
}
1483

1484
// fetchInvoice attempts to read out the relevant state for the invoice as
1485
// specified by the invoice number. If the setID fields are set, then only the
1486
// HTLC information pertaining to those set IDs is returned.
1487
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1488
        setIDs ...*invpkg.SetID) (invpkg.Invoice, error) {
3✔
1489

3✔
1490
        invoiceBytes := invoices.Get(invoiceNum)
3✔
1491
        if invoiceBytes == nil {
3✔
1492
                return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
×
1493
        }
×
1494

1495
        invoiceReader := bytes.NewReader(invoiceBytes)
3✔
1496

3✔
1497
        invoice, err := deserializeInvoice(invoiceReader)
3✔
1498
        if err != nil {
3✔
1499
                return invpkg.Invoice{}, err
×
1500
        }
×
1501

1502
        // If this is an AMP invoice we'll also attempt to read out the set of
1503
        // HTLCs that were paid to prior set IDs, if needed.
1504
        if !invoice.IsAMP() {
6✔
1505
                return invoice, nil
3✔
1506
        }
3✔
1507

1508
        if shouldFetchAMPHTLCs(invoice, setIDs) {
6✔
1509
                invoice.Htlcs, err = fetchAmpSubInvoices(
3✔
1510
                        invoices, invoiceNum, setIDs...,
3✔
1511
                )
3✔
1512
                // TODO(positiveblue): we should fail when we are not able to
3✔
1513
                // fetch all the HTLCs for an AMP invoice. Multiple tests in
3✔
1514
                // the invoice and channeldb package break if we return this
3✔
1515
                // error. We need to update them when we migrate this logic to
3✔
1516
                // the sql implementation.
3✔
1517
                if err != nil {
6✔
1518
                        log.Errorf("unable to fetch amp htlcs for inv "+
3✔
1519
                                "%v and setIDs %v: %w", invoiceNum, setIDs, err)
3✔
1520
                }
3✔
1521
        }
1522

1523
        return invoice, nil
3✔
1524
}
1525

1526
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
1527
// were paid to the relevant set IDs.
1528
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
3✔
1529
        // For AMP invoice that already have HTLCs populated (created before
3✔
1530
        // recurring invoices), then we don't need to read from the prefix
3✔
1531
        // keyed section of the bucket.
3✔
1532
        if len(invoice.Htlcs) != 0 {
3✔
1533
                return false
×
1534
        }
×
1535

1536
        // If the "zero" setID was specified, then this means that no HTLC data
1537
        // should be returned alongside of it.
1538
        if len(setIDs) != 0 && setIDs[0] != nil &&
3✔
1539
                *setIDs[0] == invpkg.BlankPayAddr {
6✔
1540

3✔
1541
                return false
3✔
1542
        }
3✔
1543

1544
        return true
3✔
1545
}
1546

1547
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
1548
// an AMP invoice. This methods only decode the relevant state vs the entire
1549
// invoice.
1550
func fetchInvoiceStateAMP(invoiceNum []byte,
1551
        invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
×
1552

×
1553
        // Fetch the raw invoice bytes.
×
1554
        invoiceBytes := invoices.Get(invoiceNum)
×
1555
        if invoiceBytes == nil {
×
1556
                return nil, invpkg.ErrInvoiceNotFound
×
1557
        }
×
1558

1559
        r := bytes.NewReader(invoiceBytes)
×
1560

×
1561
        var bodyLen int64
×
1562
        err := binary.Read(r, byteOrder, &bodyLen)
×
1563
        if err != nil {
×
1564
                return nil, err
×
1565
        }
×
1566

1567
        // Next, we'll make a new TLV stream that only attempts to decode the
1568
        // bytes we actually need.
1569
        ampState := make(invpkg.AMPInvoiceState)
×
1570
        tlvStream, err := tlv.NewStream(
×
1571
                // Invoice AMP state.
×
1572
                tlv.MakeDynamicRecord(
×
1573
                        invoiceAmpStateType, &ampState, nil,
×
1574
                        ampStateEncoder, ampStateDecoder,
×
1575
                ),
×
1576
        )
×
1577
        if err != nil {
×
1578
                return nil, err
×
1579
        }
×
1580

1581
        invoiceReader := io.LimitReader(r, bodyLen)
×
1582
        if err = tlvStream.Decode(invoiceReader); err != nil {
×
1583
                return nil, err
×
1584
        }
×
1585

1586
        return ampState, nil
×
1587
}
1588

1589
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
3✔
1590
        var (
3✔
1591
                preimageBytes [32]byte
3✔
1592
                value         uint64
3✔
1593
                cltvDelta     uint32
3✔
1594
                expiry        uint64
3✔
1595
                amtPaid       uint64
3✔
1596
                state         uint8
3✔
1597
                hodlInvoice   uint8
3✔
1598

3✔
1599
                creationDateBytes []byte
3✔
1600
                settleDateBytes   []byte
3✔
1601
                featureBytes      []byte
3✔
1602
        )
3✔
1603

3✔
1604
        var i invpkg.Invoice
3✔
1605
        i.AMPState = make(invpkg.AMPInvoiceState)
3✔
1606
        tlvStream, err := tlv.NewStream(
3✔
1607
                // Memo and payreq.
3✔
1608
                tlv.MakePrimitiveRecord(memoType, &i.Memo),
3✔
1609
                tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
3✔
1610

3✔
1611
                // Add/settle metadata.
3✔
1612
                tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
3✔
1613
                tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
3✔
1614
                tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
3✔
1615
                tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
3✔
1616

3✔
1617
                // Terms.
3✔
1618
                tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
3✔
1619
                tlv.MakePrimitiveRecord(valueType, &value),
3✔
1620
                tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
3✔
1621
                tlv.MakePrimitiveRecord(expiryType, &expiry),
3✔
1622
                tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
3✔
1623
                tlv.MakePrimitiveRecord(featuresType, &featureBytes),
3✔
1624

3✔
1625
                // Invoice state.
3✔
1626
                tlv.MakePrimitiveRecord(invStateType, &state),
3✔
1627
                tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
3✔
1628

3✔
1629
                tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
3✔
1630

3✔
1631
                // Invoice AMP state.
3✔
1632
                tlv.MakeDynamicRecord(
3✔
1633
                        invoiceAmpStateType, &i.AMPState, nil,
3✔
1634
                        ampStateEncoder, ampStateDecoder,
3✔
1635
                ),
3✔
1636
        )
3✔
1637
        if err != nil {
3✔
1638
                return i, err
×
1639
        }
×
1640

1641
        var bodyLen int64
3✔
1642
        err = binary.Read(r, byteOrder, &bodyLen)
3✔
1643
        if err != nil {
3✔
1644
                return i, err
×
1645
        }
×
1646

1647
        lr := io.LimitReader(r, bodyLen)
3✔
1648
        if err = tlvStream.Decode(lr); err != nil {
3✔
1649
                return i, err
×
1650
        }
×
1651

1652
        preimage := lntypes.Preimage(preimageBytes)
3✔
1653
        if preimage != invpkg.UnknownPreimage {
6✔
1654
                i.Terms.PaymentPreimage = &preimage
3✔
1655
        }
3✔
1656

1657
        i.Terms.Value = lnwire.MilliSatoshi(value)
3✔
1658
        i.Terms.FinalCltvDelta = int32(cltvDelta)
3✔
1659
        i.Terms.Expiry = time.Duration(expiry)
3✔
1660
        i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
3✔
1661
        i.State = invpkg.ContractState(state)
3✔
1662

3✔
1663
        if hodlInvoice != 0 {
6✔
1664
                i.HodlInvoice = true
3✔
1665
        }
3✔
1666

1667
        err = i.CreationDate.UnmarshalBinary(creationDateBytes)
3✔
1668
        if err != nil {
3✔
1669
                return i, err
×
1670
        }
×
1671

1672
        err = i.SettleDate.UnmarshalBinary(settleDateBytes)
3✔
1673
        if err != nil {
3✔
1674
                return i, err
×
1675
        }
×
1676

1677
        rawFeatures := lnwire.NewRawFeatureVector()
3✔
1678
        err = rawFeatures.DecodeBase256(
3✔
1679
                bytes.NewReader(featureBytes), len(featureBytes),
3✔
1680
        )
3✔
1681
        if err != nil {
3✔
1682
                return i, err
×
1683
        }
×
1684

1685
        i.Terms.Features = lnwire.NewFeatureVector(
3✔
1686
                rawFeatures, lnwire.Features,
3✔
1687
        )
3✔
1688

3✔
1689
        i.Htlcs, err = deserializeHtlcs(r)
3✔
1690
        return i, err
3✔
1691
}
1692

1693
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1694
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
6✔
1695
                // We encode the set of circuit keys as a varint length prefix.
3✔
1696
                // followed by a series of fixed sized uint8 integers.
3✔
1697
                numKeys := uint64(len(*v))
3✔
1698

3✔
1699
                if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
3✔
1700
                        return err
×
1701
                }
×
1702

1703
                for key := range *v {
6✔
1704
                        scidInt := key.ChanID.ToUint64()
3✔
1705

3✔
1706
                        if err := tlv.EUint64(w, &scidInt, buf); err != nil {
3✔
1707
                                return err
×
1708
                        }
×
1709
                        if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
3✔
1710
                                return err
×
1711
                        }
×
1712
                }
1713

1714
                return nil
3✔
1715
        }
1716

1717
        return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
×
1718
}
1719

1720
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
1721
        l uint64) error {
3✔
1722

3✔
1723
        if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
6✔
1724
                // First, we'll read out the varint that encodes the number of
3✔
1725
                // circuit keys encoded.
3✔
1726
                numKeys, err := tlv.ReadVarInt(r, buf)
3✔
1727
                if err != nil {
3✔
1728
                        return err
×
1729
                }
×
1730

1731
                // Now that we know how many keys to expect, iterate reading
1732
                // each one until we're done.
1733
                for i := uint64(0); i < numKeys; i++ {
6✔
1734
                        var (
3✔
1735
                                key  models.CircuitKey
3✔
1736
                                scid uint64
3✔
1737
                        )
3✔
1738

3✔
1739
                        if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
3✔
1740
                                return err
×
1741
                        }
×
1742

1743
                        key.ChanID = lnwire.NewShortChanIDFromInt(scid)
3✔
1744

3✔
1745
                        err := tlv.DUint64(r, &key.HtlcID, buf, 8)
3✔
1746
                        if err != nil {
3✔
1747
                                return err
×
1748
                        }
×
1749

1750
                        (*v)[key] = struct{}{}
3✔
1751
                }
1752

1753
                return nil
3✔
1754
        }
1755

1756
        return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
×
1757
}
1758

1759
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
1760
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
3✔
1761
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
6✔
1762
                // We'll encode the AMP state as a series of KV pairs on the
3✔
1763
                // wire with a length prefix.
3✔
1764
                numRecords := uint64(len(*v))
3✔
1765

3✔
1766
                // First, we'll write out the number of records as a var int.
3✔
1767
                if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
3✔
1768
                        return err
×
1769
                }
×
1770

1771
                // With that written out, we'll now encode the entries
1772
                // themselves as a sub-TLV record, which includes its _own_
1773
                // inner length prefix.
1774
                for setID, ampState := range *v {
6✔
1775
                        setID := [32]byte(setID)
3✔
1776
                        ampState := ampState
3✔
1777

3✔
1778
                        htlcState := uint8(ampState.State)
3✔
1779
                        settleDate := ampState.SettleDate
3✔
1780
                        settleDateBytes, err := settleDate.MarshalBinary()
3✔
1781
                        if err != nil {
3✔
1782
                                return err
×
1783
                        }
×
1784

1785
                        amtPaid := uint64(ampState.AmtPaid)
3✔
1786

3✔
1787
                        var ampStateTlvBytes bytes.Buffer
3✔
1788
                        tlvStream, err := tlv.NewStream(
3✔
1789
                                tlv.MakePrimitiveRecord(
3✔
1790
                                        ampStateSetIDType, &setID,
3✔
1791
                                ),
3✔
1792
                                tlv.MakePrimitiveRecord(
3✔
1793
                                        ampStateHtlcStateType, &htlcState,
3✔
1794
                                ),
3✔
1795
                                tlv.MakePrimitiveRecord(
3✔
1796
                                        ampStateSettleIndexType,
3✔
1797
                                        &ampState.SettleIndex,
3✔
1798
                                ),
3✔
1799
                                tlv.MakePrimitiveRecord(
3✔
1800
                                        ampStateSettleDateType,
3✔
1801
                                        &settleDateBytes,
3✔
1802
                                ),
3✔
1803
                                tlv.MakeDynamicRecord(
3✔
1804
                                        ampStateCircuitKeysType,
3✔
1805
                                        &ampState.InvoiceKeys,
3✔
1806
                                        func() uint64 {
6✔
1807
                                                // The record takes 8 bytes to
3✔
1808
                                                // encode the set of circuits,
3✔
1809
                                                // 8 bytes for the scid for the
3✔
1810
                                                // key, and 8 bytes for the HTLC
3✔
1811
                                                // index.
3✔
1812
                                                keys := ampState.InvoiceKeys
3✔
1813
                                                numKeys := uint64(len(keys))
3✔
1814
                                                size := tlv.VarIntSize(numKeys)
3✔
1815
                                                dataSize := (numKeys * 16)
3✔
1816

3✔
1817
                                                return size + dataSize
3✔
1818
                                        },
3✔
1819
                                        encodeCircuitKeys, decodeCircuitKeys,
1820
                                ),
1821
                                tlv.MakePrimitiveRecord(
1822
                                        ampStateAmtPaidType, &amtPaid,
1823
                                ),
1824
                        )
1825
                        if err != nil {
3✔
1826
                                return err
×
1827
                        }
×
1828

1829
                        err = tlvStream.Encode(&ampStateTlvBytes)
3✔
1830
                        if err != nil {
3✔
1831
                                return err
×
1832
                        }
×
1833

1834
                        // We encode the record with a varint length followed by
1835
                        // the _raw_ TLV bytes.
1836
                        tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
3✔
1837
                        if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
3✔
1838
                                return err
×
1839
                        }
×
1840

1841
                        _, err = w.Write(ampStateTlvBytes.Bytes())
3✔
1842
                        if err != nil {
3✔
1843
                                return err
×
1844
                        }
×
1845
                }
1846

1847
                return nil
3✔
1848
        }
1849

1850
        return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
×
1851
}
1852

1853
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
1854
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
1855
        l uint64) error {
3✔
1856

3✔
1857
        if v, ok := val.(*invpkg.AMPInvoiceState); ok {
6✔
1858
                // First, we'll decode the varint that encodes how many set IDs
3✔
1859
                // are encoded within the greater map.
3✔
1860
                numRecords, err := tlv.ReadVarInt(r, buf)
3✔
1861
                if err != nil {
3✔
1862
                        return err
×
1863
                }
×
1864

1865
                // Now that we know how many records we'll need to read, we can
1866
                // iterate and read them all out in series.
1867
                for i := uint64(0); i < numRecords; i++ {
6✔
1868
                        // Read out the varint that encodes the size of this
3✔
1869
                        // inner TLV record.
3✔
1870
                        stateRecordSize, err := tlv.ReadVarInt(r, buf)
3✔
1871
                        if err != nil {
3✔
1872
                                return err
×
1873
                        }
×
1874

1875
                        // Using this information, we'll create a new limited
1876
                        // reader that'll return an EOF once the end has been
1877
                        // reached so the stream stops consuming bytes.
1878
                        innerTlvReader := io.LimitedReader{
3✔
1879
                                R: r,
3✔
1880
                                N: int64(stateRecordSize),
3✔
1881
                        }
3✔
1882

3✔
1883
                        var (
3✔
1884
                                setID           [32]byte
3✔
1885
                                htlcState       uint8
3✔
1886
                                settleIndex     uint64
3✔
1887
                                settleDateBytes []byte
3✔
1888
                                invoiceKeys     = make(
3✔
1889
                                        map[models.CircuitKey]struct{},
3✔
1890
                                )
3✔
1891
                                amtPaid uint64
3✔
1892
                        )
3✔
1893
                        tlvStream, err := tlv.NewStream(
3✔
1894
                                tlv.MakePrimitiveRecord(
3✔
1895
                                        ampStateSetIDType, &setID,
3✔
1896
                                ),
3✔
1897
                                tlv.MakePrimitiveRecord(
3✔
1898
                                        ampStateHtlcStateType, &htlcState,
3✔
1899
                                ),
3✔
1900
                                tlv.MakePrimitiveRecord(
3✔
1901
                                        ampStateSettleIndexType, &settleIndex,
3✔
1902
                                ),
3✔
1903
                                tlv.MakePrimitiveRecord(
3✔
1904
                                        ampStateSettleDateType,
3✔
1905
                                        &settleDateBytes,
3✔
1906
                                ),
3✔
1907
                                tlv.MakeDynamicRecord(
3✔
1908
                                        ampStateCircuitKeysType,
3✔
1909
                                        &invoiceKeys, nil,
3✔
1910
                                        encodeCircuitKeys, decodeCircuitKeys,
3✔
1911
                                ),
3✔
1912
                                tlv.MakePrimitiveRecord(
3✔
1913
                                        ampStateAmtPaidType, &amtPaid,
3✔
1914
                                ),
3✔
1915
                        )
3✔
1916
                        if err != nil {
3✔
1917
                                return err
×
1918
                        }
×
1919

1920
                        err = tlvStream.Decode(&innerTlvReader)
3✔
1921
                        if err != nil {
3✔
1922
                                return err
×
1923
                        }
×
1924

1925
                        var settleDate time.Time
3✔
1926
                        err = settleDate.UnmarshalBinary(settleDateBytes)
3✔
1927
                        if err != nil {
3✔
1928
                                return err
×
1929
                        }
×
1930

1931
                        (*v)[setID] = invpkg.InvoiceStateAMP{
3✔
1932
                                State:       invpkg.HtlcState(htlcState),
3✔
1933
                                SettleIndex: settleIndex,
3✔
1934
                                SettleDate:  settleDate,
3✔
1935
                                InvoiceKeys: invoiceKeys,
3✔
1936
                                AmtPaid:     lnwire.MilliSatoshi(amtPaid),
3✔
1937
                        }
3✔
1938
                }
1939

1940
                return nil
3✔
1941
        }
1942

1943
        return tlv.NewTypeForDecodingErr(
×
1944
                val, "channeldb.AMPInvoiceState", l, l,
×
1945
        )
×
1946
}
1947

1948
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
1949
// as a map.
1950
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
1951
        error) {
3✔
1952

3✔
1953
        htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
3✔
1954
        for {
6✔
1955
                // Read the length of the tlv stream for this htlc.
3✔
1956
                var streamLen int64
3✔
1957
                if err := binary.Read(r, byteOrder, &streamLen); err != nil {
6✔
1958
                        if err == io.EOF {
6✔
1959
                                break
3✔
1960
                        }
1961

1962
                        return nil, err
×
1963
                }
1964

1965
                // Limit the reader so that it stops at the end of this htlc's
1966
                // stream.
1967
                htlcReader := io.LimitReader(r, streamLen)
3✔
1968

3✔
1969
                // Decode the contents into the htlc fields.
3✔
1970
                var (
3✔
1971
                        htlc                    invpkg.InvoiceHTLC
3✔
1972
                        key                     models.CircuitKey
3✔
1973
                        chanID                  uint64
3✔
1974
                        state                   uint8
3✔
1975
                        acceptTime, resolveTime uint64
3✔
1976
                        amt, mppTotalAmt        uint64
3✔
1977
                        amp                     = &record.AMP{}
3✔
1978
                        hash32                  = &[32]byte{}
3✔
1979
                        preimage32              = &[32]byte{}
3✔
1980
                )
3✔
1981
                tlvStream, err := tlv.NewStream(
3✔
1982
                        tlv.MakePrimitiveRecord(chanIDType, &chanID),
3✔
1983
                        tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
3✔
1984
                        tlv.MakePrimitiveRecord(amtType, &amt),
3✔
1985
                        tlv.MakePrimitiveRecord(
3✔
1986
                                acceptHeightType, &htlc.AcceptHeight,
3✔
1987
                        ),
3✔
1988
                        tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
3✔
1989
                        tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
3✔
1990
                        tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
3✔
1991
                        tlv.MakePrimitiveRecord(htlcStateType, &state),
3✔
1992
                        tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
3✔
1993
                        tlv.MakeDynamicRecord(
3✔
1994
                                htlcAMPType, amp, amp.PayloadSize,
3✔
1995
                                record.AMPEncoder, record.AMPDecoder,
3✔
1996
                        ),
3✔
1997
                        tlv.MakePrimitiveRecord(htlcHashType, hash32),
3✔
1998
                        tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
3✔
1999
                )
3✔
2000
                if err != nil {
3✔
2001
                        return nil, err
×
2002
                }
×
2003

2004
                parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
3✔
2005
                if err != nil {
3✔
2006
                        return nil, err
×
2007
                }
×
2008

2009
                if _, ok := parsedTypes[htlcAMPType]; !ok {
6✔
2010
                        amp = nil
3✔
2011
                }
3✔
2012

2013
                var preimage *lntypes.Preimage
3✔
2014
                if _, ok := parsedTypes[htlcPreimageType]; ok {
6✔
2015
                        pimg := lntypes.Preimage(*preimage32)
3✔
2016
                        preimage = &pimg
3✔
2017
                }
3✔
2018

2019
                var hash *lntypes.Hash
3✔
2020
                if _, ok := parsedTypes[htlcHashType]; ok {
6✔
2021
                        h := lntypes.Hash(*hash32)
3✔
2022
                        hash = &h
3✔
2023
                }
3✔
2024

2025
                key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
3✔
2026
                htlc.AcceptTime = getNanoTime(acceptTime)
3✔
2027
                htlc.ResolveTime = getNanoTime(resolveTime)
3✔
2028
                htlc.State = invpkg.HtlcState(state)
3✔
2029
                htlc.Amt = lnwire.MilliSatoshi(amt)
3✔
2030
                htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
3✔
2031
                if amp != nil && hash != nil {
6✔
2032
                        htlc.AMP = &invpkg.InvoiceHtlcAMPData{
3✔
2033
                                Record:   *amp,
3✔
2034
                                Hash:     *hash,
3✔
2035
                                Preimage: preimage,
3✔
2036
                        }
3✔
2037
                }
3✔
2038

2039
                // Reconstruct the custom records fields from the parsed types
2040
                // map return from the tlv parser.
2041
                htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
3✔
2042

3✔
2043
                htlcs[key] = &htlc
3✔
2044
        }
2045

2046
        return htlcs, nil
3✔
2047
}
2048

2049
// invoiceSetIDKeyLen is the length of the key that's used to store the
2050
// individual HTLCs prefixed by their ID along side the main invoice within the
2051
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
2052
// set ID.
2053
const invoiceSetIDKeyLen = 4 + 32
2054

2055
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
2056
// number where the HTLCs for this setID will be stored udner.
2057
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
3✔
2058
        // Construct the prefix key we need to obtain the invoice information:
3✔
2059
        // invoiceNum || setID.
3✔
2060
        var invoiceSetIDKey [invoiceSetIDKeyLen]byte
3✔
2061
        copy(invoiceSetIDKey[:], invoiceNum)
3✔
2062
        copy(invoiceSetIDKey[len(invoiceNum):], setID)
3✔
2063

3✔
2064
        return invoiceSetIDKey
3✔
2065
}
3✔
2066

2067
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
2068
// greater AMP invoices. We do this by deleting the set of keys that share the
2069
// invoice number as a prefix.
2070
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
×
2071
        // Since it isn't safe to delete using an active cursor, we'll use the
×
2072
        // cursor simply to collect the set of keys we need to delete, _then_
×
2073
        // delete them in another pass.
×
2074
        var keysToDel [][]byte
×
2075
        err := forEachAMPInvoice(
×
2076
                invoiceBucket, invoiceNum,
×
2077
                func(cursorKey, v []byte) error {
×
2078
                        keysToDel = append(keysToDel, cursorKey)
×
2079
                        return nil
×
2080
                },
×
2081
        )
2082
        if err != nil {
×
2083
                return err
×
2084
        }
×
2085

2086
        // In this next phase, we'll then delete all the relevant invoices.
2087
        for _, keyToDel := range keysToDel {
×
2088
                if err := invoiceBucket.Delete(keyToDel); err != nil {
×
2089
                        return err
×
2090
                }
×
2091
        }
2092

2093
        return nil
×
2094
}
2095

2096
// delAMPSettleIndex removes all the entries in the settle index associated
2097
// with a given AMP invoice.
2098
func delAMPSettleIndex(invoiceNum []byte, invoices,
2099
        settleIndex kvdb.RwBucket) error {
×
2100

×
2101
        // First, we need to grab the AMP invoice state to see if there's
×
2102
        // anything that we even need to delete.
×
2103
        ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
×
2104
        if err != nil {
×
2105
                return err
×
2106
        }
×
2107

2108
        // If there's no AMP state at all (non-AMP invoice), then we can return
2109
        // early.
2110
        if len(ampState) == 0 {
×
2111
                return nil
×
2112
        }
×
2113

2114
        // Otherwise, we'll need to iterate and delete each settle index within
2115
        // the set of returned entries.
2116
        var settleIndexKey [8]byte
×
2117
        for _, subState := range ampState {
×
2118
                byteOrder.PutUint64(
×
2119
                        settleIndexKey[:], subState.SettleIndex,
×
2120
                )
×
2121

×
2122
                if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
×
2123
                        return err
×
2124
                }
×
2125
        }
2126

2127
        return nil
×
2128
}
2129

2130
// DeleteCanceledInvoices deletes all canceled invoices from the database.
2131
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
×
2132
        return kvdb.Update(d, func(tx kvdb.RwTx) error {
×
2133
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
2134
                if invoices == nil {
×
2135
                        return nil
×
2136
                }
×
2137

2138
                invoiceIndex := invoices.NestedReadWriteBucket(
×
2139
                        invoiceIndexBucket,
×
2140
                )
×
2141
                if invoiceIndex == nil {
×
2142
                        return nil
×
2143
                }
×
2144

2145
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
2146
                        addIndexBucket,
×
2147
                )
×
2148
                if invoiceAddIndex == nil {
×
2149
                        return nil
×
2150
                }
×
2151

2152
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
2153

×
2154
                return invoiceIndex.ForEach(func(k, v []byte) error {
×
2155
                        // Skip the special numInvoicesKey as that does not
×
2156
                        // point to a valid invoice.
×
2157
                        if bytes.Equal(k, numInvoicesKey) {
×
2158
                                return nil
×
2159
                        }
×
2160

2161
                        // Skip sub-buckets.
2162
                        if v == nil {
×
2163
                                return nil
×
2164
                        }
×
2165

2166
                        invoice, err := fetchInvoice(v, invoices)
×
2167
                        if err != nil {
×
2168
                                return err
×
2169
                        }
×
2170

2171
                        if invoice.State != invpkg.ContractCanceled {
×
2172
                                return nil
×
2173
                        }
×
2174

2175
                        // Delete the payment hash from the invoice index.
2176
                        err = invoiceIndex.Delete(k)
×
2177
                        if err != nil {
×
2178
                                return err
×
2179
                        }
×
2180

2181
                        // Delete payment address index reference if there's a
2182
                        // valid payment address.
2183
                        if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
×
2184
                                // To ensure consistency check that the already
×
2185
                                // fetched invoice key matches the one in the
×
2186
                                // payment address index.
×
2187
                                key := payAddrIndex.Get(
×
2188
                                        invoice.Terms.PaymentAddr[:],
×
2189
                                )
×
2190
                                if bytes.Equal(key, k) {
×
2191
                                        // Delete from the payment address
×
2192
                                        // index.
×
2193
                                        if err := payAddrIndex.Delete(
×
2194
                                                invoice.Terms.PaymentAddr[:],
×
2195
                                        ); err != nil {
×
2196
                                                return err
×
2197
                                        }
×
2198
                                }
2199
                        }
2200

2201
                        // Remove from the add index.
2202
                        var addIndexKey [8]byte
×
2203
                        byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
×
2204
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
2205
                        if err != nil {
×
2206
                                return err
×
2207
                        }
×
2208

2209
                        // Note that we don't need to delete the invoice from
2210
                        // the settle index as it is not added until the
2211
                        // invoice is settled.
2212

2213
                        // Now remove all sub invoices.
2214
                        err = delAMPInvoices(k, invoices)
×
2215
                        if err != nil {
×
2216
                                return err
×
2217
                        }
×
2218

2219
                        // Finally remove the serialized invoice from the
2220
                        // invoice bucket.
2221
                        return invoices.Delete(k)
×
2222
                })
2223
        }, func() {})
×
2224
}
2225

2226
// DeleteInvoice attempts to delete the passed invoices from the database in
2227
// one transaction. The passed delete references hold all keys required to
2228
// delete the invoices without also needing to deserialize them.
2229
func (d *DB) DeleteInvoice(_ context.Context,
2230
        invoicesToDelete []invpkg.InvoiceDeleteRef) error {
×
2231

×
2232
        err := kvdb.Update(d, func(tx kvdb.RwTx) error {
×
2233
                invoices := tx.ReadWriteBucket(invoiceBucket)
×
2234
                if invoices == nil {
×
2235
                        return invpkg.ErrNoInvoicesCreated
×
2236
                }
×
2237

2238
                invoiceIndex := invoices.NestedReadWriteBucket(
×
2239
                        invoiceIndexBucket,
×
2240
                )
×
2241
                if invoiceIndex == nil {
×
2242
                        return invpkg.ErrNoInvoicesCreated
×
2243
                }
×
2244

2245
                invoiceAddIndex := invoices.NestedReadWriteBucket(
×
2246
                        addIndexBucket,
×
2247
                )
×
2248
                if invoiceAddIndex == nil {
×
2249
                        return invpkg.ErrNoInvoicesCreated
×
2250
                }
×
2251

2252
                // settleIndex can be nil, as the bucket is created lazily
2253
                // when the first invoice is settled.
2254
                settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
×
2255

×
2256
                payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
×
2257

×
2258
                for _, ref := range invoicesToDelete {
×
2259
                        // Fetch the invoice key for using it to check for
×
2260
                        // consistency and also to delete from the invoice
×
2261
                        // index.
×
2262
                        invoiceKey := invoiceIndex.Get(ref.PayHash[:])
×
2263
                        if invoiceKey == nil {
×
2264
                                return invpkg.ErrInvoiceNotFound
×
2265
                        }
×
2266

2267
                        err := invoiceIndex.Delete(ref.PayHash[:])
×
2268
                        if err != nil {
×
2269
                                return err
×
2270
                        }
×
2271

2272
                        // Delete payment address index reference if there's a
2273
                        // valid payment address passed.
2274
                        if ref.PayAddr != nil {
×
2275
                                // To ensure consistency check that the already
×
2276
                                // fetched invoice key matches the one in the
×
2277
                                // payment address index.
×
2278
                                key := payAddrIndex.Get(ref.PayAddr[:])
×
2279
                                if bytes.Equal(key, invoiceKey) {
×
2280
                                        // Delete from the payment address
×
2281
                                        // index. Note that since the payment
×
2282
                                        // address index has been introduced
×
2283
                                        // with an empty migration it may be
×
2284
                                        // possible that the index doesn't have
×
2285
                                        // an entry for this invoice.
×
2286
                                        // ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
×
2287
                                        if err := payAddrIndex.Delete(
×
2288
                                                ref.PayAddr[:],
×
2289
                                        ); err != nil {
×
2290
                                                return err
×
2291
                                        }
×
2292
                                }
2293
                        }
2294

2295
                        var addIndexKey [8]byte
×
2296
                        byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
×
2297

×
2298
                        // To ensure consistency check that the key stored in
×
2299
                        // the add index also matches the previously fetched
×
2300
                        // invoice key.
×
2301
                        key := invoiceAddIndex.Get(addIndexKey[:])
×
2302
                        if !bytes.Equal(key, invoiceKey) {
×
2303
                                return fmt.Errorf("unknown invoice " +
×
2304
                                        "in add index")
×
2305
                        }
×
2306

2307
                        // Remove from the add index.
2308
                        err = invoiceAddIndex.Delete(addIndexKey[:])
×
2309
                        if err != nil {
×
2310
                                return err
×
2311
                        }
×
2312

2313
                        // Remove from the settle index if available and
2314
                        // if the invoice is settled.
2315
                        if settleIndex != nil && ref.SettleIndex > 0 {
×
2316
                                var settleIndexKey [8]byte
×
2317
                                byteOrder.PutUint64(
×
2318
                                        settleIndexKey[:], ref.SettleIndex,
×
2319
                                )
×
2320

×
2321
                                // To ensure consistency check that the already
×
2322
                                // fetched invoice key matches the one in the
×
2323
                                // settle index
×
2324
                                key := settleIndex.Get(settleIndexKey[:])
×
2325
                                if !bytes.Equal(key, invoiceKey) {
×
2326
                                        return fmt.Errorf("unknown invoice " +
×
2327
                                                "in settle index")
×
2328
                                }
×
2329

2330
                                err = settleIndex.Delete(settleIndexKey[:])
×
2331
                                if err != nil {
×
2332
                                        return err
×
2333
                                }
×
2334
                        }
2335

2336
                        // In addition to deleting the main invoice state, if
2337
                        // this is an AMP invoice, then we'll also need to
2338
                        // delete the set HTLC set stored as a key prefix. For
2339
                        // non-AMP invoices, this'll be a noop.
2340
                        err = delAMPSettleIndex(
×
2341
                                invoiceKey, invoices, settleIndex,
×
2342
                        )
×
2343
                        if err != nil {
×
2344
                                return err
×
2345
                        }
×
2346
                        err = delAMPInvoices(invoiceKey, invoices)
×
2347
                        if err != nil {
×
2348
                                return err
×
2349
                        }
×
2350

2351
                        // Finally remove the serialized invoice from the
2352
                        // invoice bucket.
2353
                        err = invoices.Delete(invoiceKey)
×
2354
                        if err != nil {
×
2355
                                return err
×
2356
                        }
×
2357
                }
2358

2359
                return nil
×
2360
        }, func() {})
×
2361

2362
        return err
×
2363
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc