• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 11170835610

03 Oct 2024 10:41PM UTC coverage: 49.188% (-9.6%) from 58.738%
11170835610

push

github

web-flow
Merge pull request #9154 from ziggie1984/master

multi: bump btcd version.

3 of 6 new or added lines in 6 files covered. (50.0%)

26110 existing lines in 428 files now uncovered.

97359 of 197934 relevant lines covered (49.19%)

1.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/channeldb/models"
15
        "github.com/lightningnetwork/lnd/clock"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27
)
28

29
// SQLInvoiceQueries is an interface that defines the set of operations that can
30
// be executed against the invoice SQL database.
31
type SQLInvoiceQueries interface { //nolint:interfacebloat
32
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
33
                error)
34

35
        InsertInvoiceFeature(ctx context.Context,
36
                arg sqlc.InsertInvoiceFeatureParams) error
37

38
        InsertInvoiceHTLC(ctx context.Context,
39
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
40

41
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
42
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
43

44
        FilterInvoices(ctx context.Context,
45
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
46

47
        GetInvoice(ctx context.Context,
48
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
49

50
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
51
                error)
52

53
        GetInvoiceFeatures(ctx context.Context,
54
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
55

56
        GetInvoiceHTLCCustomRecords(ctx context.Context,
57
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
58

59
        GetInvoiceHTLCs(ctx context.Context,
60
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
61

62
        UpdateInvoiceState(ctx context.Context,
63
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
64

65
        UpdateInvoiceAmountPaid(ctx context.Context,
66
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
67

68
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
69

70
        UpdateInvoiceHTLC(ctx context.Context,
71
                arg sqlc.UpdateInvoiceHTLCParams) error
72

73
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
74
                sql.Result, error)
75

76
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
77

78
        // AMP sub invoice specific methods.
79
        UpsertAMPSubInvoice(ctx context.Context,
80
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
81

82
        UpdateAMPSubInvoiceState(ctx context.Context,
83
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
84

85
        InsertAMPSubInvoiceHTLC(ctx context.Context,
86
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
87

88
        FetchAMPSubInvoices(ctx context.Context,
89
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
90
                error)
91

92
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
93
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
94
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
95

96
        FetchSettledAMPSubInvoices(ctx context.Context,
97
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
98
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
99

100
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
101
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
102
                error)
103

104
        // Invoice events specific methods.
105
        OnInvoiceCreated(ctx context.Context,
106
                arg sqlc.OnInvoiceCreatedParams) error
107

108
        OnInvoiceCanceled(ctx context.Context,
109
                arg sqlc.OnInvoiceCanceledParams) error
110

111
        OnInvoiceSettled(ctx context.Context,
112
                arg sqlc.OnInvoiceSettledParams) error
113

114
        OnAMPSubInvoiceCreated(ctx context.Context,
115
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
116

117
        OnAMPSubInvoiceCanceled(ctx context.Context,
118
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
119

120
        OnAMPSubInvoiceSettled(ctx context.Context,
121
                arg sqlc.OnAMPSubInvoiceSettledParams) error
122
}
123

124
var _ InvoiceDB = (*SQLStore)(nil)
125

126
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
127
// SQLInvoiceQueries understands.
128
type SQLInvoiceQueriesTxOptions struct {
129
        // readOnly governs if a read only transaction is needed or not.
130
        readOnly bool
131
}
132

133
// ReadOnly returns true if the transaction should be read only.
134
//
135
// NOTE: This implements the TxOptions.
UNCOV
136
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
×
UNCOV
137
        return a.readOnly
×
UNCOV
138
}
×
139

140
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
UNCOV
141
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
×
UNCOV
142
        return SQLInvoiceQueriesTxOptions{
×
UNCOV
143
                readOnly: true,
×
UNCOV
144
        }
×
UNCOV
145
}
×
146

147
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
148
// of batched database operations.
149
type BatchedSQLInvoiceQueries interface {
150
        SQLInvoiceQueries
151

152
        sqldb.BatchedTx[SQLInvoiceQueries]
153
}
154

155
// SQLStore represents a storage backend.
156
type SQLStore struct {
157
        db    BatchedSQLInvoiceQueries
158
        clock clock.Clock
159
        opts  SQLStoreOptions
160
}
161

162
// SQLStoreOptions holds the options for the SQL store.
163
type SQLStoreOptions struct {
164
        paginationLimit int
165
}
166

167
// defaultSQLStoreOptions returns the default options for the SQL store.
UNCOV
168
func defaultSQLStoreOptions() SQLStoreOptions {
×
UNCOV
169
        return SQLStoreOptions{
×
UNCOV
170
                paginationLimit: defaultQueryPaginationLimit,
×
UNCOV
171
        }
×
UNCOV
172
}
×
173

174
// SQLStoreOption is a functional option that can be used to optionally modify
175
// the behavior of the SQL store.
176
type SQLStoreOption func(*SQLStoreOptions)
177

178
// WithPaginationLimit sets the pagination limit for the SQL store queries that
179
// paginate results.
UNCOV
180
func WithPaginationLimit(limit int) SQLStoreOption {
×
UNCOV
181
        return func(o *SQLStoreOptions) {
×
UNCOV
182
                o.paginationLimit = limit
×
UNCOV
183
        }
×
184
}
185

186
// NewSQLStore creates a new SQLStore instance given a open
187
// BatchedSQLInvoiceQueries storage backend.
188
func NewSQLStore(db BatchedSQLInvoiceQueries,
UNCOV
189
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
×
UNCOV
190

×
UNCOV
191
        opts := defaultSQLStoreOptions()
×
UNCOV
192
        for _, applyOption := range options {
×
UNCOV
193
                applyOption(&opts)
×
UNCOV
194
        }
×
195

UNCOV
196
        return &SQLStore{
×
UNCOV
197
                db:    db,
×
UNCOV
198
                clock: clock,
×
UNCOV
199
                opts:  opts,
×
UNCOV
200
        }
×
201
}
202

203
// AddInvoice inserts the targeted invoice into the database. If the invoice has
204
// *any* payment hashes which already exists within the database, then the
205
// insertion will be aborted and rejected due to the strict policy banning any
206
// duplicate payment hashes.
207
//
208
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
209
func (i *SQLStore) AddInvoice(ctx context.Context,
UNCOV
210
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
×
UNCOV
211

×
UNCOV
212
        // Make sure this is a valid invoice before trying to store it in our
×
UNCOV
213
        // DB.
×
UNCOV
214
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
UNCOV
215
                return 0, err
×
UNCOV
216
        }
×
217

UNCOV
218
        var (
×
UNCOV
219
                writeTxOpts SQLInvoiceQueriesTxOptions
×
UNCOV
220
                invoiceID   int64
×
UNCOV
221
        )
×
UNCOV
222

×
UNCOV
223
        // Precompute the payment request hash so we can use it in the query.
×
UNCOV
224
        var paymentRequestHash []byte
×
UNCOV
225
        if len(newInvoice.PaymentRequest) > 0 {
×
UNCOV
226
                h := sha256.New()
×
UNCOV
227
                h.Write(newInvoice.PaymentRequest)
×
UNCOV
228
                paymentRequestHash = h.Sum(nil)
×
UNCOV
229
        }
×
230

UNCOV
231
        err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
×
UNCOV
232
                params := sqlc.InsertInvoiceParams{
×
UNCOV
233
                        Hash:       paymentHash[:],
×
UNCOV
234
                        Memo:       sqldb.SQLStr(string(newInvoice.Memo)),
×
UNCOV
235
                        AmountMsat: int64(newInvoice.Terms.Value),
×
UNCOV
236
                        // Note: BOLT12 invoices don't have a final cltv delta.
×
UNCOV
237
                        CltvDelta: sqldb.SQLInt32(
×
UNCOV
238
                                newInvoice.Terms.FinalCltvDelta,
×
UNCOV
239
                        ),
×
UNCOV
240
                        Expiry: int32(newInvoice.Terms.Expiry.Seconds()),
×
UNCOV
241
                        // Note: keysend invoices don't have a payment request.
×
UNCOV
242
                        PaymentRequest: sqldb.SQLStr(string(
×
UNCOV
243
                                newInvoice.PaymentRequest),
×
UNCOV
244
                        ),
×
UNCOV
245
                        PaymentRequestHash: paymentRequestHash,
×
UNCOV
246
                        State:              int16(newInvoice.State),
×
UNCOV
247
                        AmountPaidMsat:     int64(newInvoice.AmtPaid),
×
UNCOV
248
                        IsAmp:              newInvoice.IsAMP(),
×
UNCOV
249
                        IsHodl:             newInvoice.HodlInvoice,
×
UNCOV
250
                        IsKeysend:          newInvoice.IsKeysend(),
×
UNCOV
251
                        CreatedAt:          newInvoice.CreationDate.UTC(),
×
UNCOV
252
                }
×
UNCOV
253

×
UNCOV
254
                // Some invoices may not have a preimage, like in the case of
×
UNCOV
255
                // HODL invoices.
×
UNCOV
256
                if newInvoice.Terms.PaymentPreimage != nil {
×
UNCOV
257
                        preimage := *newInvoice.Terms.PaymentPreimage
×
UNCOV
258
                        if preimage == UnknownPreimage {
×
259
                                return errors.New("cannot use all-zeroes " +
×
260
                                        "preimage")
×
261
                        }
×
UNCOV
262
                        params.Preimage = preimage[:]
×
263
                }
264

265
                // Some non MPP payments may have the default (invalid) value.
UNCOV
266
                if newInvoice.Terms.PaymentAddr != BlankPayAddr {
×
UNCOV
267
                        params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
×
UNCOV
268
                }
×
269

UNCOV
270
                var err error
×
UNCOV
271
                invoiceID, err = db.InsertInvoice(ctx, params)
×
UNCOV
272
                if err != nil {
×
UNCOV
273
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
UNCOV
274
                }
×
275

276
                // TODO(positiveblue): if invocies do not have custom features
277
                // maybe just store the "invoice type" and populate the features
278
                // based on that.
UNCOV
279
                for feature := range newInvoice.Terms.Features.Features() {
×
UNCOV
280
                        params := sqlc.InsertInvoiceFeatureParams{
×
UNCOV
281
                                InvoiceID: invoiceID,
×
UNCOV
282
                                Feature:   int32(feature),
×
UNCOV
283
                        }
×
UNCOV
284

×
UNCOV
285
                        err := db.InsertInvoiceFeature(ctx, params)
×
UNCOV
286
                        if err != nil {
×
287
                                return fmt.Errorf("unable to insert invoice "+
×
288
                                        "feature(%v): %w", feature, err)
×
289
                        }
×
290
                }
291

292
                // Finally add a new event for this invoice.
UNCOV
293
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
UNCOV
294
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
UNCOV
295
                        InvoiceID: invoiceID,
×
UNCOV
296
                })
×
UNCOV
297
        }, func() {})
×
UNCOV
298
        if err != nil {
×
UNCOV
299
                mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
300
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
UNCOV
301
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
302
                        // Add context to unique constraint errors.
×
UNCOV
303
                        return 0, ErrDuplicateInvoice
×
UNCOV
304
                }
×
305

306
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
307
                        paymentHash, err)
×
308
        }
309

UNCOV
310
        newInvoice.AddIndex = uint64(invoiceID)
×
UNCOV
311

×
UNCOV
312
        return newInvoice.AddIndex, nil
×
313
}
314

315
// fetchInvoice fetches the common invoice data and the AMP state for the
316
// invoice with the given reference.
317
func (i *SQLStore) fetchInvoice(ctx context.Context,
UNCOV
318
        db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
×
UNCOV
319

×
UNCOV
320
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
×
UNCOV
321
                return nil, ErrInvoiceNotFound
×
UNCOV
322
        }
×
323

UNCOV
324
        var (
×
UNCOV
325
                invoice *Invoice
×
UNCOV
326
                params  sqlc.GetInvoiceParams
×
UNCOV
327
        )
×
UNCOV
328

×
UNCOV
329
        // Given all invoices are uniquely identified by their payment hash,
×
UNCOV
330
        // we can use it to query a specific invoice.
×
UNCOV
331
        if ref.PayHash() != nil {
×
UNCOV
332
                params.Hash = ref.PayHash()[:]
×
UNCOV
333
        }
×
334

335
        // Newer invoices (0.11 and up) are indexed by payment address in
336
        // addition to payment hash, but pre 0.8 invoices do not have one at
337
        // all. Only allow lookups for payment address if it is not a blank
338
        // payment address, which is a special-cased value for legacy keysend
339
        // invoices.
UNCOV
340
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
UNCOV
341
                params.PaymentAddr = ref.PayAddr()[:]
×
UNCOV
342
        }
×
343

344
        // If the reference has a set ID we'll fetch the invoice which has the
345
        // corresponding AMP sub invoice.
UNCOV
346
        if ref.SetID() != nil {
×
UNCOV
347
                params.SetID = ref.SetID()[:]
×
UNCOV
348
        }
×
349

UNCOV
350
        var (
×
UNCOV
351
                rows []sqlc.Invoice
×
UNCOV
352
                err  error
×
UNCOV
353
        )
×
UNCOV
354

×
UNCOV
355
        // We need to split the query based on how we intend to look up the
×
UNCOV
356
        // invoice. If only the set ID is given then we want to have an exact
×
UNCOV
357
        // match on the set ID. If other fields are given, we want to match on
×
UNCOV
358
        // those fields and the set ID but with a less strict join condition.
×
UNCOV
359
        if params.Hash == nil && params.PaymentAddr == nil &&
×
UNCOV
360
                params.SetID != nil {
×
UNCOV
361

×
UNCOV
362
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
×
UNCOV
363
        } else {
×
UNCOV
364
                rows, err = db.GetInvoice(ctx, params)
×
UNCOV
365
        }
×
UNCOV
366
        switch {
×
UNCOV
367
        case len(rows) == 0:
×
UNCOV
368
                return nil, ErrInvoiceNotFound
×
369

370
        case len(rows) > 1:
×
371
                // In case the reference is ambiguous, meaning it matches more
×
372
                // than        one invoice, we'll return an error.
×
373
                return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
×
374
                        ref.String(), spew.Sdump(rows))
×
375

376
        case err != nil:
×
377
                return nil, fmt.Errorf("unable to fetch invoice: %w", err)
×
378
        }
379

UNCOV
380
        var (
×
UNCOV
381
                setID         *[32]byte
×
UNCOV
382
                fetchAmpHtlcs bool
×
UNCOV
383
        )
×
UNCOV
384

×
UNCOV
385
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
UNCOV
386
        // the modifier.
×
UNCOV
387
        switch ref.Modifier() {
×
UNCOV
388
        case DefaultModifier:
×
UNCOV
389
                // By default we'll fetch all AMP HTLCs.
×
UNCOV
390
                setID = nil
×
UNCOV
391
                fetchAmpHtlcs = true
×
392

UNCOV
393
        case HtlcSetOnlyModifier:
×
UNCOV
394
                // In this case we'll fetch all AMP HTLCs for the
×
UNCOV
395
                // specified set id.
×
UNCOV
396
                if ref.SetID() == nil {
×
397
                        return nil, fmt.Errorf("set ID is required to use " +
×
398
                                "the HTLC set only modifier")
×
399
                }
×
400

UNCOV
401
                setID = ref.SetID()
×
UNCOV
402
                fetchAmpHtlcs = true
×
403

UNCOV
404
        case HtlcSetBlankModifier:
×
UNCOV
405
                // No need to fetch any HTLCs.
×
UNCOV
406
                setID = nil
×
UNCOV
407
                fetchAmpHtlcs = false
×
408

409
        default:
×
410
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
411
                        ref.Modifier())
×
412
        }
413

414
        // Fetch the rest of the invoice data and fill the invoice struct.
UNCOV
415
        _, invoice, err = fetchInvoiceData(
×
UNCOV
416
                ctx, db, rows[0], setID, fetchAmpHtlcs,
×
UNCOV
417
        )
×
UNCOV
418
        if err != nil {
×
419
                return nil, err
×
420
        }
×
421

UNCOV
422
        return invoice, nil
×
423
}
424

425
// fetchAmpState fetches the AMP state for the invoice with the given ID.
426
// Optional setID can be provided to fetch the state for a specific AMP HTLC
427
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
428
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
429
// well.
430
//
431
//nolint:funlen
432
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
433
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
UNCOV
434
        HTLCSet, error) {
×
UNCOV
435

×
UNCOV
436
        var paramSetID []byte
×
UNCOV
437
        if setID != nil {
×
UNCOV
438
                paramSetID = setID[:]
×
UNCOV
439
        }
×
440

441
        // First fetch all the AMP sub invoices for this invoice or the one
442
        // matching the provided set ID.
UNCOV
443
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
UNCOV
444
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
UNCOV
445
                        InvoiceID: invoiceID,
×
UNCOV
446
                        SetID:     paramSetID,
×
UNCOV
447
                },
×
UNCOV
448
        )
×
UNCOV
449
        if err != nil {
×
450
                return nil, nil, err
×
451
        }
×
452

UNCOV
453
        ampState := make(map[SetID]InvoiceStateAMP)
×
UNCOV
454
        for _, row := range ampInvoiceRows {
×
UNCOV
455
                var rowSetID [32]byte
×
UNCOV
456

×
UNCOV
457
                if len(row.SetID) != 32 {
×
458
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
459
                                len(row.SetID))
×
460
                }
×
461

UNCOV
462
                var settleDate time.Time
×
UNCOV
463
                if row.SettledAt.Valid {
×
UNCOV
464
                        settleDate = row.SettledAt.Time.Local()
×
UNCOV
465
                }
×
466

UNCOV
467
                copy(rowSetID[:], row.SetID)
×
UNCOV
468
                ampState[rowSetID] = InvoiceStateAMP{
×
UNCOV
469
                        State:       HtlcState(row.State),
×
UNCOV
470
                        SettleIndex: uint64(row.SettleIndex.Int64),
×
UNCOV
471
                        SettleDate:  settleDate,
×
UNCOV
472
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
UNCOV
473
                }
×
474
        }
475

UNCOV
476
        if !fetchHtlcs {
×
UNCOV
477
                return ampState, nil, nil
×
UNCOV
478
        }
×
479

UNCOV
480
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
481
        if err != nil {
×
482
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
483
                        "invoice HTLCs: %w", err)
×
484
        }
×
485

UNCOV
486
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
UNCOV
487
        for _, row := range customRecordRows {
×
488
                if _, ok := customRecords[row.HtlcID]; !ok {
×
489
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
490
                }
×
491

492
                value := row.Value
×
493
                if value == nil {
×
494
                        value = []byte{}
×
495
                }
×
496

497
                customRecords[row.HtlcID][uint64(row.Key)] = value
×
498
        }
499

500
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
501
        // provided set ID.
UNCOV
502
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
×
UNCOV
503
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
×
UNCOV
504
                        InvoiceID: invoiceID,
×
UNCOV
505
                        SetID:     paramSetID,
×
UNCOV
506
                },
×
UNCOV
507
        )
×
UNCOV
508
        if err != nil {
×
509
                return nil, nil, err
×
510
        }
×
511

UNCOV
512
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
×
UNCOV
513
        for _, row := range ampHtlcRows {
×
UNCOV
514
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
515
                if err != nil {
×
516
                        return nil, nil, err
×
517
                }
×
518

UNCOV
519
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
520

×
UNCOV
521
                if row.HtlcID < 0 {
×
522
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
523
                                "value: %v", row.HtlcID)
×
524
                }
×
525

UNCOV
526
                htlcID := uint64(row.HtlcID)
×
UNCOV
527

×
UNCOV
528
                circuitKey := CircuitKey{
×
UNCOV
529
                        ChanID: chanID,
×
UNCOV
530
                        HtlcID: htlcID,
×
UNCOV
531
                }
×
UNCOV
532

×
UNCOV
533
                htlc := &InvoiceHTLC{
×
UNCOV
534
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
535
                        AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
536
                        AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
537
                        Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
538
                        State:        HtlcState(row.State),
×
UNCOV
539
                }
×
UNCOV
540

×
UNCOV
541
                if row.TotalMppMsat.Valid {
×
UNCOV
542
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
UNCOV
543
                                row.TotalMppMsat.Int64,
×
UNCOV
544
                        )
×
UNCOV
545
                }
×
546

UNCOV
547
                if row.ResolveTime.Valid {
×
UNCOV
548
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
549
                }
×
550

UNCOV
551
                var (
×
UNCOV
552
                        rootShare [32]byte
×
UNCOV
553
                        setID     [32]byte
×
UNCOV
554
                )
×
UNCOV
555

×
UNCOV
556
                if len(row.RootShare) != 32 {
×
557
                        return nil, nil, fmt.Errorf("invalid root share "+
×
558
                                "length: %d", len(row.RootShare))
×
559
                }
×
UNCOV
560
                copy(rootShare[:], row.RootShare)
×
UNCOV
561

×
UNCOV
562
                if len(row.SetID) != 32 {
×
563
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
564
                                len(row.SetID))
×
565
                }
×
UNCOV
566
                copy(setID[:], row.SetID)
×
UNCOV
567

×
UNCOV
568
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
569
                        return nil, nil, fmt.Errorf("invalid child index "+
×
570
                                "value: %v", row.ChildIndex)
×
571
                }
×
572

UNCOV
573
                ampRecord := record.NewAMP(
×
UNCOV
574
                        rootShare, setID, uint32(row.ChildIndex),
×
UNCOV
575
                )
×
UNCOV
576

×
UNCOV
577
                htlc.AMP = &InvoiceHtlcAMPData{
×
UNCOV
578
                        Record: *ampRecord,
×
UNCOV
579
                }
×
UNCOV
580

×
UNCOV
581
                if len(row.Hash) != 32 {
×
582
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
583
                                len(row.Hash))
×
584
                }
×
UNCOV
585
                copy(htlc.AMP.Hash[:], row.Hash)
×
UNCOV
586

×
UNCOV
587
                if row.Preimage != nil {
×
UNCOV
588
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
UNCOV
589
                        if err != nil {
×
590
                                return nil, nil, err
×
591
                        }
×
592

UNCOV
593
                        htlc.AMP.Preimage = &preimage
×
594
                }
595

UNCOV
596
                if _, ok := customRecords[row.ID]; ok {
×
597
                        htlc.CustomRecords = customRecords[row.ID]
×
UNCOV
598
                } else {
×
UNCOV
599
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
600
                }
×
601

UNCOV
602
                ampHtlcs[circuitKey] = htlc
×
603
        }
604

UNCOV
605
        if len(ampHtlcs) > 0 {
×
UNCOV
606
                for setID := range ampState {
×
UNCOV
607
                        var amtPaid lnwire.MilliSatoshi
×
UNCOV
608
                        invoiceKeys := make(
×
UNCOV
609
                                map[models.CircuitKey]struct{},
×
UNCOV
610
                        )
×
UNCOV
611

×
UNCOV
612
                        for key, htlc := range ampHtlcs {
×
UNCOV
613
                                if htlc.AMP.Record.SetID() != setID {
×
UNCOV
614
                                        continue
×
615
                                }
616

UNCOV
617
                                invoiceKeys[key] = struct{}{}
×
UNCOV
618

×
UNCOV
619
                                if htlc.State != HtlcStateCanceled { //nolint: lll
×
UNCOV
620
                                        amtPaid += htlc.Amt
×
UNCOV
621
                                }
×
622
                        }
623

UNCOV
624
                        setState := ampState[setID]
×
UNCOV
625
                        setState.InvoiceKeys = invoiceKeys
×
UNCOV
626
                        setState.AmtPaid = amtPaid
×
UNCOV
627
                        ampState[setID] = setState
×
628
                }
629
        }
630

UNCOV
631
        return ampState, ampHtlcs, nil
×
632
}
633

634
// LookupInvoice attempts to look up an invoice corresponding the passed in
635
// reference. The reference may be a payment hash, a payment address, or a set
636
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
637
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
638
// error.
639
func (i *SQLStore) LookupInvoice(ctx context.Context,
UNCOV
640
        ref InvoiceRef) (Invoice, error) {
×
UNCOV
641

×
UNCOV
642
        var (
×
UNCOV
643
                invoice *Invoice
×
UNCOV
644
                err     error
×
UNCOV
645
        )
×
UNCOV
646

×
UNCOV
647
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
648
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
649
                invoice, err = i.fetchInvoice(ctx, db, ref)
×
UNCOV
650

×
UNCOV
651
                return err
×
UNCOV
652
        }, func() {})
×
UNCOV
653
        if txErr != nil {
×
UNCOV
654
                return Invoice{}, txErr
×
UNCOV
655
        }
×
656

UNCOV
657
        return *invoice, nil
×
658
}
659

660
// FetchPendingInvoices returns all the invoices that are currently in a
661
// "pending" state. An invoice is pending if it has been created but not yet
662
// settled or canceled.
663
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
UNCOV
664
        map[lntypes.Hash]Invoice, error) {
×
UNCOV
665

×
UNCOV
666
        var invoices map[lntypes.Hash]Invoice
×
UNCOV
667

×
UNCOV
668
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
669
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
670
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
671
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
672
                                PendingOnly: true,
×
UNCOV
673
                                NumOffset:   int32(offset),
×
UNCOV
674
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
675
                                Reverse:     false,
×
UNCOV
676
                        }
×
UNCOV
677

×
UNCOV
678
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
679
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
680
                                return 0, fmt.Errorf("unable to get invoices "+
×
681
                                        "from db: %w", err)
×
682
                        }
×
683

684
                        // Load all the information for the invoices.
UNCOV
685
                        for _, row := range rows {
×
UNCOV
686
                                hash, invoice, err := fetchInvoiceData(
×
UNCOV
687
                                        ctx, db, row, nil, true,
×
UNCOV
688
                                )
×
UNCOV
689
                                if err != nil {
×
690
                                        return 0, err
×
691
                                }
×
692

UNCOV
693
                                invoices[*hash] = *invoice
×
694
                        }
695

UNCOV
696
                        return len(rows), nil
×
697
                }, i.opts.paginationLimit)
UNCOV
698
        }, func() {
×
UNCOV
699
                invoices = make(map[lntypes.Hash]Invoice)
×
UNCOV
700
        })
×
UNCOV
701
        if err != nil {
×
702
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
703
                        err)
×
704
        }
×
705

UNCOV
706
        return invoices, nil
×
707
}
708

709
// InvoicesSettledSince can be used by callers to catch up any settled invoices
710
// they missed within the settled invoice time series. We'll return all known
711
// settled invoice that have a settle index higher than the passed idx.
712
//
713
// NOTE: The index starts from 1. As a result we enforce that specifying a value
714
// below the starting index value is a noop.
715
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
UNCOV
716
        []Invoice, error) {
×
UNCOV
717

×
UNCOV
718
        var invoices []Invoice
×
UNCOV
719

×
UNCOV
720
        if idx == 0 {
×
UNCOV
721
                return invoices, nil
×
UNCOV
722
        }
×
723

UNCOV
724
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
725
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
726
                err := queryWithLimit(func(offset int) (int, error) {
×
UNCOV
727
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
728
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
729
                                NumOffset:      int32(offset),
×
UNCOV
730
                                NumLimit:       int32(i.opts.paginationLimit),
×
UNCOV
731
                                Reverse:        false,
×
UNCOV
732
                        }
×
UNCOV
733

×
UNCOV
734
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
735
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
736
                                return 0, fmt.Errorf("unable to get invoices "+
×
737
                                        "from db: %w", err)
×
738
                        }
×
739

740
                        // Load all the information for the invoices.
UNCOV
741
                        for _, row := range rows {
×
UNCOV
742
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
743
                                        ctx, db, row, nil, true,
×
UNCOV
744
                                )
×
UNCOV
745
                                if err != nil {
×
746
                                        return 0, fmt.Errorf("unable to fetch "+
×
747
                                                "invoice(id=%d) from db: %w",
×
748
                                                row.ID, err)
×
749
                                }
×
750

UNCOV
751
                                invoices = append(invoices, *invoice)
×
752
                        }
753

UNCOV
754
                        return len(rows), nil
×
755
                }, i.opts.paginationLimit)
UNCOV
756
                if err != nil {
×
757
                        return err
×
758
                }
×
759

760
                // Now fetch all the AMP sub invoices that were settled since
761
                // the provided index.
UNCOV
762
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
×
UNCOV
763
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
UNCOV
764
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
765
                        },
×
UNCOV
766
                )
×
UNCOV
767
                if err != nil {
×
768
                        return err
×
769
                }
×
770

UNCOV
771
                for _, ampInvoice := range ampInvoices {
×
UNCOV
772
                        // Convert the row to a sqlc.Invoice so we can use the
×
UNCOV
773
                        // existing fetchInvoiceData function.
×
UNCOV
774
                        sqlInvoice := sqlc.Invoice{
×
UNCOV
775
                                ID:             ampInvoice.ID,
×
UNCOV
776
                                Hash:           ampInvoice.Hash,
×
UNCOV
777
                                Preimage:       ampInvoice.Preimage,
×
UNCOV
778
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
UNCOV
779
                                SettledAt:      ampInvoice.AmpSettledAt,
×
UNCOV
780
                                Memo:           ampInvoice.Memo,
×
UNCOV
781
                                AmountMsat:     ampInvoice.AmountMsat,
×
UNCOV
782
                                CltvDelta:      ampInvoice.CltvDelta,
×
UNCOV
783
                                Expiry:         ampInvoice.Expiry,
×
UNCOV
784
                                PaymentAddr:    ampInvoice.PaymentAddr,
×
UNCOV
785
                                PaymentRequest: ampInvoice.PaymentRequest,
×
UNCOV
786
                                State:          ampInvoice.State,
×
UNCOV
787
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
UNCOV
788
                                IsAmp:          ampInvoice.IsAmp,
×
UNCOV
789
                                IsHodl:         ampInvoice.IsHodl,
×
UNCOV
790
                                IsKeysend:      ampInvoice.IsKeysend,
×
UNCOV
791
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
UNCOV
792
                        }
×
UNCOV
793

×
UNCOV
794
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
UNCOV
795
                        _, invoice, err := fetchInvoiceData(
×
UNCOV
796
                                ctx, db, sqlInvoice,
×
UNCOV
797
                                (*[32]byte)(ampInvoice.SetID), true,
×
UNCOV
798
                        )
×
UNCOV
799
                        if err != nil {
×
800
                                return fmt.Errorf("unable to fetch "+
×
801
                                        "AMP invoice(id=%d) from db: %w",
×
802
                                        ampInvoice.ID, err)
×
803
                        }
×
804

UNCOV
805
                        invoices = append(invoices, *invoice)
×
806
                }
807

UNCOV
808
                return nil
×
UNCOV
809
        }, func() {
×
UNCOV
810
                invoices = nil
×
UNCOV
811
        })
×
UNCOV
812
        if err != nil {
×
813
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
814
                        "index (excluding) %d: %w", idx, err)
×
815
        }
×
816

UNCOV
817
        return invoices, nil
×
818
}
819

820
// InvoicesAddedSince can be used by callers to seek into the event time series
821
// of all the invoices added in the database. This method will return all
822
// invoices with an add index greater than the specified idx.
823
//
824
// NOTE: The index starts from 1. As a result we enforce that specifying a value
825
// below the starting index value is a noop.
826
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
UNCOV
827
        []Invoice, error) {
×
UNCOV
828

×
UNCOV
829
        var result []Invoice
×
UNCOV
830

×
UNCOV
831
        if idx == 0 {
×
UNCOV
832
                return result, nil
×
UNCOV
833
        }
×
834

UNCOV
835
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
836
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
837
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
838
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
839
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
840
                                NumOffset:   int32(offset),
×
UNCOV
841
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
842
                                Reverse:     false,
×
UNCOV
843
                        }
×
UNCOV
844

×
UNCOV
845
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
846
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
847
                                return 0, fmt.Errorf("unable to get invoices "+
×
848
                                        "from db: %w", err)
×
849
                        }
×
850

851
                        // Load all the information for the invoices.
UNCOV
852
                        for _, row := range rows {
×
UNCOV
853
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
854
                                        ctx, db, row, nil, true,
×
UNCOV
855
                                )
×
UNCOV
856
                                if err != nil {
×
857
                                        return 0, err
×
858
                                }
×
859

UNCOV
860
                                result = append(result, *invoice)
×
861
                        }
862

UNCOV
863
                        return len(rows), nil
×
864
                }, i.opts.paginationLimit)
UNCOV
865
        }, func() {
×
UNCOV
866
                result = nil
×
UNCOV
867
        })
×
868

UNCOV
869
        if err != nil {
×
870
                return nil, fmt.Errorf("unable to get invoices added since "+
×
871
                        "index %d: %w", idx, err)
×
872
        }
×
873

UNCOV
874
        return result, nil
×
875
}
876

877
// QueryInvoices allows a caller to query the invoice database for invoices
878
// within the specified add index range.
879
func (i *SQLStore) QueryInvoices(ctx context.Context,
UNCOV
880
        q InvoiceQuery) (InvoiceSlice, error) {
×
UNCOV
881

×
UNCOV
882
        var invoices []Invoice
×
UNCOV
883

×
UNCOV
884
        if q.NumMaxInvoices == 0 {
×
885
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
886
                        "be non-zero")
×
887
        }
×
888

UNCOV
889
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
890
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
891
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
892
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
893
                                NumOffset:   int32(offset),
×
UNCOV
894
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
895
                                PendingOnly: q.PendingOnly,
×
UNCOV
896
                                Reverse:     q.Reversed,
×
UNCOV
897
                        }
×
UNCOV
898

×
UNCOV
899
                        if q.Reversed {
×
UNCOV
900
                                // If the index offset was not set, we want to
×
UNCOV
901
                                // fetch from the lastest invoice.
×
UNCOV
902
                                if q.IndexOffset == 0 {
×
UNCOV
903
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
904
                                                int64(math.MaxInt64),
×
UNCOV
905
                                        )
×
UNCOV
906
                                } else {
×
UNCOV
907
                                        // The invoice with index offset id must
×
UNCOV
908
                                        // not be included in the results.
×
UNCOV
909
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
910
                                                q.IndexOffset - 1,
×
UNCOV
911
                                        )
×
UNCOV
912
                                }
×
UNCOV
913
                        } else {
×
UNCOV
914
                                // The invoice with index offset id must not be
×
UNCOV
915
                                // included in the results.
×
UNCOV
916
                                params.AddIndexGet = sqldb.SQLInt64(
×
UNCOV
917
                                        q.IndexOffset + 1,
×
UNCOV
918
                                )
×
UNCOV
919
                        }
×
920

UNCOV
921
                        if q.CreationDateStart != 0 {
×
UNCOV
922
                                params.CreatedAfter = sqldb.SQLTime(
×
UNCOV
923
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
UNCOV
924
                                )
×
UNCOV
925
                        }
×
926

UNCOV
927
                        if q.CreationDateEnd != 0 {
×
UNCOV
928
                                // We need to add 1 to the end date as we're
×
UNCOV
929
                                // checking less than the end date in SQL.
×
UNCOV
930
                                params.CreatedBefore = sqldb.SQLTime(
×
UNCOV
931
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
×
UNCOV
932
                                )
×
UNCOV
933
                        }
×
934

UNCOV
935
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
936
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
937
                                return 0, fmt.Errorf("unable to get invoices "+
×
938
                                        "from db: %w", err)
×
939
                        }
×
940

941
                        // Load all the information for the invoices.
UNCOV
942
                        for _, row := range rows {
×
UNCOV
943
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
944
                                        ctx, db, row, nil, true,
×
UNCOV
945
                                )
×
UNCOV
946
                                if err != nil {
×
947
                                        return 0, err
×
948
                                }
×
949

UNCOV
950
                                invoices = append(invoices, *invoice)
×
UNCOV
951

×
UNCOV
952
                                if len(invoices) == int(q.NumMaxInvoices) {
×
UNCOV
953
                                        return 0, nil
×
UNCOV
954
                                }
×
955
                        }
956

UNCOV
957
                        return len(rows), nil
×
958
                }, i.opts.paginationLimit)
UNCOV
959
        }, func() {
×
UNCOV
960
                invoices = nil
×
UNCOV
961
        })
×
UNCOV
962
        if err != nil {
×
963
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
964
                        "invoices: %w", err)
×
965
        }
×
966

UNCOV
967
        if len(invoices) == 0 {
×
UNCOV
968
                return InvoiceSlice{
×
UNCOV
969
                        InvoiceQuery: q,
×
UNCOV
970
                }, nil
×
UNCOV
971
        }
×
972

973
        // If we iterated through the add index in reverse order, then
974
        // we'll need to reverse the slice of invoices to return them in
975
        // forward order.
UNCOV
976
        if q.Reversed {
×
UNCOV
977
                numInvoices := len(invoices)
×
UNCOV
978
                for i := 0; i < numInvoices/2; i++ {
×
UNCOV
979
                        reverse := numInvoices - i - 1
×
UNCOV
980
                        invoices[i], invoices[reverse] =
×
UNCOV
981
                                invoices[reverse], invoices[i]
×
UNCOV
982
                }
×
983
        }
984

UNCOV
985
        res := InvoiceSlice{
×
UNCOV
986
                InvoiceQuery:     q,
×
UNCOV
987
                Invoices:         invoices,
×
UNCOV
988
                FirstIndexOffset: invoices[0].AddIndex,
×
UNCOV
989
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
UNCOV
990
        }
×
UNCOV
991

×
UNCOV
992
        return res, nil
×
993
}
994

995
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
996
// a SQL database as the backend.
997
type sqlInvoiceUpdater struct {
998
        db         SQLInvoiceQueries
999
        ctx        context.Context //nolint:containedctx
1000
        invoice    *Invoice
1001
        updateTime time.Time
1002
}
1003

1004
// AddHtlc adds a new htlc to the invoice.
1005
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
UNCOV
1006
        newHtlc *InvoiceHTLC) error {
×
UNCOV
1007

×
UNCOV
1008
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
UNCOV
1009
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
UNCOV
1010
                        HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1011
                        ChanID: strconv.FormatUint(
×
UNCOV
1012
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1013
                        ),
×
UNCOV
1014
                        AmountMsat: int64(newHtlc.Amt),
×
UNCOV
1015
                        TotalMppMsat: sql.NullInt64{
×
UNCOV
1016
                                Int64: int64(newHtlc.MppTotalAmt),
×
UNCOV
1017
                                Valid: newHtlc.MppTotalAmt != 0,
×
UNCOV
1018
                        },
×
UNCOV
1019
                        AcceptHeight: int32(newHtlc.AcceptHeight),
×
UNCOV
1020
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
UNCOV
1021
                        ExpiryHeight: int32(newHtlc.Expiry),
×
UNCOV
1022
                        State:        int16(newHtlc.State),
×
UNCOV
1023
                        InvoiceID:    int64(s.invoice.AddIndex),
×
UNCOV
1024
                },
×
UNCOV
1025
        )
×
UNCOV
1026
        if err != nil {
×
1027
                return err
×
1028
        }
×
1029

UNCOV
1030
        for key, value := range newHtlc.CustomRecords {
×
UNCOV
1031
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
UNCOV
1032
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
UNCOV
1033
                                // TODO(bhandras): schema might be wrong here
×
UNCOV
1034
                                // as the custom record key is an uint64.
×
UNCOV
1035
                                Key:    int64(key),
×
UNCOV
1036
                                Value:  value,
×
UNCOV
1037
                                HtlcID: htlcPrimaryKeyID,
×
UNCOV
1038
                        },
×
UNCOV
1039
                )
×
UNCOV
1040
                if err != nil {
×
1041
                        return err
×
1042
                }
×
1043
        }
1044

UNCOV
1045
        if newHtlc.AMP != nil {
×
UNCOV
1046
                setID := newHtlc.AMP.Record.SetID()
×
UNCOV
1047

×
UNCOV
1048
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
UNCOV
1049
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
UNCOV
1050
                                SetID:     setID[:],
×
UNCOV
1051
                                CreatedAt: s.updateTime.UTC(),
×
UNCOV
1052
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1053
                        },
×
UNCOV
1054
                )
×
UNCOV
1055
                if err != nil {
×
UNCOV
1056
                        mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
1057
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:lll
×
UNCOV
1058
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
1059
                                return ErrDuplicateSetID{
×
UNCOV
1060
                                        SetID: setID,
×
UNCOV
1061
                                }
×
UNCOV
1062
                        }
×
1063

1064
                        return err
×
1065
                }
1066

1067
                // If we're just inserting the AMP invoice, we'll get a non
1068
                // zero rows affected count.
UNCOV
1069
                rowsAffected, err := upsertResult.RowsAffected()
×
UNCOV
1070
                if err != nil {
×
1071
                        return err
×
1072
                }
×
UNCOV
1073
                if rowsAffected != 0 {
×
UNCOV
1074
                        // If we're inserting a new AMP invoice, we'll also
×
UNCOV
1075
                        // insert a new invoice event.
×
UNCOV
1076
                        err = s.db.OnAMPSubInvoiceCreated(
×
UNCOV
1077
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
UNCOV
1078
                                        AddedAt:   s.updateTime.UTC(),
×
UNCOV
1079
                                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1080
                                        SetID:     setID[:],
×
UNCOV
1081
                                },
×
UNCOV
1082
                        )
×
UNCOV
1083
                        if err != nil {
×
1084
                                return err
×
1085
                        }
×
1086
                }
1087

UNCOV
1088
                rootShare := newHtlc.AMP.Record.RootShare()
×
UNCOV
1089

×
UNCOV
1090
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
UNCOV
1091
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1092
                        SetID:     setID[:],
×
UNCOV
1093
                        HtlcID:    htlcPrimaryKeyID,
×
UNCOV
1094
                        RootShare: rootShare[:],
×
UNCOV
1095
                        ChildIndex: int64(
×
UNCOV
1096
                                newHtlc.AMP.Record.ChildIndex(),
×
UNCOV
1097
                        ),
×
UNCOV
1098
                        Hash: newHtlc.AMP.Hash[:],
×
UNCOV
1099
                }
×
UNCOV
1100

×
UNCOV
1101
                if newHtlc.AMP.Preimage != nil {
×
UNCOV
1102
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
UNCOV
1103
                }
×
1104

UNCOV
1105
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
UNCOV
1106
                if err != nil {
×
1107
                        return err
×
1108
                }
×
1109
        }
1110

UNCOV
1111
        return nil
×
1112
}
1113

1114
// ResolveHtlc marks an htlc as resolved with the given state.
1115
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
UNCOV
1116
        state HtlcState, resolveTime time.Time) error {
×
UNCOV
1117

×
UNCOV
1118
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
UNCOV
1119
                HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1120
                ChanID: strconv.FormatUint(
×
UNCOV
1121
                        circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1122
                ),
×
UNCOV
1123
                InvoiceID:   int64(s.invoice.AddIndex),
×
UNCOV
1124
                State:       int16(state),
×
UNCOV
1125
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
UNCOV
1126
        })
×
UNCOV
1127
}
×
1128

1129
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1130
// identified by the setID.
1131
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
UNCOV
1132
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
UNCOV
1133

×
UNCOV
1134
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
UNCOV
1135
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
×
UNCOV
1136
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1137
                        SetID:     setID[:],
×
UNCOV
1138
                        HtlcID:    int64(circuitKey.HtlcID),
×
UNCOV
1139
                        Preimage:  preimage[:],
×
UNCOV
1140
                        ChanID: strconv.FormatUint(
×
UNCOV
1141
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1142
                        ),
×
UNCOV
1143
                },
×
UNCOV
1144
        )
×
UNCOV
1145
        if err != nil {
×
1146
                return err
×
1147
        }
×
1148

UNCOV
1149
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1150
        if err != nil {
×
1151
                return err
×
1152
        }
×
UNCOV
1153
        if rowsAffected == 0 {
×
1154
                return ErrInvoiceNotFound
×
1155
        }
×
1156

UNCOV
1157
        return nil
×
1158
}
1159

1160
// UpdateInvoiceState updates the invoice state to the new state.
1161
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
UNCOV
1162
        newState ContractState, preimage *lntypes.Preimage) error {
×
UNCOV
1163

×
UNCOV
1164
        var (
×
UNCOV
1165
                settleIndex sql.NullInt64
×
UNCOV
1166
                settledAt   sql.NullTime
×
UNCOV
1167
        )
×
UNCOV
1168

×
UNCOV
1169
        switch newState {
×
UNCOV
1170
        case ContractSettled:
×
UNCOV
1171
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1172
                if err != nil {
×
1173
                        return err
×
1174
                }
×
1175

UNCOV
1176
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1177

×
UNCOV
1178
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1179
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1180

×
UNCOV
1181
                err = s.db.OnInvoiceSettled(
×
UNCOV
1182
                        s.ctx, sqlc.OnInvoiceSettledParams{
×
UNCOV
1183
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1184
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1185
                        },
×
UNCOV
1186
                )
×
UNCOV
1187
                if err != nil {
×
1188
                        return err
×
1189
                }
×
1190

UNCOV
1191
        case ContractCanceled:
×
UNCOV
1192
                err := s.db.OnInvoiceCanceled(
×
UNCOV
1193
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
UNCOV
1194
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1195
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1196
                        },
×
UNCOV
1197
                )
×
UNCOV
1198
                if err != nil {
×
1199
                        return err
×
1200
                }
×
1201
        }
1202

UNCOV
1203
        params := sqlc.UpdateInvoiceStateParams{
×
UNCOV
1204
                ID:          int64(s.invoice.AddIndex),
×
UNCOV
1205
                State:       int16(newState),
×
UNCOV
1206
                SettleIndex: settleIndex,
×
UNCOV
1207
                SettledAt:   settledAt,
×
UNCOV
1208
        }
×
UNCOV
1209

×
UNCOV
1210
        if preimage != nil {
×
UNCOV
1211
                params.Preimage = preimage[:]
×
UNCOV
1212
        }
×
1213

UNCOV
1214
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
UNCOV
1215
        if err != nil {
×
1216
                return err
×
1217
        }
×
UNCOV
1218
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1219
        if err != nil {
×
1220
                return err
×
1221
        }
×
1222

UNCOV
1223
        if rowsAffected == 0 {
×
1224
                return ErrInvoiceNotFound
×
1225
        }
×
1226

UNCOV
1227
        if settleIndex.Valid {
×
UNCOV
1228
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1229
                s.invoice.SettleDate = s.updateTime
×
UNCOV
1230
        }
×
1231

UNCOV
1232
        return nil
×
1233
}
1234

1235
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1236
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
UNCOV
1237
        amtPaid lnwire.MilliSatoshi) error {
×
UNCOV
1238

×
UNCOV
1239
        _, err := s.db.UpdateInvoiceAmountPaid(
×
UNCOV
1240
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
UNCOV
1241
                        ID:             int64(s.invoice.AddIndex),
×
UNCOV
1242
                        AmountPaidMsat: int64(amtPaid),
×
UNCOV
1243
                },
×
UNCOV
1244
        )
×
UNCOV
1245

×
UNCOV
1246
        return err
×
UNCOV
1247
}
×
1248

1249
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1250
// setID.
1251
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
UNCOV
1252
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
UNCOV
1253

×
UNCOV
1254
        var (
×
UNCOV
1255
                settleIndex sql.NullInt64
×
UNCOV
1256
                settledAt   sql.NullTime
×
UNCOV
1257
        )
×
UNCOV
1258

×
UNCOV
1259
        switch newState.State {
×
UNCOV
1260
        case HtlcStateSettled:
×
UNCOV
1261
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1262
                if err != nil {
×
1263
                        return err
×
1264
                }
×
1265

UNCOV
1266
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1267

×
UNCOV
1268
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1269
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1270

×
UNCOV
1271
                err = s.db.OnAMPSubInvoiceSettled(
×
UNCOV
1272
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
×
UNCOV
1273
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1274
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1275
                                SetID:     setID[:],
×
UNCOV
1276
                        },
×
UNCOV
1277
                )
×
UNCOV
1278
                if err != nil {
×
1279
                        return err
×
1280
                }
×
1281

UNCOV
1282
        case HtlcStateCanceled:
×
UNCOV
1283
                err := s.db.OnAMPSubInvoiceCanceled(
×
UNCOV
1284
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
UNCOV
1285
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1286
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1287
                                SetID:     setID[:],
×
UNCOV
1288
                        },
×
UNCOV
1289
                )
×
UNCOV
1290
                if err != nil {
×
1291
                        return err
×
1292
                }
×
1293
        }
1294

UNCOV
1295
        err := s.db.UpdateAMPSubInvoiceState(
×
UNCOV
1296
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
×
UNCOV
1297
                        SetID:       setID[:],
×
UNCOV
1298
                        State:       int16(newState.State),
×
UNCOV
1299
                        SettleIndex: settleIndex,
×
UNCOV
1300
                        SettledAt:   settledAt,
×
UNCOV
1301
                },
×
UNCOV
1302
        )
×
UNCOV
1303
        if err != nil {
×
1304
                return err
×
1305
        }
×
1306

UNCOV
1307
        if settleIndex.Valid {
×
UNCOV
1308
                updatedState := s.invoice.AMPState[setID]
×
UNCOV
1309
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1310
                updatedState.SettleDate = s.updateTime.UTC()
×
UNCOV
1311
                s.invoice.AMPState[setID] = updatedState
×
UNCOV
1312
        }
×
1313

UNCOV
1314
        return nil
×
1315
}
1316

1317
// Finalize finalizes the update before it is written to the database. Note that
1318
// we don't use this directly in the SQL implementation, so the function is just
1319
// a stub.
UNCOV
1320
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
×
UNCOV
1321
        return nil
×
UNCOV
1322
}
×
1323

1324
// UpdateInvoice attempts to update an invoice corresponding to the passed
1325
// reference. If an invoice matching the passed reference doesn't exist within
1326
// the database, then the action will fail with  ErrInvoiceNotFound error.
1327
//
1328
// The update is performed inside the same database transaction that fetches the
1329
// invoice and is therefore atomic. The fields to update are controlled by the
1330
// supplied callback.
1331
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1332
        setID *SetID, callback InvoiceUpdateCallback) (
UNCOV
1333
        *Invoice, error) {
×
UNCOV
1334

×
UNCOV
1335
        var updatedInvoice *Invoice
×
UNCOV
1336

×
UNCOV
1337
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
×
UNCOV
1338
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1339
                if setID != nil {
×
UNCOV
1340
                        // Make sure to use the set ID if this is an AMP update.
×
UNCOV
1341
                        var setIDBytes [32]byte
×
UNCOV
1342
                        copy(setIDBytes[:], setID[:])
×
UNCOV
1343
                        ref.setID = &setIDBytes
×
UNCOV
1344

×
UNCOV
1345
                        // If we're updating an AMP invoice, we'll also only
×
UNCOV
1346
                        // need to fetch the HTLCs for the given set ID.
×
UNCOV
1347
                        ref.refModifier = HtlcSetOnlyModifier
×
UNCOV
1348
                }
×
1349

UNCOV
1350
                invoice, err := i.fetchInvoice(ctx, db, ref)
×
UNCOV
1351
                if err != nil {
×
UNCOV
1352
                        return err
×
UNCOV
1353
                }
×
1354

UNCOV
1355
                updateTime := i.clock.Now()
×
UNCOV
1356
                updater := &sqlInvoiceUpdater{
×
UNCOV
1357
                        db:         db,
×
UNCOV
1358
                        ctx:        ctx,
×
UNCOV
1359
                        invoice:    invoice,
×
UNCOV
1360
                        updateTime: updateTime,
×
UNCOV
1361
                }
×
UNCOV
1362

×
UNCOV
1363
                payHash := ref.PayHash()
×
UNCOV
1364
                updatedInvoice, err = UpdateInvoice(
×
UNCOV
1365
                        payHash, invoice, updateTime, callback, updater,
×
UNCOV
1366
                )
×
UNCOV
1367

×
UNCOV
1368
                return err
×
UNCOV
1369
        }, func() {})
×
UNCOV
1370
        if txErr != nil {
×
UNCOV
1371
                // If the invoice is already settled, we'll return the
×
UNCOV
1372
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
UNCOV
1373
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
UNCOV
1374
                        return updatedInvoice, txErr
×
UNCOV
1375
                }
×
1376

UNCOV
1377
                return nil, txErr
×
1378
        }
1379

UNCOV
1380
        return updatedInvoice, nil
×
1381
}
1382

1383
// DeleteInvoice attempts to delete the passed invoices and all their related
1384
// data from the database in one transaction.
1385
func (i *SQLStore) DeleteInvoice(ctx context.Context,
UNCOV
1386
        invoicesToDelete []InvoiceDeleteRef) error {
×
UNCOV
1387

×
UNCOV
1388
        // All the InvoiceDeleteRef instances include the add index of the
×
UNCOV
1389
        // invoice. The rest was added to ensure that the invoices were deleted
×
UNCOV
1390
        // properly in the kv database. When we have fully migrated we can
×
UNCOV
1391
        // remove the rest of the fields.
×
UNCOV
1392
        for _, ref := range invoicesToDelete {
×
UNCOV
1393
                if ref.AddIndex == 0 {
×
1394
                        return fmt.Errorf("unable to delete invoice using a "+
×
1395
                                "ref without AddIndex set: %v", ref)
×
1396
                }
×
1397
        }
1398

UNCOV
1399
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
UNCOV
1400
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1401
                for _, ref := range invoicesToDelete {
×
UNCOV
1402
                        params := sqlc.DeleteInvoiceParams{
×
UNCOV
1403
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
UNCOV
1404
                        }
×
UNCOV
1405

×
UNCOV
1406
                        if ref.SettleIndex != 0 {
×
UNCOV
1407
                                params.SettleIndex = sqldb.SQLInt64(
×
UNCOV
1408
                                        ref.SettleIndex,
×
UNCOV
1409
                                )
×
UNCOV
1410
                        }
×
1411

UNCOV
1412
                        if ref.PayHash != lntypes.ZeroHash {
×
UNCOV
1413
                                params.Hash = ref.PayHash[:]
×
UNCOV
1414
                        }
×
1415

UNCOV
1416
                        result, err := db.DeleteInvoice(ctx, params)
×
UNCOV
1417
                        if err != nil {
×
1418
                                return fmt.Errorf("unable to delete "+
×
1419
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1420
                        }
×
UNCOV
1421
                        rowsAffected, err := result.RowsAffected()
×
UNCOV
1422
                        if err != nil {
×
1423
                                return fmt.Errorf("unable to get rows "+
×
1424
                                        "affected: %w", err)
×
1425
                        }
×
UNCOV
1426
                        if rowsAffected == 0 {
×
UNCOV
1427
                                return fmt.Errorf("%w: %v",
×
UNCOV
1428
                                        ErrInvoiceNotFound, ref.AddIndex)
×
UNCOV
1429
                        }
×
1430
                }
1431

UNCOV
1432
                return nil
×
UNCOV
1433
        }, func() {})
×
1434

UNCOV
1435
        if err != nil {
×
UNCOV
1436
                return fmt.Errorf("unable to delete invoices: %w", err)
×
UNCOV
1437
        }
×
1438

UNCOV
1439
        return nil
×
1440
}
1441

1442
// DeleteCanceledInvoices removes all canceled invoices from the database.
UNCOV
1443
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
UNCOV
1444
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
UNCOV
1445
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1446
                _, err := db.DeleteCanceledInvoices(ctx)
×
UNCOV
1447
                if err != nil {
×
1448
                        return fmt.Errorf("unable to delete canceled "+
×
1449
                                "invoices: %w", err)
×
1450
                }
×
1451

UNCOV
1452
                return nil
×
UNCOV
1453
        }, func() {})
×
UNCOV
1454
        if err != nil {
×
1455
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1456
        }
×
1457

UNCOV
1458
        return nil
×
1459
}
1460

1461
// fetchInvoiceData fetches additional data for the given invoice. If the
1462
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1463
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1464
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1465
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1466
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
UNCOV
1467
        *Invoice, error) {
×
UNCOV
1468

×
UNCOV
1469
        // Unmarshal the common data.
×
UNCOV
1470
        hash, invoice, err := unmarshalInvoice(row)
×
UNCOV
1471
        if err != nil {
×
1472
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1473
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1474
        }
×
1475

1476
        // Fetch the invoice features.
UNCOV
1477
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
UNCOV
1478
        if err != nil {
×
1479
                return nil, nil, err
×
1480
        }
×
1481

UNCOV
1482
        invoice.Terms.Features = features
×
UNCOV
1483

×
UNCOV
1484
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
UNCOV
1485
        // with the HTLCs (if requested).
×
UNCOV
1486
        if invoice.IsAMP() {
×
UNCOV
1487
                invoiceID := int64(invoice.AddIndex)
×
UNCOV
1488
                ampState, ampHtlcs, err := fetchAmpState(
×
UNCOV
1489
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
UNCOV
1490
                )
×
UNCOV
1491
                if err != nil {
×
1492
                        return nil, nil, err
×
1493
                }
×
1494

UNCOV
1495
                invoice.AMPState = ampState
×
UNCOV
1496
                invoice.Htlcs = ampHtlcs
×
UNCOV
1497

×
UNCOV
1498
                return hash, invoice, nil
×
1499
        }
1500

1501
        // Otherwise simply fetch the invoice HTLCs.
UNCOV
1502
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
UNCOV
1503
        if err != nil {
×
1504
                return nil, nil, err
×
1505
        }
×
1506

UNCOV
1507
        if len(htlcs) > 0 {
×
UNCOV
1508
                invoice.Htlcs = htlcs
×
UNCOV
1509
                var amountPaid lnwire.MilliSatoshi
×
UNCOV
1510
                for _, htlc := range htlcs {
×
UNCOV
1511
                        if htlc.State == HtlcStateSettled {
×
UNCOV
1512
                                amountPaid += htlc.Amt
×
UNCOV
1513
                        }
×
1514
                }
UNCOV
1515
                invoice.AmtPaid = amountPaid
×
1516
        }
1517

UNCOV
1518
        return hash, invoice, nil
×
1519
}
1520

1521
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1522
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1523
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
1524

×
UNCOV
1525
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
UNCOV
1526
        if err != nil {
×
1527
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1528
                        err)
×
1529
        }
×
1530

UNCOV
1531
        features := lnwire.EmptyFeatureVector()
×
UNCOV
1532
        for _, feature := range rows {
×
UNCOV
1533
                features.Set(lnwire.FeatureBit(feature.Feature))
×
UNCOV
1534
        }
×
1535

UNCOV
1536
        return features, nil
×
1537
}
1538

1539
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1540
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1541
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
UNCOV
1542

×
UNCOV
1543
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
UNCOV
1544
        if err != nil {
×
1545
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1546
        }
×
1547

1548
        // We have no htlcs to unmarshal.
UNCOV
1549
        if len(htlcRows) == 0 {
×
UNCOV
1550
                return nil, nil
×
UNCOV
1551
        }
×
1552

UNCOV
1553
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
1554
        if err != nil {
×
1555
                return nil, fmt.Errorf("unable to get custom records for "+
×
1556
                        "invoice htlcs: %w", err)
×
1557
        }
×
1558

UNCOV
1559
        cr := make(map[int64]record.CustomSet, len(crRows))
×
UNCOV
1560
        for _, row := range crRows {
×
UNCOV
1561
                if _, ok := cr[row.HtlcID]; !ok {
×
UNCOV
1562
                        cr[row.HtlcID] = make(record.CustomSet)
×
UNCOV
1563
                }
×
1564

UNCOV
1565
                value := row.Value
×
UNCOV
1566
                if value == nil {
×
UNCOV
1567
                        value = []byte{}
×
UNCOV
1568
                }
×
UNCOV
1569
                cr[row.HtlcID][uint64(row.Key)] = value
×
1570
        }
1571

UNCOV
1572
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
UNCOV
1573

×
UNCOV
1574
        for _, row := range htlcRows {
×
UNCOV
1575
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
UNCOV
1576
                if err != nil {
×
1577
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1578
                                "htlc(%d): %w", row.ID, err)
×
1579
                }
×
1580

UNCOV
1581
                if customRecords, ok := cr[row.ID]; ok {
×
UNCOV
1582
                        htlc.CustomRecords = customRecords
×
UNCOV
1583
                } else {
×
UNCOV
1584
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
1585
                }
×
1586

UNCOV
1587
                htlcs[circuiteKey] = htlc
×
1588
        }
1589

UNCOV
1590
        return htlcs, nil
×
1591
}
1592

1593
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1594
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
UNCOV
1595
        error) {
×
UNCOV
1596

×
UNCOV
1597
        var (
×
UNCOV
1598
                settleIndex    int64
×
UNCOV
1599
                settledAt      time.Time
×
UNCOV
1600
                memo           []byte
×
UNCOV
1601
                paymentRequest []byte
×
UNCOV
1602
                preimage       *lntypes.Preimage
×
UNCOV
1603
                paymentAddr    [32]byte
×
UNCOV
1604
        )
×
UNCOV
1605

×
UNCOV
1606
        hash, err := lntypes.MakeHash(row.Hash)
×
UNCOV
1607
        if err != nil {
×
1608
                return nil, nil, err
×
1609
        }
×
1610

UNCOV
1611
        if row.SettleIndex.Valid {
×
UNCOV
1612
                settleIndex = row.SettleIndex.Int64
×
UNCOV
1613
        }
×
1614

UNCOV
1615
        if row.SettledAt.Valid {
×
UNCOV
1616
                settledAt = row.SettledAt.Time.Local()
×
UNCOV
1617
        }
×
1618

UNCOV
1619
        if row.Memo.Valid {
×
UNCOV
1620
                memo = []byte(row.Memo.String)
×
UNCOV
1621
        }
×
1622

1623
        // Keysend payments will have this field empty.
UNCOV
1624
        if row.PaymentRequest.Valid {
×
UNCOV
1625
                paymentRequest = []byte(row.PaymentRequest.String)
×
UNCOV
1626
        } else {
×
UNCOV
1627
                paymentRequest = []byte{}
×
UNCOV
1628
        }
×
1629

1630
        // We may not have the preimage if this a hodl invoice.
UNCOV
1631
        if row.Preimage != nil {
×
UNCOV
1632
                preimage = &lntypes.Preimage{}
×
UNCOV
1633
                copy(preimage[:], row.Preimage)
×
UNCOV
1634
        }
×
1635

UNCOV
1636
        copy(paymentAddr[:], row.PaymentAddr)
×
UNCOV
1637

×
UNCOV
1638
        var cltvDelta int32
×
UNCOV
1639
        if row.CltvDelta.Valid {
×
UNCOV
1640
                cltvDelta = row.CltvDelta.Int32
×
UNCOV
1641
        }
×
1642

UNCOV
1643
        expiry := time.Duration(row.Expiry) * time.Second
×
UNCOV
1644

×
UNCOV
1645
        invoice := &Invoice{
×
UNCOV
1646
                SettleIndex:    uint64(settleIndex),
×
UNCOV
1647
                SettleDate:     settledAt,
×
UNCOV
1648
                Memo:           memo,
×
UNCOV
1649
                PaymentRequest: paymentRequest,
×
UNCOV
1650
                CreationDate:   row.CreatedAt.Local(),
×
UNCOV
1651
                Terms: ContractTerm{
×
UNCOV
1652
                        FinalCltvDelta:  cltvDelta,
×
UNCOV
1653
                        Expiry:          expiry,
×
UNCOV
1654
                        PaymentPreimage: preimage,
×
UNCOV
1655
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1656
                        PaymentAddr:     paymentAddr,
×
UNCOV
1657
                },
×
UNCOV
1658
                AddIndex:    uint64(row.ID),
×
UNCOV
1659
                State:       ContractState(row.State),
×
UNCOV
1660
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
UNCOV
1661
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
UNCOV
1662
                AMPState:    AMPInvoiceState{},
×
UNCOV
1663
                HodlInvoice: row.IsHodl,
×
UNCOV
1664
        }
×
UNCOV
1665

×
UNCOV
1666
        return &hash, invoice, nil
×
1667
}
1668

1669
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1670
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
UNCOV
1671
        *InvoiceHTLC, error) {
×
UNCOV
1672

×
UNCOV
1673
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
1674
        if err != nil {
×
1675
                return CircuitKey{}, nil, err
×
1676
        }
×
1677

UNCOV
1678
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
1679

×
UNCOV
1680
        if row.HtlcID < 0 {
×
1681
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1682
                        "value: %v", row.HtlcID)
×
1683
        }
×
1684

UNCOV
1685
        htlcID := uint64(row.HtlcID)
×
UNCOV
1686

×
UNCOV
1687
        circuitKey := CircuitKey{
×
UNCOV
1688
                ChanID: chanID,
×
UNCOV
1689
                HtlcID: htlcID,
×
UNCOV
1690
        }
×
UNCOV
1691

×
UNCOV
1692
        htlc := &InvoiceHTLC{
×
UNCOV
1693
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1694
                AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
1695
                AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
1696
                Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
1697
                State:        HtlcState(row.State),
×
UNCOV
1698
        }
×
UNCOV
1699

×
UNCOV
1700
        if row.TotalMppMsat.Valid {
×
UNCOV
1701
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
UNCOV
1702
        }
×
1703

UNCOV
1704
        if row.ResolveTime.Valid {
×
UNCOV
1705
                htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
1706
        }
×
1707

UNCOV
1708
        return circuitKey, htlc, nil
×
1709
}
1710

1711
// queryWithLimit is a helper method that can be used to query the database
1712
// using a limit and offset. The passed query function should return the number
1713
// of rows returned and an error if any.
UNCOV
1714
func queryWithLimit(query func(int) (int, error), limit int) error {
×
UNCOV
1715
        offset := 0
×
UNCOV
1716
        for {
×
UNCOV
1717
                rows, err := query(offset)
×
UNCOV
1718
                if err != nil {
×
1719
                        return err
×
1720
                }
×
1721

UNCOV
1722
                if rows < limit {
×
UNCOV
1723
                        return nil
×
UNCOV
1724
                }
×
1725

UNCOV
1726
                offset += limit
×
1727
        }
1728
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc