• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15160358425

21 May 2025 10:56AM UTC coverage: 58.584% (-10.4%) from 68.996%
15160358425

Pull #9847

github

web-flow
Merge 2880b9a35 into c52a6ddeb
Pull Request #9847: Refactor Payment PR 4

634 of 942 new or added lines in 17 files covered. (67.3%)

28108 existing lines in 450 files now uncovered.

97449 of 166342 relevant lines covered (58.58%)

1.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.06
/channeldb/payments.go
1
package channeldb
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/binary"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "sort"
11
        "sync"
12
        "time"
13

14
        "github.com/btcsuite/btcd/btcec/v2"
15
        "github.com/btcsuite/btcd/wire"
16
        "github.com/lightningnetwork/lnd/kvdb"
17
        "github.com/lightningnetwork/lnd/lntypes"
18
        "github.com/lightningnetwork/lnd/lnwire"
19
        pymtpkg "github.com/lightningnetwork/lnd/payments"
20
        "github.com/lightningnetwork/lnd/record"
21
        "github.com/lightningnetwork/lnd/routing/route"
22
        "github.com/lightningnetwork/lnd/tlv"
23
)
24

25
var (
26
        // paymentsRootBucket is the name of the top-level bucket within the
27
        // database that stores all data related to payments. Within this
28
        // bucket, each payment hash its own sub-bucket keyed by its payment
29
        // hash.
30
        //
31
        // Bucket hierarchy:
32
        //
33
        // root-bucket
34
        //      |
35
        //      |-- <paymenthash>
36
        //      |        |--sequence-key: <sequence number>
37
        //      |        |--creation-info-key: <creation info>
38
        //      |        |--fail-info-key: <(optional) fail info>
39
        //      |        |
40
        //      |        |--payment-htlcs-bucket (shard-bucket)
41
        //      |        |        |
42
        //      |        |        |-- ai<htlc attempt ID>: <htlc attempt info>
43
        //      |        |        |-- si<htlc attempt ID>: <(optional) settle info>
44
        //      |        |        |-- fi<htlc attempt ID>: <(optional) fail info>
45
        //      |        |        |
46
        //      |        |       ...
47
        //      |        |
48
        //      |        |
49
        //      |        |--duplicate-bucket (only for old, completed payments)
50
        //      |                 |
51
        //      |                 |-- <seq-num>
52
        //      |                 |       |--sequence-key: <sequence number>
53
        //      |                 |       |--creation-info-key: <creation info>
54
        //      |                 |       |--ai: <attempt info>
55
        //      |                 |       |--si: <settle info>
56
        //      |                 |       |--fi: <fail info>
57
        //      |                 |
58
        //      |                 |-- <seq-num>
59
        //      |                 |       |
60
        //      |                ...     ...
61
        //      |
62
        //      |-- <paymenthash>
63
        //      |        |
64
        //      |       ...
65
        //     ...
66
        //
67
        paymentsRootBucket = []byte("payments-root-bucket")
68

69
        // paymentSequenceKey is a key used in the payment's sub-bucket to
70
        // store the sequence number of the payment.
71
        paymentSequenceKey = []byte("payment-sequence-key")
72

73
        // paymentCreationInfoKey is a key used in the payment's sub-bucket to
74
        // store the creation info of the payment.
75
        paymentCreationInfoKey = []byte("payment-creation-info")
76

77
        // paymentHtlcsBucket is a bucket where we'll store the information
78
        // about the HTLCs that were attempted for a payment.
79
        paymentHtlcsBucket = []byte("payment-htlcs-bucket")
80

81
        // htlcAttemptInfoKey is the key used as the prefix of an HTLC attempt
82
        // to store the info about the attempt that was done for the HTLC in
83
        // question. The HTLC attempt ID is concatenated at the end.
84
        htlcAttemptInfoKey = []byte("ai")
85

86
        // htlcSettleInfoKey is the key used as the prefix of an HTLC attempt
87
        // settle info, if any. The HTLC attempt ID is concatenated at the end.
88
        htlcSettleInfoKey = []byte("si")
89

90
        // htlcFailInfoKey is the key used as the prefix of an HTLC attempt
91
        // failure information, if any.The  HTLC attempt ID is concatenated at
92
        // the end.
93
        htlcFailInfoKey = []byte("fi")
94

95
        // paymentFailInfoKey is a key used in the payment's sub-bucket to
96
        // store information about the reason a payment failed.
97
        paymentFailInfoKey = []byte("payment-fail-info")
98

99
        // paymentsIndexBucket is the name of the top-level bucket within the
100
        // database that stores an index of payment sequence numbers to its
101
        // payment hash.
102
        // payments-sequence-index-bucket
103
        //         |--<sequence-number>: <payment hash>
104
        //         |--...
105
        //         |--<sequence-number>: <payment hash>
106
        paymentsIndexBucket = []byte("payments-index-bucket")
107
)
108

109
var (
110
        // ErrNoSequenceNumber is returned if we look up a payment which does
111
        // not have a sequence number.
112
        ErrNoSequenceNumber = errors.New("sequence number not found")
113

114
        // ErrDuplicateNotFound is returned when we lookup a payment by its
115
        // index and cannot find a payment with a matching sequence number.
116
        ErrDuplicateNotFound = errors.New("duplicate payment not found")
117

118
        // ErrNoDuplicateBucket is returned when we expect to find duplicates
119
        // when looking up a payment from its index, but the payment does not
120
        // have any.
121
        ErrNoDuplicateBucket = errors.New("expected duplicate bucket")
122

123
        // ErrNoDuplicateNestedBucket is returned if we do not find duplicate
124
        // payments in their own sub-bucket.
125
        ErrNoDuplicateNestedBucket = errors.New("nested duplicate bucket not " +
126
                "found")
127
)
128

129
// Payment operations related constants.
130
const (
131
        // paymentSeqBlockSize is the block size used when we batch allocate
132
        // payment sequences for future payments.
133
        paymentSeqBlockSize = 1000
134

135
        // paymentProgressLogInterval is the interval we use limiting the
136
        // logging output of payment processing.
137
        paymentProgressLogInterval = 30 * time.Second
138
)
139

140
// KVPaymentsDB implements persistence for payments and payment attempts.
141
type KVPaymentsDB struct {
142
        paymentSeqMx     sync.Mutex
143
        currPaymentSeq   uint64
144
        storedPaymentSeq uint64
145
        db               *DB
146
}
147

148
// NewKVPaymentsDB creates a new instance of the KVPaymentsDB.
149
func NewKVPaymentsDB(db *DB) *KVPaymentsDB {
3✔
150
        return &KVPaymentsDB{
3✔
151
                db: db,
3✔
152
        }
3✔
153
}
3✔
154

155
// InitPayment checks or records the given PaymentCreationInfo with the DB,
156
// making sure it does not already exist as an in-flight payment. When this
157
// method returns successfully, the payment is guaranteed to be in the InFlight
158
// state.
159
func (p *KVPaymentsDB) InitPayment(paymentHash lntypes.Hash,
160
        info *pymtpkg.PaymentCreationInfo) error {
3✔
161

3✔
162
        // Obtain a new sequence number for this payment. This is used
3✔
163
        // to sort the payments in order of creation, and also acts as
3✔
164
        // a unique identifier for each payment.
3✔
165
        sequenceNum, err := p.nextPaymentSequence()
3✔
166
        if err != nil {
3✔
NEW
167
                return err
×
NEW
168
        }
×
169

170
        var b bytes.Buffer
3✔
171
        if err := serializePaymentCreationInfo(&b, info); err != nil {
3✔
NEW
172
                return err
×
NEW
173
        }
×
174
        infoBytes := b.Bytes()
3✔
175

3✔
176
        var updateErr error
3✔
177
        err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
6✔
178
                // Reset the update error, to avoid carrying over an error
3✔
179
                // from a previous execution of the batched db transaction.
3✔
180
                updateErr = nil
3✔
181

3✔
182
                prefetchPayment(tx, paymentHash)
3✔
183
                bucket, err := createPaymentBucket(tx, paymentHash)
3✔
184
                if err != nil {
3✔
NEW
185
                        return err
×
NEW
186
                }
×
187

188
                // Get the existing status of this payment, if any.
189
                paymentStatus, err := fetchPaymentStatus(bucket)
3✔
190

3✔
191
                switch {
3✔
192
                // If no error is returned, it means we already have this
193
                // payment. We'll check the status to decide whether we allow
194
                // retrying the payment or return a specific error.
195
                case err == nil:
3✔
196
                        if err := paymentStatus.Initializable(); err != nil {
6✔
197
                                updateErr = err
3✔
198
                                return nil
3✔
199
                        }
3✔
200

201
                // Otherwise, if the error is not `ErrPaymentNotInitiated`,
202
                // we'll return the error.
NEW
203
                case !errors.Is(err, pymtpkg.ErrPaymentNotInitiated):
×
NEW
204
                        return err
×
205
                }
206

207
                // Before we set our new sequence number, we check whether this
208
                // payment has a previously set sequence number and remove its
209
                // index entry if it exists. This happens in the case where we
210
                // have a previously attempted payment which was left in a state
211
                // where we can retry.
212
                seqBytes := bucket.Get(paymentSequenceKey)
3✔
213
                if seqBytes != nil {
6✔
214
                        indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
3✔
215
                        if err := indexBucket.Delete(seqBytes); err != nil {
3✔
NEW
216
                                return err
×
NEW
217
                        }
×
218
                }
219

220
                // Once we have obtained a sequence number, we add an entry
221
                // to our index bucket which will map the sequence number to
222
                // our payment identifier.
223
                err = createPaymentIndexEntry(
3✔
224
                        tx, sequenceNum, info.PaymentIdentifier,
3✔
225
                )
3✔
226
                if err != nil {
3✔
NEW
227
                        return err
×
NEW
228
                }
×
229

230
                err = bucket.Put(paymentSequenceKey, sequenceNum)
3✔
231
                if err != nil {
3✔
NEW
232
                        return err
×
NEW
233
                }
×
234

235
                // Add the payment info to the bucket, which contains the
236
                // static information for this payment
237
                err = bucket.Put(paymentCreationInfoKey, infoBytes)
3✔
238
                if err != nil {
3✔
NEW
239
                        return err
×
NEW
240
                }
×
241

242
                // We'll delete any lingering HTLCs to start with, in case we
243
                // are initializing a payment that was attempted earlier, but
244
                // left in a state where we could retry.
245
                err = bucket.DeleteNestedBucket(paymentHtlcsBucket)
3✔
246
                if err != nil && !errors.Is(err, kvdb.ErrBucketNotFound) {
3✔
NEW
247
                        return err
×
NEW
248
                }
×
249

250
                // Also delete any lingering failure info now that we are
251
                // re-attempting.
252
                return bucket.Delete(paymentFailInfoKey)
3✔
253
        })
254
        if err != nil {
3✔
NEW
255
                return fmt.Errorf("unable to init payment: %w", err)
×
NEW
256
        }
×
257

258
        return updateErr
3✔
259
}
260

261
// DeleteFailedAttempts deletes all failed htlcs for a payment if configured
262
// by the KVPaymentsDB db.
263
func (p *KVPaymentsDB) DeleteFailedAttempts(hash lntypes.Hash) error {
3✔
264
        if !p.db.keepFailedPaymentAttempts {
3✔
NEW
265
                const failedHtlcsOnly = true
×
NEW
266
                err := p.DeletePayment(hash, failedHtlcsOnly)
×
NEW
267
                if err != nil {
×
NEW
268
                        return err
×
NEW
269
                }
×
270
        }
271

272
        return nil
3✔
273
}
274

275
// paymentIndexTypeHash is a payment index type which indicates that we have
276
// created an index of payment sequence number to payment hash.
277
type paymentIndexType uint8
278

279
// paymentIndexTypeHash is a payment index type which indicates that we have
280
// created an index of payment sequence number to payment hash.
281
const paymentIndexTypeHash paymentIndexType = 0
282

283
// createPaymentIndexEntry creates a payment hash typed index for a payment. The
284
// index produced contains a payment index type (which can be used in future to
285
// signal different payment index types) and the payment identifier.
286
func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte,
287
        id lntypes.Hash) error {
3✔
288

3✔
289
        var b bytes.Buffer
3✔
290
        if err := WriteElements(&b, paymentIndexTypeHash, id[:]); err != nil {
3✔
NEW
291
                return err
×
NEW
292
        }
×
293

294
        indexes := tx.ReadWriteBucket(paymentsIndexBucket)
3✔
295

3✔
296
        return indexes.Put(sequenceNumber, b.Bytes())
3✔
297
}
298

299
// deserializePaymentIndex deserializes a payment index entry. This function
300
// currently only supports deserialization of payment hash indexes, and will
301
// fail for other types.
302
func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) {
3✔
303
        var (
3✔
304
                indexType   paymentIndexType
3✔
305
                paymentHash []byte
3✔
306
        )
3✔
307

3✔
308
        if err := ReadElements(r, &indexType, &paymentHash); err != nil {
3✔
NEW
309
                return lntypes.Hash{}, err
×
NEW
310
        }
×
311

312
        // While we only have on payment index type, we do not need to use our
313
        // index type to deserialize the index. However, we sanity check that
314
        // this type is as expected, since we had to read it out anyway.
315
        if indexType != paymentIndexTypeHash {
3✔
NEW
316
                return lntypes.Hash{}, fmt.Errorf("unknown payment index "+
×
NEW
317
                        "type: %v", indexType)
×
NEW
318
        }
×
319

320
        hash, err := lntypes.MakeHash(paymentHash)
3✔
321
        if err != nil {
3✔
NEW
322
                return lntypes.Hash{}, err
×
NEW
323
        }
×
324

325
        return hash, nil
3✔
326
}
327

328
// RegisterAttempt atomically records the provided HTLCAttemptInfo to the
329
// DB.
330
func (p *KVPaymentsDB) RegisterAttempt(paymentHash lntypes.Hash,
331
        attempt *pymtpkg.HTLCAttemptInfo) (*pymtpkg.MPPayment, error) {
3✔
332

3✔
333
        // Serialize the information before opening the db transaction.
3✔
334
        var a bytes.Buffer
3✔
335
        err := serializeHTLCAttemptInfo(&a, attempt)
3✔
336
        if err != nil {
3✔
NEW
337
                return nil, err
×
NEW
338
        }
×
339
        htlcInfoBytes := a.Bytes()
3✔
340

3✔
341
        htlcIDBytes := make([]byte, 8)
3✔
342
        binary.BigEndian.PutUint64(htlcIDBytes, attempt.AttemptID)
3✔
343

3✔
344
        var payment *pymtpkg.MPPayment
3✔
345
        err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
6✔
346
                prefetchPayment(tx, paymentHash)
3✔
347
                bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
3✔
348
                if err != nil {
3✔
NEW
349
                        return err
×
NEW
350
                }
×
351

352
                payment, err = fetchPayment(bucket)
3✔
353
                if err != nil {
3✔
NEW
354
                        return err
×
NEW
355
                }
×
356

357
                // Check if registering a new attempt is allowed.
358
                if err := payment.Registrable(); err != nil {
3✔
NEW
359
                        return err
×
NEW
360
                }
×
361

362
                // If the final hop has encrypted data, then we know this is a
363
                // blinded payment. In blinded payments, MPP records are not set
364
                // for split payments and the recipient is responsible for using
365
                // a consistent PathID across the various encrypted data
366
                // payloads that we received from them for this payment. All we
367
                // need to check is that the total amount field for each HTLC
368
                // in the split payment is correct.
369
                isBlinded := len(attempt.Route.FinalHop().EncryptedData) != 0
3✔
370

3✔
371
                // Make sure any existing shards match the new one with regards
3✔
372
                // to MPP options.
3✔
373
                mpp := attempt.Route.FinalHop().MPP
3✔
374

3✔
375
                // MPP records should not be set for attempts to blinded paths.
3✔
376
                if isBlinded && mpp != nil {
3✔
NEW
377
                        return pymtpkg.ErrMPPRecordInBlindedPayment
×
NEW
378
                }
×
379

380
                for _, h := range payment.InFlightHTLCs() {
6✔
381
                        hMpp := h.Route.FinalHop().MPP
3✔
382

3✔
383
                        // If this is a blinded payment, then no existing HTLCs
3✔
384
                        // should have MPP records.
3✔
385
                        if isBlinded && hMpp != nil {
3✔
NEW
386
                                return pymtpkg.ErrMPPRecordInBlindedPayment
×
NEW
387
                        }
×
388

389
                        // If this is a blinded payment, then we just need to
390
                        // check that the TotalAmtMsat field for this shard
391
                        // is equal to that of any other shard in the same
392
                        // payment.
393
                        if isBlinded {
6✔
394
                                if attempt.Route.FinalHop().TotalAmtMsat !=
3✔
395
                                        h.Route.FinalHop().TotalAmtMsat {
3✔
NEW
396

×
NEW
397
                                        //nolint:ll
×
NEW
398
                                        return pymtpkg.ErrBlindedPaymentTotalAmountMismatch
×
NEW
399
                                }
×
400

401
                                continue
3✔
402
                        }
403

404
                        switch {
3✔
405
                        // We tried to register a non-MPP attempt for a MPP
406
                        // payment.
NEW
407
                        case mpp == nil && hMpp != nil:
×
NEW
408
                                return pymtpkg.ErrMPPayment
×
409

410
                        // We tried to register a MPP shard for a non-MPP
411
                        // payment.
NEW
412
                        case mpp != nil && hMpp == nil:
×
NEW
413
                                return pymtpkg.ErrNonMPPayment
×
414

415
                        // Non-MPP payment, nothing more to validate.
NEW
416
                        case mpp == nil:
×
NEW
417
                                continue
×
418
                        }
419

420
                        // Check that MPP options match.
421
                        if mpp.PaymentAddr() != hMpp.PaymentAddr() {
3✔
NEW
422
                                return pymtpkg.ErrMPPPaymentAddrMismatch
×
NEW
423
                        }
×
424

425
                        if mpp.TotalMsat() != hMpp.TotalMsat() {
3✔
NEW
426
                                return pymtpkg.ErrMPPTotalAmountMismatch
×
NEW
427
                        }
×
428
                }
429

430
                // If this is a non-MPP attempt, it must match the total amount
431
                // exactly. Note that a blinded payment is considered an MPP
432
                // attempt.
433
                amt := attempt.Route.ReceiverAmt()
3✔
434
                if !isBlinded && mpp == nil && amt != payment.Info.Value {
3✔
NEW
435
                        return pymtpkg.ErrValueMismatch
×
NEW
436
                }
×
437

438
                // Ensure we aren't sending more than the total payment amount.
439
                sentAmt, _ := payment.SentAmt()
3✔
440
                if sentAmt+amt > payment.Info.Value {
3✔
NEW
441
                        return fmt.Errorf("%w: attempted=%v, payment amount="+
×
NEW
442
                                "%v", pymtpkg.ErrValueExceedsAmt, sentAmt+amt,
×
NEW
443
                                payment.Info.Value)
×
NEW
444
                }
×
445

446
                htlcsBucket, err := bucket.CreateBucketIfNotExists(
3✔
447
                        paymentHtlcsBucket,
3✔
448
                )
3✔
449
                if err != nil {
3✔
NEW
450
                        return err
×
NEW
451
                }
×
452

453
                err = htlcsBucket.Put(
3✔
454
                        htlcBucketKey(htlcAttemptInfoKey, htlcIDBytes),
3✔
455
                        htlcInfoBytes,
3✔
456
                )
3✔
457
                if err != nil {
3✔
NEW
458
                        return err
×
NEW
459
                }
×
460

461
                // Retrieve attempt info for the notification.
462
                payment, err = fetchPayment(bucket)
3✔
463

3✔
464
                return err
3✔
465
        })
466
        if err != nil {
3✔
NEW
467
                return nil, err
×
NEW
468
        }
×
469

470
        return payment, err
3✔
471
}
472

473
// SettleAttempt marks the given attempt settled with the preimage. If this is
474
// a multi shard payment, this might implicitly mean that the full payment
475
// succeeded.
476
//
477
// After invoking this method, InitPayment should always return an error to
478
// prevent us from making duplicate payments to the same payment hash. The
479
// provided preimage is atomically saved to the DB for record keeping.
480
func (p *KVPaymentsDB) SettleAttempt(hash lntypes.Hash, attemptID uint64,
481
        settleInfo *pymtpkg.HTLCSettleInfo) (*pymtpkg.MPPayment, error) {
3✔
482

3✔
483
        var b bytes.Buffer
3✔
484
        if err := serializeHTLCSettleInfo(&b, settleInfo); err != nil {
3✔
NEW
485
                return nil, err
×
NEW
486
        }
×
487
        settleBytes := b.Bytes()
3✔
488

3✔
489
        return p.updateHtlcKey(hash, attemptID, htlcSettleInfoKey, settleBytes)
3✔
490
}
491

492
// FailAttempt marks the given payment attempt failed.
493
func (p *KVPaymentsDB) FailAttempt(hash lntypes.Hash, attemptID uint64,
494
        failInfo *pymtpkg.HTLCFailInfo) (*pymtpkg.MPPayment, error) {
3✔
495

3✔
496
        var b bytes.Buffer
3✔
497
        if err := serializeHTLCFailInfo(&b, failInfo); err != nil {
3✔
NEW
498
                return nil, err
×
NEW
499
        }
×
500
        failBytes := b.Bytes()
3✔
501

3✔
502
        return p.updateHtlcKey(hash, attemptID, htlcFailInfoKey, failBytes)
3✔
503
}
504

505
// updateHtlcKey updates a database key for the specified htlc.
506
func (p *KVPaymentsDB) updateHtlcKey(paymentHash lntypes.Hash,
507
        attemptID uint64, key, value []byte) (*pymtpkg.MPPayment, error) {
3✔
508

3✔
509
        aid := make([]byte, 8)
3✔
510
        binary.BigEndian.PutUint64(aid, attemptID)
3✔
511

3✔
512
        var payment *pymtpkg.MPPayment
3✔
513
        err := kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
6✔
514
                payment = nil
3✔
515

3✔
516
                prefetchPayment(tx, paymentHash)
3✔
517
                bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
3✔
518
                if err != nil {
3✔
NEW
519
                        return err
×
NEW
520
                }
×
521

522
                p, err := fetchPayment(bucket)
3✔
523
                if err != nil {
3✔
NEW
524
                        return err
×
NEW
525
                }
×
526

527
                // We can only update keys of in-flight payments. We allow
528
                // updating keys even if the payment has reached a terminal
529
                // condition, since the HTLC outcomes must still be updated.
530
                if err := p.Status.Updatable(); err != nil {
3✔
NEW
531
                        return err
×
NEW
532
                }
×
533

534
                htlcsBucket := bucket.NestedReadWriteBucket(paymentHtlcsBucket)
3✔
535
                if htlcsBucket == nil {
3✔
NEW
536
                        return fmt.Errorf("htlcs bucket not found")
×
NEW
537
                }
×
538

539
                attemptInfo := htlcsBucket.Get(
3✔
540
                        htlcBucketKey(htlcAttemptInfoKey, aid),
3✔
541
                )
3✔
542
                if attemptInfo == nil {
3✔
NEW
543
                        return fmt.Errorf("HTLC with ID %v not registered",
×
NEW
544
                                attemptID)
×
NEW
545
                }
×
546

547
                failInfo := htlcsBucket.Get(
3✔
548
                        htlcBucketKey(htlcFailInfoKey, aid),
3✔
549
                )
3✔
550
                if failInfo != nil {
3✔
NEW
551
                        return pymtpkg.ErrAttemptAlreadyFailed
×
NEW
552
                }
×
553

554
                settleInfo := htlcsBucket.Get(
3✔
555
                        htlcBucketKey(htlcSettleInfoKey, aid),
3✔
556
                )
3✔
557
                if settleInfo != nil {
3✔
NEW
558
                        return pymtpkg.ErrAttemptAlreadySettled
×
NEW
559
                }
×
560

561
                // Add or update the key for this htlc.
562
                err = htlcsBucket.Put(htlcBucketKey(key, aid), value)
3✔
563
                if err != nil {
3✔
NEW
564
                        return err
×
NEW
565
                }
×
566

567
                // Retrieve attempt info for the notification.
568
                payment, err = fetchPayment(bucket)
3✔
569

3✔
570
                return err
3✔
571
        })
572
        if err != nil {
3✔
NEW
573
                return nil, err
×
NEW
574
        }
×
575

576
        return payment, err
3✔
577
}
578

579
// Fail transitions a payment into the Failed state, and records the reason the
580
// payment failed. After invoking this method, InitPayment should return nil on
581
// its next call for this payment hash, allowing the switch to make a
582
// subsequent payment.
583
func (p *KVPaymentsDB) Fail(paymentHash lntypes.Hash,
584
        reason pymtpkg.FailureReason) (*pymtpkg.MPPayment, error) {
3✔
585

3✔
586
        var (
3✔
587
                updateErr error
3✔
588
                payment   *pymtpkg.MPPayment
3✔
589
        )
3✔
590
        err := kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
6✔
591
                // Reset the update error, to avoid carrying over an error
3✔
592
                // from a previous execution of the batched db transaction.
3✔
593
                updateErr = nil
3✔
594
                payment = nil
3✔
595

3✔
596
                prefetchPayment(tx, paymentHash)
3✔
597
                bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
3✔
598
                if errors.Is(err, pymtpkg.ErrPaymentNotInitiated) {
3✔
NEW
599
                        updateErr = pymtpkg.ErrPaymentNotInitiated
×
NEW
600
                        return nil
×
601
                } else if err != nil {
3✔
NEW
602
                        return err
×
NEW
603
                }
×
604

605
                // We mark the payment as failed as long as it is known. This
606
                // lets the last attempt to fail with a terminal write its
607
                // failure to the KVPaymentsDB without synchronizing with
608
                // other attempts.
609
                _, err = fetchPaymentStatus(bucket)
3✔
610
                if errors.Is(err, pymtpkg.ErrPaymentNotInitiated) {
3✔
NEW
611
                        updateErr = pymtpkg.ErrPaymentNotInitiated
×
NEW
612
                        return nil
×
613
                } else if err != nil {
3✔
NEW
614
                        return err
×
NEW
615
                }
×
616

617
                // Put the failure reason in the bucket for record keeping.
618
                v := []byte{byte(reason)}
3✔
619
                err = bucket.Put(paymentFailInfoKey, v)
3✔
620
                if err != nil {
3✔
NEW
621
                        return err
×
NEW
622
                }
×
623

624
                // Retrieve attempt info for the notification, if available.
625
                payment, err = fetchPayment(bucket)
3✔
626
                if err != nil {
3✔
NEW
627
                        return err
×
NEW
628
                }
×
629

630
                return nil
3✔
631
        })
632
        if err != nil {
3✔
NEW
633
                return nil, err
×
NEW
634
        }
×
635

636
        return payment, updateErr
3✔
637
}
638

639
// FetchPayment returns information about a payment from the database.
640
func (p *KVPaymentsDB) FetchPayment(paymentHash lntypes.Hash) (
641
        *pymtpkg.MPPayment, error) {
3✔
642

3✔
643
        var payment *pymtpkg.MPPayment
3✔
644
        err := kvdb.View(p.db, func(tx kvdb.RTx) error {
6✔
645
                prefetchPayment(tx, paymentHash)
3✔
646
                bucket, err := fetchPaymentBucket(tx, paymentHash)
3✔
647
                if err != nil {
3✔
NEW
648
                        return err
×
NEW
649
                }
×
650

651
                payment, err = fetchPayment(bucket)
3✔
652

3✔
653
                return err
3✔
654
        }, func() {
3✔
655
                payment = nil
3✔
656
        })
3✔
657
        if err != nil {
3✔
NEW
658
                return nil, err
×
NEW
659
        }
×
660

661
        return payment, nil
3✔
662
}
663

664
// prefetchPayment attempts to prefetch as much of the payment as possible to
665
// reduce DB roundtrips.
666
func prefetchPayment(tx kvdb.RTx, paymentHash lntypes.Hash) {
3✔
667
        rb := kvdb.RootBucket(tx)
3✔
668
        kvdb.Prefetch(
3✔
669
                rb,
3✔
670
                []string{
3✔
671
                        // Prefetch all keys in the payment's bucket.
3✔
672
                        string(paymentsRootBucket),
3✔
673
                        string(paymentHash[:]),
3✔
674
                },
3✔
675
                []string{
3✔
676
                        // Prefetch all keys in the payment's htlc bucket.
3✔
677
                        string(paymentsRootBucket),
3✔
678
                        string(paymentHash[:]),
3✔
679
                        string(paymentHtlcsBucket),
3✔
680
                },
3✔
681
        )
3✔
682
}
3✔
683

684
// createPaymentBucket creates or fetches the sub-bucket assigned to this
685
// payment hash.
686
func createPaymentBucket(tx kvdb.RwTx, paymentHash lntypes.Hash) (
687
        kvdb.RwBucket, error) {
3✔
688

3✔
689
        payments, err := tx.CreateTopLevelBucket(paymentsRootBucket)
3✔
690
        if err != nil {
3✔
NEW
691
                return nil, err
×
NEW
692
        }
×
693

694
        return payments.CreateBucketIfNotExists(paymentHash[:])
3✔
695
}
696

697
// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If
698
// the bucket does not exist, it returns ErrPaymentNotInitiated.
699
func fetchPaymentBucket(tx kvdb.RTx, paymentHash lntypes.Hash) (
700
        kvdb.RBucket, error) {
3✔
701

3✔
702
        payments := tx.ReadBucket(paymentsRootBucket)
3✔
703
        if payments == nil {
3✔
NEW
704
                return nil, pymtpkg.ErrPaymentNotInitiated
×
NEW
705
        }
×
706

707
        bucket := payments.NestedReadBucket(paymentHash[:])
3✔
708
        if bucket == nil {
3✔
NEW
709
                return nil, pymtpkg.ErrPaymentNotInitiated
×
NEW
710
        }
×
711

712
        return bucket, nil
3✔
713
}
714

715
// fetchPaymentBucketUpdate is identical to fetchPaymentBucket, but it returns a
716
// bucket that can be written to.
717
func fetchPaymentBucketUpdate(tx kvdb.RwTx, paymentHash lntypes.Hash) (
718
        kvdb.RwBucket, error) {
3✔
719

3✔
720
        payments := tx.ReadWriteBucket(paymentsRootBucket)
3✔
721
        if payments == nil {
3✔
NEW
722
                return nil, pymtpkg.ErrPaymentNotInitiated
×
NEW
723
        }
×
724

725
        bucket := payments.NestedReadWriteBucket(paymentHash[:])
3✔
726
        if bucket == nil {
3✔
NEW
727
                return nil, pymtpkg.ErrPaymentNotInitiated
×
NEW
728
        }
×
729

730
        return bucket, nil
3✔
731
}
732

733
// nextPaymentSequence returns the next sequence number to store for a new
734
// payment.
735
func (p *KVPaymentsDB) nextPaymentSequence() ([]byte, error) {
3✔
736
        p.paymentSeqMx.Lock()
3✔
737
        defer p.paymentSeqMx.Unlock()
3✔
738

3✔
739
        // Set a new upper bound in the DB every 1000 payments to avoid
3✔
740
        // conflicts on the sequence when using etcd.
3✔
741
        if p.currPaymentSeq == p.storedPaymentSeq {
6✔
742
                var currPaymentSeq, newUpperBound uint64
3✔
743
                if err := kvdb.Update(p.db.Backend, func(tx kvdb.RwTx) error {
6✔
744
                        paymentsBucket, err := tx.CreateTopLevelBucket(
3✔
745
                                paymentsRootBucket,
3✔
746
                        )
3✔
747
                        if err != nil {
3✔
NEW
748
                                return err
×
NEW
749
                        }
×
750

751
                        currPaymentSeq = paymentsBucket.Sequence()
3✔
752
                        newUpperBound = currPaymentSeq + paymentSeqBlockSize
3✔
753

3✔
754
                        return paymentsBucket.SetSequence(newUpperBound)
3✔
755
                }, func() {}); err != nil {
3✔
NEW
756
                        return nil, err
×
NEW
757
                }
×
758

759
                // We lazy initialize the cached currPaymentSeq here using the
760
                // first nextPaymentSequence() call. This if statement will auto
761
                // initialize our stored currPaymentSeq, since by default both
762
                // this variable and storedPaymentSeq are zero which in turn
763
                // will have us fetch the current values from the DB.
764
                if p.currPaymentSeq == 0 {
6✔
765
                        p.currPaymentSeq = currPaymentSeq
3✔
766
                }
3✔
767

768
                p.storedPaymentSeq = newUpperBound
3✔
769
        }
770

771
        p.currPaymentSeq++
3✔
772
        b := make([]byte, 8)
3✔
773
        binary.BigEndian.PutUint64(b, p.currPaymentSeq)
3✔
774

3✔
775
        return b, nil
3✔
776
}
777

778
// fetchPaymentStatus fetches the payment status of the payment. If the payment
779
// isn't found, it will return error `ErrPaymentNotInitiated`.
780
func fetchPaymentStatus(bucket kvdb.RBucket) (pymtpkg.PaymentStatus, error) {
3✔
781
        // Creation info should be set for all payments, regardless of state.
3✔
782
        // If not, it is unknown.
3✔
783
        if bucket.Get(paymentCreationInfoKey) == nil {
6✔
784
                return 0, pymtpkg.ErrPaymentNotInitiated
3✔
785
        }
3✔
786

787
        payment, err := fetchPayment(bucket)
3✔
788
        if err != nil {
3✔
NEW
789
                return 0, err
×
NEW
790
        }
×
791

792
        return payment.Status, nil
3✔
793
}
794

795
// FetchInFlightPayments returns all payments with status InFlight.
796
func (p *KVPaymentsDB) FetchInFlightPayments() ([]*pymtpkg.MPPayment, error) {
3✔
797
        var (
3✔
798
                inFlights      []*pymtpkg.MPPayment
3✔
799
                start          = time.Now()
3✔
800
                lastLogTime    = time.Now()
3✔
801
                processedCount int
3✔
802
        )
3✔
803

3✔
804
        err := kvdb.View(p.db, func(tx kvdb.RTx) error {
6✔
805
                payments := tx.ReadBucket(paymentsRootBucket)
3✔
806
                if payments == nil {
6✔
807
                        return nil
3✔
808
                }
3✔
809

810
                return payments.ForEach(func(k, _ []byte) error {
6✔
811
                        bucket := payments.NestedReadBucket(k)
3✔
812
                        if bucket == nil {
3✔
NEW
813
                                return fmt.Errorf("non bucket element")
×
NEW
814
                        }
×
815

816
                        p, err := fetchPayment(bucket)
3✔
817
                        if err != nil {
3✔
NEW
818
                                return err
×
NEW
819
                        }
×
820

821
                        processedCount++
3✔
822
                        if time.Since(lastLogTime) >=
3✔
823
                                paymentProgressLogInterval {
3✔
NEW
824

×
NEW
825
                                log.Debugf("Scanning inflight payments "+
×
NEW
826
                                        "(in progress), processed %d, last "+
×
NEW
827
                                        "processed payment: %v", processedCount,
×
NEW
828
                                        p.Info)
×
NEW
829

×
NEW
830
                                lastLogTime = time.Now()
×
NEW
831
                        }
×
832

833
                        // Skip the payment if it's terminated.
834
                        if p.Terminated() {
6✔
835
                                return nil
3✔
836
                        }
3✔
837

838
                        inFlights = append(inFlights, p)
3✔
839

3✔
840
                        return nil
3✔
841
                })
842
        }, func() {
3✔
843
                inFlights = nil
3✔
844
        })
3✔
845
        if err != nil {
3✔
NEW
846
                return nil, err
×
NEW
847
        }
×
848

849
        elapsed := time.Since(start)
3✔
850
        log.Debugf("Completed scanning for inflight payments: "+
3✔
851
                "total_processed=%d, found_inflight=%d, elapsed=%v",
3✔
852
                processedCount, len(inFlights),
3✔
853
                elapsed.Round(time.Millisecond))
3✔
854

3✔
855
        return inFlights, nil
3✔
856
}
857

858
// htlcBucketKey creates a composite key from prefix and id where the result is
859
// simply the two concatenated.
860
func htlcBucketKey(prefix, id []byte) []byte {
3✔
861
        key := make([]byte, len(prefix)+len(id))
3✔
862
        copy(key, prefix)
3✔
863
        copy(key[len(prefix):], id)
3✔
864
        return key
3✔
865
}
3✔
866

867
// FetchPayments returns all sent payments found in the DB.
868
//
869
// nolint: dupl
NEW
870
func (p *KVPaymentsDB) FetchPayments() ([]*pymtpkg.MPPayment, error) {
×
NEW
871
        var payments []*pymtpkg.MPPayment
×
UNCOV
872

×
NEW
873
        err := kvdb.View(p.db, func(tx kvdb.RTx) error {
×
UNCOV
874
                paymentsBucket := tx.ReadBucket(paymentsRootBucket)
×
UNCOV
875
                if paymentsBucket == nil {
×
876
                        return nil
×
877
                }
×
878

UNCOV
879
                return paymentsBucket.ForEach(func(k, v []byte) error {
×
UNCOV
880
                        bucket := paymentsBucket.NestedReadBucket(k)
×
UNCOV
881
                        if bucket == nil {
×
882
                                // We only expect sub-buckets to be found in
×
883
                                // this top-level bucket.
×
884
                                return fmt.Errorf("non bucket element in " +
×
885
                                        "payments bucket")
×
886
                        }
×
887

UNCOV
888
                        p, err := fetchPayment(bucket)
×
UNCOV
889
                        if err != nil {
×
890
                                return err
×
891
                        }
×
892

UNCOV
893
                        payments = append(payments, p)
×
UNCOV
894

×
UNCOV
895
                        // For older versions of lnd, duplicate payments to a
×
UNCOV
896
                        // payment has was possible. These will be found in a
×
UNCOV
897
                        // sub-bucket indexed by their sequence number if
×
UNCOV
898
                        // available.
×
UNCOV
899
                        duplicatePayments, err := fetchDuplicatePayments(bucket)
×
UNCOV
900
                        if err != nil {
×
901
                                return err
×
902
                        }
×
903

UNCOV
904
                        payments = append(payments, duplicatePayments...)
×
UNCOV
905
                        return nil
×
906
                })
UNCOV
907
        }, func() {
×
UNCOV
908
                payments = nil
×
UNCOV
909
        })
×
UNCOV
910
        if err != nil {
×
911
                return nil, err
×
912
        }
×
913

914
        // Before returning, sort the payments by their sequence number.
UNCOV
915
        sort.Slice(payments, func(i, j int) bool {
×
UNCOV
916
                return payments[i].SequenceNum < payments[j].SequenceNum
×
UNCOV
917
        })
×
918

UNCOV
919
        return payments, nil
×
920
}
921

922
func fetchCreationInfo(bucket kvdb.RBucket) (*pymtpkg.PaymentCreationInfo,
923
        error) {
3✔
924

3✔
925
        b := bucket.Get(paymentCreationInfoKey)
3✔
926
        if b == nil {
3✔
927
                return nil, fmt.Errorf("creation info not found")
×
928
        }
×
929

930
        r := bytes.NewReader(b)
3✔
931
        return deserializePaymentCreationInfo(r)
3✔
932
}
933

934
func fetchPayment(bucket kvdb.RBucket) (*pymtpkg.MPPayment, error) {
3✔
935
        seqBytes := bucket.Get(paymentSequenceKey)
3✔
936
        if seqBytes == nil {
3✔
937
                return nil, fmt.Errorf("sequence number not found")
×
938
        }
×
939

940
        sequenceNum := binary.BigEndian.Uint64(seqBytes)
3✔
941

3✔
942
        // Get the PaymentCreationInfo.
3✔
943
        creationInfo, err := fetchCreationInfo(bucket)
3✔
944
        if err != nil {
3✔
945
                return nil, err
×
946
        }
×
947

948
        var htlcs []pymtpkg.HTLCAttempt
3✔
949
        htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket)
3✔
950
        if htlcsBucket != nil {
6✔
951
                // Get the payment attempts. This can be empty.
3✔
952
                htlcs, err = fetchHtlcAttempts(htlcsBucket)
3✔
953
                if err != nil {
3✔
954
                        return nil, err
×
955
                }
×
956
        }
957

958
        // Get failure reason if available.
959
        var failureReason *pymtpkg.FailureReason
3✔
960
        b := bucket.Get(paymentFailInfoKey)
3✔
961
        if b != nil {
6✔
962
                reason := pymtpkg.FailureReason(b[0])
3✔
963
                failureReason = &reason
3✔
964
        }
3✔
965

966
        // Create a new payment.
967
        payment := &pymtpkg.MPPayment{
3✔
968
                SequenceNum:   sequenceNum,
3✔
969
                Info:          creationInfo,
3✔
970
                HTLCs:         htlcs,
3✔
971
                FailureReason: failureReason,
3✔
972
        }
3✔
973

3✔
974
        // Set its state and status.
3✔
975
        if err := payment.SetState(); err != nil {
3✔
976
                return nil, err
×
977
        }
×
978

979
        return payment, nil
3✔
980
}
981

982
// fetchHtlcAttempts retrieves all htlc attempts made for the payment found in
983
// the given bucket.
984
func fetchHtlcAttempts(bucket kvdb.RBucket) ([]pymtpkg.HTLCAttempt, error) {
3✔
985
        htlcsMap := make(map[uint64]*pymtpkg.HTLCAttempt)
3✔
986

3✔
987
        attemptInfoCount := 0
3✔
988
        err := bucket.ForEach(func(k, v []byte) error {
6✔
989
                aid := byteOrder.Uint64(k[len(k)-8:])
3✔
990

3✔
991
                if _, ok := htlcsMap[aid]; !ok {
6✔
992
                        htlcsMap[aid] = &pymtpkg.HTLCAttempt{}
3✔
993
                }
3✔
994

995
                var err error
3✔
996
                switch {
3✔
997
                case bytes.HasPrefix(k, htlcAttemptInfoKey):
3✔
998
                        attemptInfo, err := readHtlcAttemptInfo(v)
3✔
999
                        if err != nil {
3✔
1000
                                return err
×
1001
                        }
×
1002

1003
                        attemptInfo.AttemptID = aid
3✔
1004
                        htlcsMap[aid].HTLCAttemptInfo = *attemptInfo
3✔
1005
                        attemptInfoCount++
3✔
1006

1007
                case bytes.HasPrefix(k, htlcSettleInfoKey):
3✔
1008
                        htlcsMap[aid].Settle, err = readHtlcSettleInfo(v)
3✔
1009
                        if err != nil {
3✔
1010
                                return err
×
1011
                        }
×
1012

1013
                case bytes.HasPrefix(k, htlcFailInfoKey):
3✔
1014
                        htlcsMap[aid].Failure, err = readHtlcFailInfo(v)
3✔
1015
                        if err != nil {
3✔
1016
                                return err
×
1017
                        }
×
1018

1019
                default:
×
1020
                        return fmt.Errorf("unknown htlc attempt key")
×
1021
                }
1022

1023
                return nil
3✔
1024
        })
1025
        if err != nil {
3✔
1026
                return nil, err
×
1027
        }
×
1028

1029
        // Sanity check that all htlcs have an attempt info.
1030
        if attemptInfoCount != len(htlcsMap) {
3✔
NEW
1031
                return nil, pymtpkg.ErrNoAttemptInfo
×
1032
        }
×
1033

1034
        keys := make([]uint64, len(htlcsMap))
3✔
1035
        i := 0
3✔
1036
        for k := range htlcsMap {
6✔
1037
                keys[i] = k
3✔
1038
                i++
3✔
1039
        }
3✔
1040

1041
        // Sort HTLC attempts by their attempt ID. This is needed because in the
1042
        // DB we store the attempts with keys prefixed by their status which
1043
        // changes order (groups them together by status).
1044
        sort.Slice(keys, func(i, j int) bool {
6✔
1045
                return keys[i] < keys[j]
3✔
1046
        })
3✔
1047

1048
        htlcs := make([]pymtpkg.HTLCAttempt, len(htlcsMap))
3✔
1049
        for i, key := range keys {
6✔
1050
                htlcs[i] = *htlcsMap[key]
3✔
1051
        }
3✔
1052

1053
        return htlcs, nil
3✔
1054
}
1055

1056
// readHtlcAttemptInfo reads the payment attempt info for this htlc.
1057
func readHtlcAttemptInfo(b []byte) (*pymtpkg.HTLCAttemptInfo, error) {
3✔
1058
        r := bytes.NewReader(b)
3✔
1059
        return deserializeHTLCAttemptInfo(r)
3✔
1060
}
3✔
1061

1062
// readHtlcSettleInfo reads the settle info for the htlc. If the htlc isn't
1063
// settled, nil is returned.
1064
func readHtlcSettleInfo(b []byte) (*pymtpkg.HTLCSettleInfo, error) {
3✔
1065
        r := bytes.NewReader(b)
3✔
1066
        return deserializeHTLCSettleInfo(r)
3✔
1067
}
3✔
1068

1069
// readHtlcFailInfo reads the failure info for the htlc. If the htlc hasn't
1070
// failed, nil is returned.
1071
func readHtlcFailInfo(b []byte) (*pymtpkg.HTLCFailInfo, error) {
3✔
1072
        r := bytes.NewReader(b)
3✔
1073
        return deserializeHTLCFailInfo(r)
3✔
1074
}
3✔
1075

1076
// fetchFailedHtlcKeys retrieves the bucket keys of all failed HTLCs of a
1077
// payment bucket.
UNCOV
1078
func fetchFailedHtlcKeys(bucket kvdb.RBucket) ([][]byte, error) {
×
UNCOV
1079
        htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket)
×
UNCOV
1080

×
NEW
1081
        var htlcs []pymtpkg.HTLCAttempt
×
UNCOV
1082
        var err error
×
UNCOV
1083
        if htlcsBucket != nil {
×
UNCOV
1084
                htlcs, err = fetchHtlcAttempts(htlcsBucket)
×
UNCOV
1085
                if err != nil {
×
1086
                        return nil, err
×
1087
                }
×
1088
        }
1089

1090
        // Now iterate though them and save the bucket keys for the failed
1091
        // HTLCs.
UNCOV
1092
        var htlcKeys [][]byte
×
UNCOV
1093
        for _, h := range htlcs {
×
UNCOV
1094
                if h.Failure == nil {
×
UNCOV
1095
                        continue
×
1096
                }
1097

UNCOV
1098
                htlcKeyBytes := make([]byte, 8)
×
UNCOV
1099
                binary.BigEndian.PutUint64(htlcKeyBytes, h.AttemptID)
×
UNCOV
1100

×
UNCOV
1101
                htlcKeys = append(htlcKeys, htlcKeyBytes)
×
1102
        }
1103

UNCOV
1104
        return htlcKeys, nil
×
1105
}
1106

1107
// QueryPayments is a query to the payments database which is restricted
1108
// to a subset of payments by the payments query, containing an offset
1109
// index and a maximum number of returned payments.
1110
func (p *KVPaymentsDB) QueryPayments(_ context.Context,
1111
        query pymtpkg.Query) (pymtpkg.Response, error) {
3✔
1112

3✔
1113
        var resp pymtpkg.Response
3✔
1114

3✔
1115
        if err := kvdb.View(p.db, func(tx kvdb.RTx) error {
6✔
1116
                // Get the root payments bucket.
3✔
1117
                paymentsBucket := tx.ReadBucket(paymentsRootBucket)
3✔
1118
                if paymentsBucket == nil {
6✔
1119
                        return nil
3✔
1120
                }
3✔
1121

1122
                // Get the index bucket which maps sequence number -> payment
1123
                // hash and duplicate bool. If we have a payments bucket, we
1124
                // should have an indexes bucket as well.
1125
                indexes := tx.ReadBucket(paymentsIndexBucket)
3✔
1126
                if indexes == nil {
3✔
1127
                        return fmt.Errorf("index bucket does not exist")
×
1128
                }
×
1129

1130
                // accumulatePayments gets payments with the sequence number
1131
                // and hash provided and adds them to our list of payments if
1132
                // they meet the criteria of our query. It returns the number
1133
                // of payments that were added.
1134
                accumulatePayments := func(sequenceKey, hash []byte) (bool,
3✔
1135
                        error) {
6✔
1136

3✔
1137
                        r := bytes.NewReader(hash)
3✔
1138
                        paymentHash, err := deserializePaymentIndex(r)
3✔
1139
                        if err != nil {
3✔
1140
                                return false, err
×
1141
                        }
×
1142

1143
                        payment, err := fetchPaymentWithSequenceNumber(
3✔
1144
                                tx, paymentHash, sequenceKey,
3✔
1145
                        )
3✔
1146
                        if err != nil {
3✔
1147
                                return false, err
×
1148
                        }
×
1149

1150
                        // To keep compatibility with the old API, we only
1151
                        // return non-succeeded payments if requested.
1152
                        if payment.Status != pymtpkg.StatusSucceeded &&
3✔
1153
                                !query.IncludeIncomplete {
3✔
UNCOV
1154

×
UNCOV
1155
                                return false, err
×
UNCOV
1156
                        }
×
1157

1158
                        // Get the creation time in Unix seconds, this always
1159
                        // rounds down the nanoseconds to full seconds.
1160
                        createTime := payment.Info.CreationTime.Unix()
3✔
1161

3✔
1162
                        // Skip any payments that were created before the
3✔
1163
                        // specified time.
3✔
1164
                        if createTime < query.CreationDateStart {
6✔
1165
                                return false, nil
3✔
1166
                        }
3✔
1167

1168
                        // Skip any payments that were created after the
1169
                        // specified time.
1170
                        if query.CreationDateEnd != 0 &&
3✔
1171
                                createTime > query.CreationDateEnd {
6✔
1172

3✔
1173
                                return false, nil
3✔
1174
                        }
3✔
1175

1176
                        // At this point, we've exhausted the offset, so we'll
1177
                        // begin collecting invoices found within the range.
1178
                        resp.Payments = append(resp.Payments, payment)
3✔
1179
                        return true, nil
3✔
1180
                }
1181

1182
                // Create a paginator which reads from our sequence index bucket
1183
                // with the parameters provided by the payments query.
1184
                paginator := newPaginator(
3✔
1185
                        indexes.ReadCursor(), query.Reversed, query.IndexOffset,
3✔
1186
                        query.MaxPayments,
3✔
1187
                )
3✔
1188

3✔
1189
                // Run a paginated query, adding payments to our response.
3✔
1190
                if err := paginator.query(accumulatePayments); err != nil {
3✔
1191
                        return err
×
1192
                }
×
1193

1194
                // Counting the total number of payments is expensive, since we
1195
                // literally have to traverse the cursor linearly, which can
1196
                // take quite a while. So it's an optional query parameter.
1197
                if query.CountTotal {
3✔
1198
                        var (
×
1199
                                totalPayments uint64
×
1200
                                err           error
×
1201
                        )
×
1202
                        countFn := func(_, _ []byte) error {
×
1203
                                totalPayments++
×
1204

×
1205
                                return nil
×
1206
                        }
×
1207

1208
                        // In non-boltdb database backends, there's a faster
1209
                        // ForAll query that allows for batch fetching items.
1210
                        if fastBucket, ok := indexes.(kvdb.ExtendedRBucket); ok {
×
1211
                                err = fastBucket.ForAll(countFn)
×
1212
                        } else {
×
1213
                                err = indexes.ForEach(countFn)
×
1214
                        }
×
1215
                        if err != nil {
×
1216
                                return fmt.Errorf("error counting payments: %w",
×
1217
                                        err)
×
1218
                        }
×
1219

1220
                        resp.TotalCount = totalPayments
×
1221
                }
1222

1223
                return nil
3✔
1224
        }, func() {
3✔
1225
                resp = pymtpkg.Response{}
3✔
1226
        }); err != nil {
3✔
1227
                return resp, err
×
1228
        }
×
1229

1230
        // Need to swap the payments slice order if reversed order.
1231
        if query.Reversed {
3✔
UNCOV
1232
                for l, r := 0, len(resp.Payments)-1; l < r; l, r = l+1, r-1 {
×
UNCOV
1233
                        resp.Payments[l], resp.Payments[r] =
×
UNCOV
1234
                                resp.Payments[r], resp.Payments[l]
×
UNCOV
1235
                }
×
1236
        }
1237

1238
        // Set the first and last index of the returned payments so that the
1239
        // caller can resume from this point later on.
1240
        if len(resp.Payments) > 0 {
6✔
1241
                resp.FirstIndexOffset = resp.Payments[0].SequenceNum
3✔
1242
                resp.LastIndexOffset =
3✔
1243
                        resp.Payments[len(resp.Payments)-1].SequenceNum
3✔
1244
        }
3✔
1245

1246
        return resp, nil
3✔
1247
}
1248

1249
// fetchPaymentWithSequenceNumber get the payment which matches the payment hash
1250
// *and* sequence number provided from the database. This is required because
1251
// we previously had more than one payment per hash, so we have multiple indexes
1252
// pointing to a single payment; we want to retrieve the correct one.
1253
func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash,
1254
        sequenceNumber []byte) (*pymtpkg.MPPayment, error) {
3✔
1255

3✔
1256
        // We can now lookup the payment keyed by its hash in
3✔
1257
        // the payments root bucket.
3✔
1258
        bucket, err := fetchPaymentBucket(tx, paymentHash)
3✔
1259
        if err != nil {
3✔
1260
                return nil, err
×
1261
        }
×
1262

1263
        // A single payment hash can have multiple payments associated with it.
1264
        // We lookup our sequence number first, to determine whether this is
1265
        // the payment we are actually looking for.
1266
        seqBytes := bucket.Get(paymentSequenceKey)
3✔
1267
        if seqBytes == nil {
3✔
1268
                return nil, ErrNoSequenceNumber
×
1269
        }
×
1270

1271
        // If this top level payment has the sequence number we are looking for,
1272
        // return it.
1273
        if bytes.Equal(seqBytes, sequenceNumber) {
6✔
1274
                return fetchPayment(bucket)
3✔
1275
        }
3✔
1276

1277
        // If we were not looking for the top level payment, we are looking for
1278
        // one of our duplicate payments. We need to iterate through the seq
1279
        // numbers in this bucket to find the correct payments. If we do not
1280
        // find a duplicate payments bucket here, something is wrong.
UNCOV
1281
        dup := bucket.NestedReadBucket(duplicatePaymentsBucket)
×
UNCOV
1282
        if dup == nil {
×
UNCOV
1283
                return nil, ErrNoDuplicateBucket
×
UNCOV
1284
        }
×
1285

NEW
1286
        var duplicatePayment *pymtpkg.MPPayment
×
UNCOV
1287
        err = dup.ForEach(func(k, v []byte) error {
×
UNCOV
1288
                subBucket := dup.NestedReadBucket(k)
×
UNCOV
1289
                if subBucket == nil {
×
1290
                        // We one bucket for each duplicate to be found.
×
1291
                        return ErrNoDuplicateNestedBucket
×
1292
                }
×
1293

UNCOV
1294
                seqBytes := subBucket.Get(duplicatePaymentSequenceKey)
×
UNCOV
1295
                if seqBytes == nil {
×
1296
                        return err
×
1297
                }
×
1298

1299
                // If this duplicate payment is not the sequence number we are
1300
                // looking for, we can continue.
UNCOV
1301
                if !bytes.Equal(seqBytes, sequenceNumber) {
×
UNCOV
1302
                        return nil
×
UNCOV
1303
                }
×
1304

UNCOV
1305
                duplicatePayment, err = fetchDuplicatePayment(subBucket)
×
UNCOV
1306
                if err != nil {
×
1307
                        return err
×
1308
                }
×
1309

UNCOV
1310
                return nil
×
1311
        })
UNCOV
1312
        if err != nil {
×
1313
                return nil, err
×
1314
        }
×
1315

1316
        // If none of the duplicate payments matched our sequence number, we
1317
        // failed to find the payment with this sequence number; something is
1318
        // wrong.
UNCOV
1319
        if duplicatePayment == nil {
×
UNCOV
1320
                return nil, ErrDuplicateNotFound
×
UNCOV
1321
        }
×
1322

UNCOV
1323
        return duplicatePayment, nil
×
1324
}
1325

1326
// DeletePayment deletes a payment from the DB given its payment hash. If
1327
// failedHtlcsOnly is set, only failed HTLC attempts of the payment will be
1328
// deleted.
1329
func (p *KVPaymentsDB) DeletePayment(paymentHash lntypes.Hash,
UNCOV
1330
        failedHtlcsOnly bool) error {
×
UNCOV
1331

×
NEW
1332
        return kvdb.Update(p.db, func(tx kvdb.RwTx) error {
×
UNCOV
1333
                payments := tx.ReadWriteBucket(paymentsRootBucket)
×
UNCOV
1334
                if payments == nil {
×
1335
                        return nil
×
1336
                }
×
1337

UNCOV
1338
                bucket := payments.NestedReadWriteBucket(paymentHash[:])
×
UNCOV
1339
                if bucket == nil {
×
UNCOV
1340
                        return fmt.Errorf("non bucket element in payments " +
×
UNCOV
1341
                                "bucket")
×
UNCOV
1342
                }
×
1343

1344
                // If the status is InFlight, we cannot safely delete
1345
                // the payment information, so we return early.
UNCOV
1346
                paymentStatus, err := fetchPaymentStatus(bucket)
×
UNCOV
1347
                if err != nil {
×
1348
                        return err
×
1349
                }
×
1350

1351
                // If the payment has inflight HTLCs, we cannot safely delete
1352
                // the payment information, so we return an error.
NEW
1353
                if err := paymentStatus.Removable(); err != nil {
×
UNCOV
1354
                        return fmt.Errorf("payment '%v' has inflight HTLCs"+
×
UNCOV
1355
                                "and therefore cannot be deleted: %w",
×
UNCOV
1356
                                paymentHash.String(), err)
×
UNCOV
1357
                }
×
1358

1359
                // Delete the failed HTLC attempts we found.
UNCOV
1360
                if failedHtlcsOnly {
×
UNCOV
1361
                        toDelete, err := fetchFailedHtlcKeys(bucket)
×
UNCOV
1362
                        if err != nil {
×
1363
                                return err
×
1364
                        }
×
1365

UNCOV
1366
                        htlcsBucket := bucket.NestedReadWriteBucket(
×
UNCOV
1367
                                paymentHtlcsBucket,
×
UNCOV
1368
                        )
×
UNCOV
1369

×
UNCOV
1370
                        for _, htlcID := range toDelete {
×
UNCOV
1371
                                err = htlcsBucket.Delete(
×
UNCOV
1372
                                        htlcBucketKey(htlcAttemptInfoKey, htlcID),
×
UNCOV
1373
                                )
×
UNCOV
1374
                                if err != nil {
×
1375
                                        return err
×
1376
                                }
×
1377

UNCOV
1378
                                err = htlcsBucket.Delete(
×
UNCOV
1379
                                        htlcBucketKey(htlcFailInfoKey, htlcID),
×
UNCOV
1380
                                )
×
UNCOV
1381
                                if err != nil {
×
1382
                                        return err
×
1383
                                }
×
1384

UNCOV
1385
                                err = htlcsBucket.Delete(
×
UNCOV
1386
                                        htlcBucketKey(htlcSettleInfoKey, htlcID),
×
UNCOV
1387
                                )
×
UNCOV
1388
                                if err != nil {
×
1389
                                        return err
×
1390
                                }
×
1391
                        }
1392

UNCOV
1393
                        return nil
×
1394
                }
1395

UNCOV
1396
                seqNrs, err := fetchSequenceNumbers(bucket)
×
UNCOV
1397
                if err != nil {
×
1398
                        return err
×
1399
                }
×
1400

UNCOV
1401
                if err := payments.DeleteNestedBucket(paymentHash[:]); err != nil {
×
1402
                        return err
×
1403
                }
×
1404

UNCOV
1405
                indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
×
UNCOV
1406
                for _, k := range seqNrs {
×
UNCOV
1407
                        if err := indexBucket.Delete(k); err != nil {
×
1408
                                return err
×
1409
                        }
×
1410
                }
1411

UNCOV
1412
                return nil
×
UNCOV
1413
        }, func() {})
×
1414
}
1415

1416
// DeletePayments deletes all completed and failed payments from the DB. If
1417
// failedOnly is set, only failed payments will be considered for deletion. If
1418
// failedHtlcsOnly is set, the payment itself won't be deleted, only failed HTLC
1419
// attempts. The method returns the number of deleted payments, which is always
1420
// 0 if failedHtlcsOnly is set.
1421
func (p *KVPaymentsDB) DeletePayments(failedOnly, failedHtlcsOnly bool) (int,
1422
        error) {
3✔
1423

3✔
1424
        var numPayments int
3✔
1425
        err := kvdb.Update(p.db, func(tx kvdb.RwTx) error {
6✔
1426
                payments := tx.ReadWriteBucket(paymentsRootBucket)
3✔
1427
                if payments == nil {
3✔
1428
                        return nil
×
1429
                }
×
1430

1431
                var (
3✔
1432
                        // deleteBuckets is the set of payment buckets we need
3✔
1433
                        // to delete.
3✔
1434
                        deleteBuckets [][]byte
3✔
1435

3✔
1436
                        // deleteIndexes is the set of indexes pointing to these
3✔
1437
                        // payments that need to be deleted.
3✔
1438
                        deleteIndexes [][]byte
3✔
1439

3✔
1440
                        // deleteHtlcs maps a payment hash to the HTLC IDs we
3✔
1441
                        // want to delete for that payment.
3✔
1442
                        deleteHtlcs = make(map[lntypes.Hash][][]byte)
3✔
1443
                )
3✔
1444
                err := payments.ForEach(func(k, _ []byte) error {
6✔
1445
                        bucket := payments.NestedReadBucket(k)
3✔
1446
                        if bucket == nil {
3✔
1447
                                // We only expect sub-buckets to be found in
×
1448
                                // this top-level bucket.
×
1449
                                return fmt.Errorf("non bucket element in " +
×
1450
                                        "payments bucket")
×
1451
                        }
×
1452

1453
                        // If the status is InFlight, we cannot safely delete
1454
                        // the payment information, so we return early.
1455
                        paymentStatus, err := fetchPaymentStatus(bucket)
3✔
1456
                        if err != nil {
3✔
1457
                                return err
×
1458
                        }
×
1459

1460
                        // If the payment has inflight HTLCs, we cannot safely
1461
                        // delete the payment information, so we return an nil
1462
                        // to skip it.
1463
                        if err := paymentStatus.Removable(); err != nil {
3✔
UNCOV
1464
                                return nil
×
UNCOV
1465
                        }
×
1466

1467
                        // If we requested to only delete failed payments, we
1468
                        // can return if this one is not.
1469
                        if failedOnly && paymentStatus != pymtpkg.StatusFailed {
3✔
UNCOV
1470
                                return nil
×
UNCOV
1471
                        }
×
1472

1473
                        // If we are only deleting failed HTLCs, fetch them.
1474
                        if failedHtlcsOnly {
3✔
UNCOV
1475
                                toDelete, err := fetchFailedHtlcKeys(bucket)
×
UNCOV
1476
                                if err != nil {
×
1477
                                        return err
×
1478
                                }
×
1479

UNCOV
1480
                                hash, err := lntypes.MakeHash(k)
×
UNCOV
1481
                                if err != nil {
×
1482
                                        return err
×
1483
                                }
×
1484

UNCOV
1485
                                deleteHtlcs[hash] = toDelete
×
UNCOV
1486

×
UNCOV
1487
                                // We return, we are only deleting attempts.
×
UNCOV
1488
                                return nil
×
1489
                        }
1490

1491
                        // Add the bucket to the set of buckets we can delete.
1492
                        deleteBuckets = append(deleteBuckets, k)
3✔
1493

3✔
1494
                        // Get all the sequence number associated with the
3✔
1495
                        // payment, including duplicates.
3✔
1496
                        seqNrs, err := fetchSequenceNumbers(bucket)
3✔
1497
                        if err != nil {
3✔
1498
                                return err
×
1499
                        }
×
1500

1501
                        deleteIndexes = append(deleteIndexes, seqNrs...)
3✔
1502
                        numPayments++
3✔
1503
                        return nil
3✔
1504
                })
1505
                if err != nil {
3✔
1506
                        return err
×
1507
                }
×
1508

1509
                // Delete the failed HTLC attempts we found.
1510
                for hash, htlcIDs := range deleteHtlcs {
3✔
UNCOV
1511
                        bucket := payments.NestedReadWriteBucket(hash[:])
×
UNCOV
1512
                        htlcsBucket := bucket.NestedReadWriteBucket(
×
UNCOV
1513
                                paymentHtlcsBucket,
×
UNCOV
1514
                        )
×
UNCOV
1515

×
UNCOV
1516
                        for _, aid := range htlcIDs {
×
UNCOV
1517
                                if err := htlcsBucket.Delete(
×
UNCOV
1518
                                        htlcBucketKey(htlcAttemptInfoKey, aid),
×
UNCOV
1519
                                ); err != nil {
×
1520
                                        return err
×
1521
                                }
×
1522

UNCOV
1523
                                if err := htlcsBucket.Delete(
×
UNCOV
1524
                                        htlcBucketKey(htlcFailInfoKey, aid),
×
UNCOV
1525
                                ); err != nil {
×
1526
                                        return err
×
1527
                                }
×
1528

UNCOV
1529
                                if err := htlcsBucket.Delete(
×
UNCOV
1530
                                        htlcBucketKey(htlcSettleInfoKey, aid),
×
UNCOV
1531
                                ); err != nil {
×
1532
                                        return err
×
1533
                                }
×
1534
                        }
1535
                }
1536

1537
                for _, k := range deleteBuckets {
6✔
1538
                        if err := payments.DeleteNestedBucket(k); err != nil {
3✔
1539
                                return err
×
1540
                        }
×
1541
                }
1542

1543
                // Get our index bucket and delete all indexes pointing to the
1544
                // payments we are deleting.
1545
                indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
3✔
1546
                for _, k := range deleteIndexes {
6✔
1547
                        if err := indexBucket.Delete(k); err != nil {
3✔
1548
                                return err
×
1549
                        }
×
1550
                }
1551

1552
                return nil
3✔
1553
        }, func() {
3✔
1554
                numPayments = 0
3✔
1555
        })
3✔
1556
        if err != nil {
3✔
1557
                return 0, err
×
1558
        }
×
1559

1560
        return numPayments, nil
3✔
1561
}
1562

1563
// fetchSequenceNumbers fetches all the sequence numbers associated with a
1564
// payment, including those belonging to any duplicate payments.
1565
func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) {
3✔
1566
        seqNum := paymentBucket.Get(paymentSequenceKey)
3✔
1567
        if seqNum == nil {
3✔
1568
                return nil, errors.New("expected sequence number")
×
1569
        }
×
1570

1571
        sequenceNumbers := [][]byte{seqNum}
3✔
1572

3✔
1573
        // Get the duplicate payments bucket, if it has no duplicates, just
3✔
1574
        // return early with the payment sequence number.
3✔
1575
        duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket)
3✔
1576
        if duplicates == nil {
6✔
1577
                return sequenceNumbers, nil
3✔
1578
        }
3✔
1579

1580
        // If we do have duplicated, they are keyed by sequence number, so we
1581
        // iterate through the duplicates bucket and add them to our set of
1582
        // sequence numbers.
UNCOV
1583
        if err := duplicates.ForEach(func(k, v []byte) error {
×
UNCOV
1584
                sequenceNumbers = append(sequenceNumbers, k)
×
UNCOV
1585
                return nil
×
UNCOV
1586
        }); err != nil {
×
1587
                return nil, err
×
1588
        }
×
1589

UNCOV
1590
        return sequenceNumbers, nil
×
1591
}
1592

1593
// nolint: dupl
1594
func serializePaymentCreationInfo(w io.Writer,
1595
        c *pymtpkg.PaymentCreationInfo) error {
3✔
1596

3✔
1597
        var scratch [8]byte
3✔
1598

3✔
1599
        if _, err := w.Write(c.PaymentIdentifier[:]); err != nil {
3✔
1600
                return err
×
1601
        }
×
1602

1603
        byteOrder.PutUint64(scratch[:], uint64(c.Value))
3✔
1604
        if _, err := w.Write(scratch[:]); err != nil {
3✔
1605
                return err
×
1606
        }
×
1607

1608
        if err := serializeTime(w, c.CreationTime); err != nil {
3✔
1609
                return err
×
1610
        }
×
1611

1612
        byteOrder.PutUint32(scratch[:4], uint32(len(c.PaymentRequest)))
3✔
1613
        if _, err := w.Write(scratch[:4]); err != nil {
3✔
1614
                return err
×
1615
        }
×
1616

1617
        if _, err := w.Write(c.PaymentRequest[:]); err != nil {
3✔
1618
                return err
×
1619
        }
×
1620

1621
        // Any remaining bytes are TLV encoded records. Currently, these are
1622
        // only the custom records provided by the user to be sent to the first
1623
        // hop. But this can easily be extended with further records by merging
1624
        // the records into a single TLV stream.
1625
        err := c.FirstHopCustomRecords.SerializeTo(w)
3✔
1626
        if err != nil {
3✔
1627
                return err
×
1628
        }
×
1629

1630
        return nil
3✔
1631
}
1632

1633
func deserializePaymentCreationInfo(r io.Reader) (*pymtpkg.PaymentCreationInfo,
1634
        error) {
3✔
1635

3✔
1636
        var scratch [8]byte
3✔
1637

3✔
1638
        c := &pymtpkg.PaymentCreationInfo{}
3✔
1639

3✔
1640
        if _, err := io.ReadFull(r, c.PaymentIdentifier[:]); err != nil {
3✔
1641
                return nil, err
×
1642
        }
×
1643

1644
        if _, err := io.ReadFull(r, scratch[:]); err != nil {
3✔
1645
                return nil, err
×
1646
        }
×
1647
        c.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
3✔
1648

3✔
1649
        creationTime, err := deserializeTime(r)
3✔
1650
        if err != nil {
3✔
1651
                return nil, err
×
1652
        }
×
1653
        c.CreationTime = creationTime
3✔
1654

3✔
1655
        if _, err := io.ReadFull(r, scratch[:4]); err != nil {
3✔
1656
                return nil, err
×
1657
        }
×
1658

1659
        reqLen := uint32(byteOrder.Uint32(scratch[:4]))
3✔
1660
        payReq := make([]byte, reqLen)
3✔
1661
        if reqLen > 0 {
6✔
1662
                if _, err := io.ReadFull(r, payReq); err != nil {
3✔
1663
                        return nil, err
×
1664
                }
×
1665
        }
1666
        c.PaymentRequest = payReq
3✔
1667

3✔
1668
        // Any remaining bytes are TLV encoded records. Currently, these are
3✔
1669
        // only the custom records provided by the user to be sent to the first
3✔
1670
        // hop. But this can easily be extended with further records by merging
3✔
1671
        // the records into a single TLV stream.
3✔
1672
        c.FirstHopCustomRecords, err = lnwire.ParseCustomRecordsFrom(r)
3✔
1673
        if err != nil {
3✔
1674
                return nil, err
×
1675
        }
×
1676

1677
        return c, nil
3✔
1678
}
1679

1680
func serializeHTLCAttemptInfo(w io.Writer, a *pymtpkg.HTLCAttemptInfo) error {
3✔
1681
        // We nned to make sure the session key is 32 bytes, so we copy the
3✔
1682
        // session key into a 32 byte array.
3✔
1683
        sessionKeySlice := a.SessionKey().Serialize()
3✔
1684
        var sessionKey [btcec.PrivKeyBytesLen]byte
3✔
1685
        copy(sessionKey[:], sessionKeySlice)
3✔
1686
        if err := WriteElements(w, sessionKey); err != nil {
3✔
1687
                return err
×
1688
        }
×
1689

1690
        if err := SerializeRoute(w, a.Route); err != nil {
3✔
1691
                return err
×
1692
        }
×
1693

1694
        if err := serializeTime(w, a.AttemptTime); err != nil {
3✔
1695
                return err
×
1696
        }
×
1697

1698
        // If the hash is nil we can just return.
1699
        if a.Hash == nil {
3✔
1700
                return nil
×
1701
        }
×
1702

1703
        if _, err := w.Write(a.Hash[:]); err != nil {
3✔
1704
                return err
×
1705
        }
×
1706

1707
        // Merge the fixed/known records together with the custom records to
1708
        // serialize them as a single blob. We can't do this in SerializeRoute
1709
        // because we're in the middle of the byte stream there. We can only do
1710
        // TLV serialization at the end of the stream, since EOF is allowed for
1711
        // a stream if no more data is expected.
1712
        producers := []tlv.RecordProducer{
3✔
1713
                &a.Route.FirstHopAmount,
3✔
1714
        }
3✔
1715
        tlvData, err := lnwire.MergeAndEncode(
3✔
1716
                producers, nil, a.Route.FirstHopWireCustomRecords,
3✔
1717
        )
3✔
1718
        if err != nil {
3✔
1719
                return err
×
1720
        }
×
1721

1722
        if _, err := w.Write(tlvData); err != nil {
3✔
1723
                return err
×
1724
        }
×
1725

1726
        return nil
3✔
1727
}
1728

1729
func deserializeHTLCAttemptInfo(r io.Reader) (*pymtpkg.HTLCAttemptInfo, error) {
3✔
1730
        a := &pymtpkg.HTLCAttemptInfo{}
3✔
1731
        var sessionKey [btcec.PrivKeyBytesLen]byte
3✔
1732
        err := ReadElements(r, &sessionKey)
3✔
1733
        if err != nil {
3✔
1734
                return nil, err
×
1735
        }
×
1736
        a.SetSessionKey(sessionKey)
3✔
1737

3✔
1738
        a.Route, err = DeserializeRoute(r)
3✔
1739
        if err != nil {
3✔
1740
                return nil, err
×
1741
        }
×
1742

1743
        a.AttemptTime, err = deserializeTime(r)
3✔
1744
        if err != nil {
3✔
1745
                return nil, err
×
1746
        }
×
1747

1748
        hash := lntypes.Hash{}
3✔
1749
        _, err = io.ReadFull(r, hash[:])
3✔
1750

3✔
1751
        switch {
3✔
1752
        // Older payment attempts wouldn't have the hash set, in which case we
1753
        // can just return.
1754
        case err == io.EOF, err == io.ErrUnexpectedEOF:
×
1755
                return a, nil
×
1756

1757
        case err != nil:
×
1758
                return nil, err
×
1759

1760
        default:
3✔
1761
        }
1762

1763
        a.Hash = &hash
3✔
1764

3✔
1765
        // Read any remaining data (if any) and parse it into the known records
3✔
1766
        // and custom records.
3✔
1767
        extraData, err := io.ReadAll(r)
3✔
1768
        if err != nil {
3✔
1769
                return nil, err
×
1770
        }
×
1771

1772
        customRecords, _, _, err := lnwire.ParseAndExtractCustomRecords(
3✔
1773
                extraData, &a.Route.FirstHopAmount,
3✔
1774
        )
3✔
1775
        if err != nil {
3✔
1776
                return nil, err
×
1777
        }
×
1778

1779
        a.Route.FirstHopWireCustomRecords = customRecords
3✔
1780

3✔
1781
        return a, nil
3✔
1782
}
1783

1784
func serializeHop(w io.Writer, h *route.Hop) error {
3✔
1785
        if err := WriteElements(w,
3✔
1786
                h.PubKeyBytes[:],
3✔
1787
                h.ChannelID,
3✔
1788
                h.OutgoingTimeLock,
3✔
1789
                h.AmtToForward,
3✔
1790
        ); err != nil {
3✔
1791
                return err
×
1792
        }
×
1793

1794
        if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil {
3✔
1795
                return err
×
1796
        }
×
1797

1798
        // For legacy payloads, we don't need to write any TLV records, so
1799
        // we'll write a zero indicating the our serialized TLV map has no
1800
        // records.
1801
        if h.LegacyPayload {
3✔
UNCOV
1802
                return WriteElements(w, uint32(0))
×
UNCOV
1803
        }
×
1804

1805
        // Gather all non-primitive TLV records so that they can be serialized
1806
        // as a single blob.
1807
        //
1808
        // TODO(conner): add migration to unify all fields in a single TLV
1809
        // blobs. The split approach will cause headaches down the road as more
1810
        // fields are added, which we can avoid by having a single TLV stream
1811
        // for all payload fields.
1812
        var records []tlv.Record
3✔
1813
        if h.MPP != nil {
6✔
1814
                records = append(records, h.MPP.Record())
3✔
1815
        }
3✔
1816

1817
        // Add blinding point and encrypted data if present.
1818
        if h.EncryptedData != nil {
6✔
1819
                records = append(records, record.NewEncryptedDataRecord(
3✔
1820
                        &h.EncryptedData,
3✔
1821
                ))
3✔
1822
        }
3✔
1823

1824
        if h.BlindingPoint != nil {
6✔
1825
                records = append(records, record.NewBlindingPointRecord(
3✔
1826
                        &h.BlindingPoint,
3✔
1827
                ))
3✔
1828
        }
3✔
1829

1830
        if h.AMP != nil {
6✔
1831
                records = append(records, h.AMP.Record())
3✔
1832
        }
3✔
1833

1834
        if h.Metadata != nil {
3✔
UNCOV
1835
                records = append(records, record.NewMetadataRecord(&h.Metadata))
×
UNCOV
1836
        }
×
1837

1838
        if h.TotalAmtMsat != 0 {
6✔
1839
                totalMsatInt := uint64(h.TotalAmtMsat)
3✔
1840
                records = append(
3✔
1841
                        records, record.NewTotalAmtMsatBlinded(&totalMsatInt),
3✔
1842
                )
3✔
1843
        }
3✔
1844

1845
        // Final sanity check to absolutely rule out custom records that are not
1846
        // custom and write into the standard range.
1847
        if err := h.CustomRecords.Validate(); err != nil {
3✔
1848
                return err
×
1849
        }
×
1850

1851
        // Convert custom records to tlv and add to the record list.
1852
        // MapToRecords sorts the list, so adding it here will keep the list
1853
        // canonical.
1854
        tlvRecords := tlv.MapToRecords(h.CustomRecords)
3✔
1855
        records = append(records, tlvRecords...)
3✔
1856

3✔
1857
        // Otherwise, we'll transform our slice of records into a map of the
3✔
1858
        // raw bytes, then serialize them in-line with a length (number of
3✔
1859
        // elements) prefix.
3✔
1860
        mapRecords, err := tlv.RecordsToMap(records)
3✔
1861
        if err != nil {
3✔
1862
                return err
×
1863
        }
×
1864

1865
        numRecords := uint32(len(mapRecords))
3✔
1866
        if err := WriteElements(w, numRecords); err != nil {
3✔
1867
                return err
×
1868
        }
×
1869

1870
        for recordType, rawBytes := range mapRecords {
6✔
1871
                if err := WriteElements(w, recordType); err != nil {
3✔
1872
                        return err
×
1873
                }
×
1874

1875
                if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil {
3✔
1876
                        return err
×
1877
                }
×
1878
        }
1879

1880
        return nil
3✔
1881
}
1882

1883
// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need
1884
// to read/write a TLV stream larger than this.
1885
const maxOnionPayloadSize = 1300
1886

1887
func deserializeHop(r io.Reader) (*route.Hop, error) {
3✔
1888
        h := &route.Hop{}
3✔
1889

3✔
1890
        var pub []byte
3✔
1891
        if err := ReadElements(r, &pub); err != nil {
3✔
1892
                return nil, err
×
1893
        }
×
1894
        copy(h.PubKeyBytes[:], pub)
3✔
1895

3✔
1896
        if err := ReadElements(r,
3✔
1897
                &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
3✔
1898
        ); err != nil {
3✔
1899
                return nil, err
×
1900
        }
×
1901

1902
        // TODO(roasbeef): change field to allow LegacyPayload false to be the
1903
        // legacy default?
1904
        err := binary.Read(r, byteOrder, &h.LegacyPayload)
3✔
1905
        if err != nil {
3✔
1906
                return nil, err
×
1907
        }
×
1908

1909
        var numElements uint32
3✔
1910
        if err := ReadElements(r, &numElements); err != nil {
3✔
1911
                return nil, err
×
1912
        }
×
1913

1914
        // If there're no elements, then we can return early.
1915
        if numElements == 0 {
6✔
1916
                return h, nil
3✔
1917
        }
3✔
1918

1919
        tlvMap := make(map[uint64][]byte)
3✔
1920
        for i := uint32(0); i < numElements; i++ {
6✔
1921
                var tlvType uint64
3✔
1922
                if err := ReadElements(r, &tlvType); err != nil {
3✔
1923
                        return nil, err
×
1924
                }
×
1925

1926
                rawRecordBytes, err := wire.ReadVarBytes(
3✔
1927
                        r, 0, maxOnionPayloadSize, "tlv",
3✔
1928
                )
3✔
1929
                if err != nil {
3✔
1930
                        return nil, err
×
1931
                }
×
1932

1933
                tlvMap[tlvType] = rawRecordBytes
3✔
1934
        }
1935

1936
        // If the MPP type is present, remove it from the generic TLV map and
1937
        // parse it back into a proper MPP struct.
1938
        //
1939
        // TODO(conner): add migration to unify all fields in a single TLV
1940
        // blobs. The split approach will cause headaches down the road as more
1941
        // fields are added, which we can avoid by having a single TLV stream
1942
        // for all payload fields.
1943
        mppType := uint64(record.MPPOnionType)
3✔
1944
        if mppBytes, ok := tlvMap[mppType]; ok {
6✔
1945
                delete(tlvMap, mppType)
3✔
1946

3✔
1947
                var (
3✔
1948
                        mpp    = &record.MPP{}
3✔
1949
                        mppRec = mpp.Record()
3✔
1950
                        r      = bytes.NewReader(mppBytes)
3✔
1951
                )
3✔
1952
                err := mppRec.Decode(r, uint64(len(mppBytes)))
3✔
1953
                if err != nil {
3✔
1954
                        return nil, err
×
1955
                }
×
1956
                h.MPP = mpp
3✔
1957
        }
1958

1959
        // If encrypted data or blinding key are present, remove them from
1960
        // the TLV map and parse into proper types.
1961
        encryptedDataType := uint64(record.EncryptedDataOnionType)
3✔
1962
        if data, ok := tlvMap[encryptedDataType]; ok {
6✔
1963
                delete(tlvMap, encryptedDataType)
3✔
1964
                h.EncryptedData = data
3✔
1965
        }
3✔
1966

1967
        blindingType := uint64(record.BlindingPointOnionType)
3✔
1968
        if blindingPoint, ok := tlvMap[blindingType]; ok {
6✔
1969
                delete(tlvMap, blindingType)
3✔
1970

3✔
1971
                h.BlindingPoint, err = btcec.ParsePubKey(blindingPoint)
3✔
1972
                if err != nil {
3✔
1973
                        return nil, fmt.Errorf("invalid blinding point: %w",
×
1974
                                err)
×
1975
                }
×
1976
        }
1977

1978
        ampType := uint64(record.AMPOnionType)
3✔
1979
        if ampBytes, ok := tlvMap[ampType]; ok {
6✔
1980
                delete(tlvMap, ampType)
3✔
1981

3✔
1982
                var (
3✔
1983
                        amp    = &record.AMP{}
3✔
1984
                        ampRec = amp.Record()
3✔
1985
                        r      = bytes.NewReader(ampBytes)
3✔
1986
                )
3✔
1987
                err := ampRec.Decode(r, uint64(len(ampBytes)))
3✔
1988
                if err != nil {
3✔
1989
                        return nil, err
×
1990
                }
×
1991
                h.AMP = amp
3✔
1992
        }
1993

1994
        // If the metadata type is present, remove it from the tlv map and
1995
        // populate directly on the hop.
1996
        metadataType := uint64(record.MetadataOnionType)
3✔
1997
        if metadata, ok := tlvMap[metadataType]; ok {
3✔
UNCOV
1998
                delete(tlvMap, metadataType)
×
UNCOV
1999

×
UNCOV
2000
                h.Metadata = metadata
×
UNCOV
2001
        }
×
2002

2003
        totalAmtMsatType := uint64(record.TotalAmtMsatBlindedType)
3✔
2004
        if totalAmtMsat, ok := tlvMap[totalAmtMsatType]; ok {
6✔
2005
                delete(tlvMap, totalAmtMsatType)
3✔
2006

3✔
2007
                var (
3✔
2008
                        totalAmtMsatInt uint64
3✔
2009
                        buf             [8]byte
3✔
2010
                )
3✔
2011
                if err := tlv.DTUint64(
3✔
2012
                        bytes.NewReader(totalAmtMsat),
3✔
2013
                        &totalAmtMsatInt,
3✔
2014
                        &buf,
3✔
2015
                        uint64(len(totalAmtMsat)),
3✔
2016
                ); err != nil {
3✔
2017
                        return nil, err
×
2018
                }
×
2019

2020
                h.TotalAmtMsat = lnwire.MilliSatoshi(totalAmtMsatInt)
3✔
2021
        }
2022

2023
        h.CustomRecords = tlvMap
3✔
2024

3✔
2025
        return h, nil
3✔
2026
}
2027

2028
// SerializeRoute serializes a route.
2029
func SerializeRoute(w io.Writer, r route.Route) error {
3✔
2030
        if err := WriteElements(w,
3✔
2031
                r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
3✔
2032
        ); err != nil {
3✔
2033
                return err
×
2034
        }
×
2035

2036
        if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
3✔
2037
                return err
×
2038
        }
×
2039

2040
        for _, h := range r.Hops {
6✔
2041
                if err := serializeHop(w, h); err != nil {
3✔
2042
                        return err
×
2043
                }
×
2044
        }
2045

2046
        // Any new/extra TLV data is encoded in serializeHTLCAttemptInfo!
2047

2048
        return nil
3✔
2049
}
2050

2051
// DeserializeRoute deserializes a route.
2052
func DeserializeRoute(r io.Reader) (route.Route, error) {
3✔
2053
        rt := route.Route{}
3✔
2054
        if err := ReadElements(r,
3✔
2055
                &rt.TotalTimeLock, &rt.TotalAmount,
3✔
2056
        ); err != nil {
3✔
2057
                return rt, err
×
2058
        }
×
2059

2060
        var pub []byte
3✔
2061
        if err := ReadElements(r, &pub); err != nil {
3✔
2062
                return rt, err
×
2063
        }
×
2064
        copy(rt.SourcePubKey[:], pub)
3✔
2065

3✔
2066
        var numHops uint32
3✔
2067
        if err := ReadElements(r, &numHops); err != nil {
3✔
2068
                return rt, err
×
2069
        }
×
2070

2071
        var hops []*route.Hop
3✔
2072
        for i := uint32(0); i < numHops; i++ {
6✔
2073
                hop, err := deserializeHop(r)
3✔
2074
                if err != nil {
3✔
2075
                        return rt, err
×
2076
                }
×
2077
                hops = append(hops, hop)
3✔
2078
        }
2079
        rt.Hops = hops
3✔
2080

3✔
2081
        // Any new/extra TLV data is decoded in deserializeHTLCAttemptInfo!
3✔
2082

3✔
2083
        return rt, nil
3✔
2084
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc