• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 13593508312

28 Feb 2025 05:41PM UTC coverage: 58.287% (-10.4%) from 68.65%
13593508312

Pull #9458

github

web-flow
Merge d40067c0c into f1182e433
Pull Request #9458: multi+server.go: add initial permissions for some peers

346 of 548 new or added lines in 10 files covered. (63.14%)

27412 existing lines in 442 files now uncovered.

94709 of 162488 relevant lines covered (58.29%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/invoices/sql_store.go
1
package invoices
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "database/sql"
7
        "errors"
8
        "fmt"
9
        "math"
10
        "strconv"
11
        "time"
12

13
        "github.com/davecgh/go-spew/spew"
14
        "github.com/lightningnetwork/lnd/clock"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/lntypes"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/record"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
)
22

23
const (
24
        // defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
25
        // queries to limit the number of rows returned.
26
        defaultQueryPaginationLimit = 100
27
)
28

29
// SQLInvoiceQueries is an interface that defines the set of operations that can
30
// be executed against the invoice SQL database.
31
type SQLInvoiceQueries interface { //nolint:interfacebloat
32
        InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
33
                error)
34

35
        // TODO(bhandras): remove this once migrations have been separated out.
36
        InsertMigratedInvoice(ctx context.Context,
37
                arg sqlc.InsertMigratedInvoiceParams) (int64, error)
38

39
        InsertInvoiceFeature(ctx context.Context,
40
                arg sqlc.InsertInvoiceFeatureParams) error
41

42
        InsertInvoiceHTLC(ctx context.Context,
43
                arg sqlc.InsertInvoiceHTLCParams) (int64, error)
44

45
        InsertInvoiceHTLCCustomRecord(ctx context.Context,
46
                arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
47

48
        FilterInvoices(ctx context.Context,
49
                arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
50

51
        GetInvoice(ctx context.Context,
52
                arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
53

54
        GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
55
                error)
56

57
        GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
58
                error)
59

60
        GetInvoiceFeatures(ctx context.Context,
61
                invoiceID int64) ([]sqlc.InvoiceFeature, error)
62

63
        GetInvoiceHTLCCustomRecords(ctx context.Context,
64
                invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
65

66
        GetInvoiceHTLCs(ctx context.Context,
67
                invoiceID int64) ([]sqlc.InvoiceHtlc, error)
68

69
        UpdateInvoiceState(ctx context.Context,
70
                arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
71

72
        UpdateInvoiceAmountPaid(ctx context.Context,
73
                arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
74

75
        NextInvoiceSettleIndex(ctx context.Context) (int64, error)
76

77
        UpdateInvoiceHTLC(ctx context.Context,
78
                arg sqlc.UpdateInvoiceHTLCParams) error
79

80
        DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
81
                sql.Result, error)
82

83
        DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
84

85
        // AMP sub invoice specific methods.
86
        UpsertAMPSubInvoice(ctx context.Context,
87
                arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
88

89
        // TODO(bhandras): remove this once migrations have been separated out.
90
        InsertAMPSubInvoice(ctx context.Context,
91
                arg sqlc.InsertAMPSubInvoiceParams) error
92

93
        UpdateAMPSubInvoiceState(ctx context.Context,
94
                arg sqlc.UpdateAMPSubInvoiceStateParams) error
95

96
        InsertAMPSubInvoiceHTLC(ctx context.Context,
97
                arg sqlc.InsertAMPSubInvoiceHTLCParams) error
98

99
        FetchAMPSubInvoices(ctx context.Context,
100
                arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
101
                error)
102

103
        FetchAMPSubInvoiceHTLCs(ctx context.Context,
104
                arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
105
                []sqlc.FetchAMPSubInvoiceHTLCsRow, error)
106

107
        FetchSettledAMPSubInvoices(ctx context.Context,
108
                arg sqlc.FetchSettledAMPSubInvoicesParams) (
109
                []sqlc.FetchSettledAMPSubInvoicesRow, error)
110

111
        UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
112
                arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
113
                error)
114

115
        // Invoice events specific methods.
116
        OnInvoiceCreated(ctx context.Context,
117
                arg sqlc.OnInvoiceCreatedParams) error
118

119
        OnInvoiceCanceled(ctx context.Context,
120
                arg sqlc.OnInvoiceCanceledParams) error
121

122
        OnInvoiceSettled(ctx context.Context,
123
                arg sqlc.OnInvoiceSettledParams) error
124

125
        OnAMPSubInvoiceCreated(ctx context.Context,
126
                arg sqlc.OnAMPSubInvoiceCreatedParams) error
127

128
        OnAMPSubInvoiceCanceled(ctx context.Context,
129
                arg sqlc.OnAMPSubInvoiceCanceledParams) error
130

131
        OnAMPSubInvoiceSettled(ctx context.Context,
132
                arg sqlc.OnAMPSubInvoiceSettledParams) error
133

134
        // Migration specific methods.
135
        // TODO(bhandras): remove this once migrations have been separated out.
136
        InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
137
                arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
138

139
        SetKVInvoicePaymentHash(ctx context.Context,
140
                arg sqlc.SetKVInvoicePaymentHashParams) error
141

142
        GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
143
                []byte, error)
144

145
        ClearKVInvoiceHashIndex(ctx context.Context) error
146
}
147

148
var _ InvoiceDB = (*SQLStore)(nil)
149

150
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
151
// SQLInvoiceQueries understands.
152
type SQLInvoiceQueriesTxOptions struct {
153
        // readOnly governs if a read only transaction is needed or not.
154
        readOnly bool
155
}
156

157
// ReadOnly returns true if the transaction should be read only.
158
//
159
// NOTE: This implements the TxOptions.
UNCOV
160
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
×
UNCOV
161
        return a.readOnly
×
UNCOV
162
}
×
163

164
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
UNCOV
165
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
×
UNCOV
166
        return SQLInvoiceQueriesTxOptions{
×
UNCOV
167
                readOnly: true,
×
UNCOV
168
        }
×
UNCOV
169
}
×
170

171
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
172
// of batched database operations.
173
type BatchedSQLInvoiceQueries interface {
174
        SQLInvoiceQueries
175

176
        sqldb.BatchedTx[SQLInvoiceQueries]
177
}
178

179
// SQLStore represents a storage backend.
180
type SQLStore struct {
181
        db    BatchedSQLInvoiceQueries
182
        clock clock.Clock
183
        opts  SQLStoreOptions
184
}
185

186
// SQLStoreOptions holds the options for the SQL store.
187
type SQLStoreOptions struct {
188
        paginationLimit int
189
}
190

191
// defaultSQLStoreOptions returns the default options for the SQL store.
UNCOV
192
func defaultSQLStoreOptions() SQLStoreOptions {
×
UNCOV
193
        return SQLStoreOptions{
×
UNCOV
194
                paginationLimit: defaultQueryPaginationLimit,
×
UNCOV
195
        }
×
UNCOV
196
}
×
197

198
// SQLStoreOption is a functional option that can be used to optionally modify
199
// the behavior of the SQL store.
200
type SQLStoreOption func(*SQLStoreOptions)
201

202
// WithPaginationLimit sets the pagination limit for the SQL store queries that
203
// paginate results.
UNCOV
204
func WithPaginationLimit(limit int) SQLStoreOption {
×
UNCOV
205
        return func(o *SQLStoreOptions) {
×
UNCOV
206
                o.paginationLimit = limit
×
UNCOV
207
        }
×
208
}
209

210
// NewSQLStore creates a new SQLStore instance given a open
211
// BatchedSQLInvoiceQueries storage backend.
212
func NewSQLStore(db BatchedSQLInvoiceQueries,
UNCOV
213
        clock clock.Clock, options ...SQLStoreOption) *SQLStore {
×
UNCOV
214

×
UNCOV
215
        opts := defaultSQLStoreOptions()
×
UNCOV
216
        for _, applyOption := range options {
×
UNCOV
217
                applyOption(&opts)
×
UNCOV
218
        }
×
219

UNCOV
220
        return &SQLStore{
×
UNCOV
221
                db:    db,
×
UNCOV
222
                clock: clock,
×
UNCOV
223
                opts:  opts,
×
UNCOV
224
        }
×
225
}
226

227
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
UNCOV
228
        sqlc.InsertInvoiceParams, error) {
×
UNCOV
229

×
UNCOV
230
        // Precompute the payment request hash so we can use it in the query.
×
UNCOV
231
        var paymentRequestHash []byte
×
UNCOV
232
        if len(invoice.PaymentRequest) > 0 {
×
UNCOV
233
                h := sha256.New()
×
UNCOV
234
                h.Write(invoice.PaymentRequest)
×
UNCOV
235
                paymentRequestHash = h.Sum(nil)
×
UNCOV
236
        }
×
237

UNCOV
238
        params := sqlc.InsertInvoiceParams{
×
UNCOV
239
                Hash:       paymentHash[:],
×
UNCOV
240
                AmountMsat: int64(invoice.Terms.Value),
×
UNCOV
241
                CltvDelta: sqldb.SQLInt32(
×
UNCOV
242
                        invoice.Terms.FinalCltvDelta,
×
UNCOV
243
                ),
×
UNCOV
244
                Expiry: int32(invoice.Terms.Expiry.Seconds()),
×
UNCOV
245
                // Note: keysend invoices don't have a payment request.
×
UNCOV
246
                PaymentRequest: sqldb.SQLStr(string(
×
UNCOV
247
                        invoice.PaymentRequest),
×
UNCOV
248
                ),
×
UNCOV
249
                PaymentRequestHash: paymentRequestHash,
×
UNCOV
250
                State:              int16(invoice.State),
×
UNCOV
251
                AmountPaidMsat:     int64(invoice.AmtPaid),
×
UNCOV
252
                IsAmp:              invoice.IsAMP(),
×
UNCOV
253
                IsHodl:             invoice.HodlInvoice,
×
UNCOV
254
                IsKeysend:          invoice.IsKeysend(),
×
UNCOV
255
                CreatedAt:          invoice.CreationDate.UTC(),
×
UNCOV
256
        }
×
UNCOV
257

×
UNCOV
258
        if invoice.Memo != nil {
×
UNCOV
259
                // Store the memo as a nullable string in the database. Note
×
UNCOV
260
                // that for compatibility reasons, we store the value as a valid
×
UNCOV
261
                // string even if it's empty.
×
UNCOV
262
                params.Memo = sql.NullString{
×
UNCOV
263
                        String: string(invoice.Memo),
×
UNCOV
264
                        Valid:  true,
×
UNCOV
265
                }
×
UNCOV
266
        }
×
267

268
        // Some invoices may not have a preimage, like in the case of HODL
269
        // invoices.
UNCOV
270
        if invoice.Terms.PaymentPreimage != nil {
×
UNCOV
271
                preimage := *invoice.Terms.PaymentPreimage
×
UNCOV
272
                if preimage == UnknownPreimage {
×
273
                        return sqlc.InsertInvoiceParams{},
×
274
                                errors.New("cannot use all-zeroes preimage")
×
275
                }
×
UNCOV
276
                params.Preimage = preimage[:]
×
277
        }
278

279
        // Some non MPP payments may have the default (invalid) value.
UNCOV
280
        if invoice.Terms.PaymentAddr != BlankPayAddr {
×
UNCOV
281
                params.PaymentAddr = invoice.Terms.PaymentAddr[:]
×
UNCOV
282
        }
×
283

UNCOV
284
        return params, nil
×
285
}
286

287
// AddInvoice inserts the targeted invoice into the database. If the invoice has
288
// *any* payment hashes which already exists within the database, then the
289
// insertion will be aborted and rejected due to the strict policy banning any
290
// duplicate payment hashes.
291
//
292
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
293
func (i *SQLStore) AddInvoice(ctx context.Context,
UNCOV
294
        newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
×
UNCOV
295

×
UNCOV
296
        // Make sure this is a valid invoice before trying to store it in our
×
UNCOV
297
        // DB.
×
UNCOV
298
        if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
×
UNCOV
299
                return 0, err
×
UNCOV
300
        }
×
301

UNCOV
302
        var (
×
UNCOV
303
                writeTxOpts SQLInvoiceQueriesTxOptions
×
UNCOV
304
                invoiceID   int64
×
UNCOV
305
        )
×
UNCOV
306

×
UNCOV
307
        insertInvoiceParams, err := makeInsertInvoiceParams(
×
UNCOV
308
                newInvoice, paymentHash,
×
UNCOV
309
        )
×
UNCOV
310
        if err != nil {
×
311
                return 0, err
×
312
        }
×
313

UNCOV
314
        err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
×
UNCOV
315
                var err error
×
UNCOV
316
                invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
×
UNCOV
317
                if err != nil {
×
UNCOV
318
                        return fmt.Errorf("unable to insert invoice: %w", err)
×
UNCOV
319
                }
×
320

321
                // TODO(positiveblue): if invocies do not have custom features
322
                // maybe just store the "invoice type" and populate the features
323
                // based on that.
UNCOV
324
                for feature := range newInvoice.Terms.Features.Features() {
×
UNCOV
325
                        params := sqlc.InsertInvoiceFeatureParams{
×
UNCOV
326
                                InvoiceID: invoiceID,
×
UNCOV
327
                                Feature:   int32(feature),
×
UNCOV
328
                        }
×
UNCOV
329

×
UNCOV
330
                        err := db.InsertInvoiceFeature(ctx, params)
×
UNCOV
331
                        if err != nil {
×
332
                                return fmt.Errorf("unable to insert invoice "+
×
333
                                        "feature(%v): %w", feature, err)
×
334
                        }
×
335
                }
336

337
                // Finally add a new event for this invoice.
UNCOV
338
                return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
×
UNCOV
339
                        AddedAt:   newInvoice.CreationDate.UTC(),
×
UNCOV
340
                        InvoiceID: invoiceID,
×
UNCOV
341
                })
×
UNCOV
342
        }, func() {})
×
UNCOV
343
        if err != nil {
×
UNCOV
344
                mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
345
                var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
×
UNCOV
346
                if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
347
                        // Add context to unique constraint errors.
×
UNCOV
348
                        return 0, ErrDuplicateInvoice
×
UNCOV
349
                }
×
350

351
                return 0, fmt.Errorf("unable to add invoice(%v): %w",
×
352
                        paymentHash, err)
×
353
        }
354

UNCOV
355
        newInvoice.AddIndex = uint64(invoiceID)
×
UNCOV
356

×
UNCOV
357
        return newInvoice.AddIndex, nil
×
358
}
359

360
// getInvoiceByRef fetches the invoice with the given reference. The reference
361
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
362
func getInvoiceByRef(ctx context.Context,
UNCOV
363
        db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
×
UNCOV
364

×
UNCOV
365
        // If the reference is empty, we can't look up the invoice.
×
UNCOV
366
        if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
×
UNCOV
367
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
UNCOV
368
        }
×
369

370
        // If the reference is a hash only, we can look up the invoice directly
371
        // by the payment hash which is faster.
UNCOV
372
        if ref.IsHashOnly() {
×
UNCOV
373
                invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
×
UNCOV
374
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
375
                        return sqlc.Invoice{}, ErrInvoiceNotFound
×
UNCOV
376
                }
×
377

UNCOV
378
                return invoice, err
×
379
        }
380

381
        // Otherwise the reference may include more fields, so we'll need to
382
        // assemble the query parameters based on the fields that are set.
UNCOV
383
        var params sqlc.GetInvoiceParams
×
UNCOV
384

×
UNCOV
385
        if ref.PayHash() != nil {
×
UNCOV
386
                params.Hash = ref.PayHash()[:]
×
UNCOV
387
        }
×
388

389
        // Newer invoices (0.11 and up) are indexed by payment address in
390
        // addition to payment hash, but pre 0.8 invoices do not have one at
391
        // all. Only allow lookups for payment address if it is not a blank
392
        // payment address, which is a special-cased value for legacy keysend
393
        // invoices.
UNCOV
394
        if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
×
UNCOV
395
                params.PaymentAddr = ref.PayAddr()[:]
×
UNCOV
396
        }
×
397

398
        // If the reference has a set ID we'll fetch the invoice which has the
399
        // corresponding AMP sub invoice.
UNCOV
400
        if ref.SetID() != nil {
×
UNCOV
401
                params.SetID = ref.SetID()[:]
×
UNCOV
402
        }
×
403

UNCOV
404
        var (
×
UNCOV
405
                rows []sqlc.Invoice
×
UNCOV
406
                err  error
×
UNCOV
407
        )
×
UNCOV
408

×
UNCOV
409
        // We need to split the query based on how we intend to look up the
×
UNCOV
410
        // invoice. If only the set ID is given then we want to have an exact
×
UNCOV
411
        // match on the set ID. If other fields are given, we want to match on
×
UNCOV
412
        // those fields and the set ID but with a less strict join condition.
×
UNCOV
413
        if params.Hash == nil && params.PaymentAddr == nil &&
×
UNCOV
414
                params.SetID != nil {
×
UNCOV
415

×
UNCOV
416
                rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
×
UNCOV
417
        } else {
×
UNCOV
418
                rows, err = db.GetInvoice(ctx, params)
×
UNCOV
419
        }
×
420

UNCOV
421
        switch {
×
UNCOV
422
        case len(rows) == 0:
×
UNCOV
423
                return sqlc.Invoice{}, ErrInvoiceNotFound
×
424

425
        case len(rows) > 1:
×
426
                // In case the reference is ambiguous, meaning it matches more
×
427
                // than        one invoice, we'll return an error.
×
428
                return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
×
429
                        "%s: %s", ref.String(), spew.Sdump(rows))
×
430

431
        case err != nil:
×
432
                return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
×
433
                        err)
×
434
        }
435

UNCOV
436
        return rows[0], nil
×
437
}
438

439
// fetchInvoice fetches the common invoice data and the AMP state for the
440
// invoice with the given reference.
441
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
UNCOV
442
        *Invoice, error) {
×
UNCOV
443

×
UNCOV
444
        // Fetch the invoice from the database.
×
UNCOV
445
        sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
×
UNCOV
446
        if err != nil {
×
UNCOV
447
                return nil, err
×
UNCOV
448
        }
×
449

UNCOV
450
        var (
×
UNCOV
451
                setID         *[32]byte
×
UNCOV
452
                fetchAmpHtlcs bool
×
UNCOV
453
        )
×
UNCOV
454

×
UNCOV
455
        // Now that we got the invoice itself, fetch the HTLCs as requested by
×
UNCOV
456
        // the modifier.
×
UNCOV
457
        switch ref.Modifier() {
×
UNCOV
458
        case DefaultModifier:
×
UNCOV
459
                // By default we'll fetch all AMP HTLCs.
×
UNCOV
460
                setID = nil
×
UNCOV
461
                fetchAmpHtlcs = true
×
462

UNCOV
463
        case HtlcSetOnlyModifier:
×
UNCOV
464
                // In this case we'll fetch all AMP HTLCs for the specified set
×
UNCOV
465
                // id.
×
UNCOV
466
                if ref.SetID() == nil {
×
467
                        return nil, fmt.Errorf("set ID is required to use " +
×
468
                                "the HTLC set only modifier")
×
469
                }
×
470

UNCOV
471
                setID = ref.SetID()
×
UNCOV
472
                fetchAmpHtlcs = true
×
473

UNCOV
474
        case HtlcSetBlankModifier:
×
UNCOV
475
                // No need to fetch any HTLCs.
×
UNCOV
476
                setID = nil
×
UNCOV
477
                fetchAmpHtlcs = false
×
478

479
        default:
×
480
                return nil, fmt.Errorf("unknown invoice ref modifier: %v",
×
481
                        ref.Modifier())
×
482
        }
483

484
        // Fetch the rest of the invoice data and fill the invoice struct.
UNCOV
485
        _, invoice, err := fetchInvoiceData(
×
UNCOV
486
                ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
×
UNCOV
487
        )
×
UNCOV
488
        if err != nil {
×
489
                return nil, err
×
490
        }
×
491

UNCOV
492
        return invoice, nil
×
493
}
494

495
// fetchAmpState fetches the AMP state for the invoice with the given ID.
496
// Optional setID can be provided to fetch the state for a specific AMP HTLC
497
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
498
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
499
// well.
500
//
501
//nolint:funlen
502
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
503
        setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
UNCOV
504
        HTLCSet, error) {
×
UNCOV
505

×
UNCOV
506
        var paramSetID []byte
×
UNCOV
507
        if setID != nil {
×
UNCOV
508
                paramSetID = setID[:]
×
UNCOV
509
        }
×
510

511
        // First fetch all the AMP sub invoices for this invoice or the one
512
        // matching the provided set ID.
UNCOV
513
        ampInvoiceRows, err := db.FetchAMPSubInvoices(
×
UNCOV
514
                ctx, sqlc.FetchAMPSubInvoicesParams{
×
UNCOV
515
                        InvoiceID: invoiceID,
×
UNCOV
516
                        SetID:     paramSetID,
×
UNCOV
517
                },
×
UNCOV
518
        )
×
UNCOV
519
        if err != nil {
×
520
                return nil, nil, err
×
521
        }
×
522

UNCOV
523
        ampState := make(map[SetID]InvoiceStateAMP)
×
UNCOV
524
        for _, row := range ampInvoiceRows {
×
UNCOV
525
                var rowSetID [32]byte
×
UNCOV
526

×
UNCOV
527
                if len(row.SetID) != 32 {
×
528
                        return nil, nil, fmt.Errorf("invalid set id length: %d",
×
529
                                len(row.SetID))
×
530
                }
×
531

UNCOV
532
                var settleDate time.Time
×
UNCOV
533
                if row.SettledAt.Valid {
×
UNCOV
534
                        settleDate = row.SettledAt.Time.Local()
×
UNCOV
535
                }
×
536

UNCOV
537
                copy(rowSetID[:], row.SetID)
×
UNCOV
538
                ampState[rowSetID] = InvoiceStateAMP{
×
UNCOV
539
                        State:       HtlcState(row.State),
×
UNCOV
540
                        SettleIndex: uint64(row.SettleIndex.Int64),
×
UNCOV
541
                        SettleDate:  settleDate,
×
UNCOV
542
                        InvoiceKeys: make(map[models.CircuitKey]struct{}),
×
UNCOV
543
                }
×
544
        }
545

UNCOV
546
        if !fetchHtlcs {
×
UNCOV
547
                return ampState, nil, nil
×
UNCOV
548
        }
×
549

UNCOV
550
        customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
551
        if err != nil {
×
552
                return nil, nil, fmt.Errorf("unable to get custom records for "+
×
553
                        "invoice HTLCs: %w", err)
×
554
        }
×
555

UNCOV
556
        customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
×
UNCOV
557
        for _, row := range customRecordRows {
×
UNCOV
558
                if _, ok := customRecords[row.HtlcID]; !ok {
×
UNCOV
559
                        customRecords[row.HtlcID] = make(record.CustomSet)
×
UNCOV
560
                }
×
561

UNCOV
562
                value := row.Value
×
UNCOV
563
                if value == nil {
×
564
                        value = []byte{}
×
565
                }
×
566

UNCOV
567
                customRecords[row.HtlcID][uint64(row.Key)] = value
×
568
        }
569

570
        // Now fetch all the AMP HTLCs for this invoice or the one matching the
571
        // provided set ID.
UNCOV
572
        ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
×
UNCOV
573
                ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
×
UNCOV
574
                        InvoiceID: invoiceID,
×
UNCOV
575
                        SetID:     paramSetID,
×
UNCOV
576
                },
×
UNCOV
577
        )
×
UNCOV
578
        if err != nil {
×
579
                return nil, nil, err
×
580
        }
×
581

UNCOV
582
        ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
×
UNCOV
583
        for _, row := range ampHtlcRows {
×
UNCOV
584
                uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
585
                if err != nil {
×
586
                        return nil, nil, err
×
587
                }
×
588

UNCOV
589
                chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
590

×
UNCOV
591
                if row.HtlcID < 0 {
×
592
                        return nil, nil, fmt.Errorf("invalid HTLC ID "+
×
593
                                "value: %v", row.HtlcID)
×
594
                }
×
595

UNCOV
596
                htlcID := uint64(row.HtlcID)
×
UNCOV
597

×
UNCOV
598
                circuitKey := CircuitKey{
×
UNCOV
599
                        ChanID: chanID,
×
UNCOV
600
                        HtlcID: htlcID,
×
UNCOV
601
                }
×
UNCOV
602

×
UNCOV
603
                htlc := &InvoiceHTLC{
×
UNCOV
604
                        Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
605
                        AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
606
                        AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
607
                        Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
608
                        State:        HtlcState(row.State),
×
UNCOV
609
                }
×
UNCOV
610

×
UNCOV
611
                if row.TotalMppMsat.Valid {
×
UNCOV
612
                        htlc.MppTotalAmt = lnwire.MilliSatoshi(
×
UNCOV
613
                                row.TotalMppMsat.Int64,
×
UNCOV
614
                        )
×
UNCOV
615
                }
×
616

UNCOV
617
                if row.ResolveTime.Valid {
×
UNCOV
618
                        htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
619
                }
×
620

UNCOV
621
                var (
×
UNCOV
622
                        rootShare [32]byte
×
UNCOV
623
                        setID     [32]byte
×
UNCOV
624
                )
×
UNCOV
625

×
UNCOV
626
                if len(row.RootShare) != 32 {
×
627
                        return nil, nil, fmt.Errorf("invalid root share "+
×
628
                                "length: %d", len(row.RootShare))
×
629
                }
×
UNCOV
630
                copy(rootShare[:], row.RootShare)
×
UNCOV
631

×
UNCOV
632
                if len(row.SetID) != 32 {
×
633
                        return nil, nil, fmt.Errorf("invalid set ID length: %d",
×
634
                                len(row.SetID))
×
635
                }
×
UNCOV
636
                copy(setID[:], row.SetID)
×
UNCOV
637

×
UNCOV
638
                if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
×
639
                        return nil, nil, fmt.Errorf("invalid child index "+
×
640
                                "value: %v", row.ChildIndex)
×
641
                }
×
642

UNCOV
643
                ampRecord := record.NewAMP(
×
UNCOV
644
                        rootShare, setID, uint32(row.ChildIndex),
×
UNCOV
645
                )
×
UNCOV
646

×
UNCOV
647
                htlc.AMP = &InvoiceHtlcAMPData{
×
UNCOV
648
                        Record: *ampRecord,
×
UNCOV
649
                }
×
UNCOV
650

×
UNCOV
651
                if len(row.Hash) != 32 {
×
652
                        return nil, nil, fmt.Errorf("invalid hash length: %d",
×
653
                                len(row.Hash))
×
654
                }
×
UNCOV
655
                copy(htlc.AMP.Hash[:], row.Hash)
×
UNCOV
656

×
UNCOV
657
                if row.Preimage != nil {
×
UNCOV
658
                        preimage, err := lntypes.MakePreimage(row.Preimage)
×
UNCOV
659
                        if err != nil {
×
660
                                return nil, nil, err
×
661
                        }
×
662

UNCOV
663
                        htlc.AMP.Preimage = &preimage
×
664
                }
665

UNCOV
666
                if _, ok := customRecords[row.ID]; ok {
×
UNCOV
667
                        htlc.CustomRecords = customRecords[row.ID]
×
UNCOV
668
                } else {
×
UNCOV
669
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
670
                }
×
671

UNCOV
672
                ampHtlcs[circuitKey] = htlc
×
673
        }
674

UNCOV
675
        if len(ampHtlcs) > 0 {
×
UNCOV
676
                for setID := range ampState {
×
UNCOV
677
                        var amtPaid lnwire.MilliSatoshi
×
UNCOV
678
                        invoiceKeys := make(
×
UNCOV
679
                                map[models.CircuitKey]struct{},
×
UNCOV
680
                        )
×
UNCOV
681

×
UNCOV
682
                        for key, htlc := range ampHtlcs {
×
UNCOV
683
                                if htlc.AMP.Record.SetID() != setID {
×
UNCOV
684
                                        continue
×
685
                                }
686

UNCOV
687
                                invoiceKeys[key] = struct{}{}
×
UNCOV
688

×
UNCOV
689
                                if htlc.State != HtlcStateCanceled {
×
UNCOV
690
                                        amtPaid += htlc.Amt
×
UNCOV
691
                                }
×
692
                        }
693

UNCOV
694
                        setState := ampState[setID]
×
UNCOV
695
                        setState.InvoiceKeys = invoiceKeys
×
UNCOV
696
                        setState.AmtPaid = amtPaid
×
UNCOV
697
                        ampState[setID] = setState
×
698
                }
699
        }
700

UNCOV
701
        return ampState, ampHtlcs, nil
×
702
}
703

704
// LookupInvoice attempts to look up an invoice corresponding the passed in
705
// reference. The reference may be a payment hash, a payment address, or a set
706
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
707
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
708
// error.
709
func (i *SQLStore) LookupInvoice(ctx context.Context,
UNCOV
710
        ref InvoiceRef) (Invoice, error) {
×
UNCOV
711

×
UNCOV
712
        var (
×
UNCOV
713
                invoice *Invoice
×
UNCOV
714
                err     error
×
UNCOV
715
        )
×
UNCOV
716

×
UNCOV
717
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
718
        txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
719
                invoice, err = fetchInvoice(ctx, db, ref)
×
UNCOV
720

×
UNCOV
721
                return err
×
UNCOV
722
        }, func() {})
×
UNCOV
723
        if txErr != nil {
×
UNCOV
724
                return Invoice{}, txErr
×
UNCOV
725
        }
×
726

UNCOV
727
        return *invoice, nil
×
728
}
729

730
// FetchPendingInvoices returns all the invoices that are currently in a
731
// "pending" state. An invoice is pending if it has been created but not yet
732
// settled or canceled.
733
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
UNCOV
734
        map[lntypes.Hash]Invoice, error) {
×
UNCOV
735

×
UNCOV
736
        var invoices map[lntypes.Hash]Invoice
×
UNCOV
737

×
UNCOV
738
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
739
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
740
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
741
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
742
                                PendingOnly: true,
×
UNCOV
743
                                NumOffset:   int32(offset),
×
UNCOV
744
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
745
                                Reverse:     false,
×
UNCOV
746
                        }
×
UNCOV
747

×
UNCOV
748
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
749
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
750
                                return 0, fmt.Errorf("unable to get invoices "+
×
751
                                        "from db: %w", err)
×
752
                        }
×
753

754
                        // Load all the information for the invoices.
UNCOV
755
                        for _, row := range rows {
×
UNCOV
756
                                hash, invoice, err := fetchInvoiceData(
×
UNCOV
757
                                        ctx, db, row, nil, true,
×
UNCOV
758
                                )
×
UNCOV
759
                                if err != nil {
×
760
                                        return 0, err
×
761
                                }
×
762

UNCOV
763
                                invoices[*hash] = *invoice
×
764
                        }
765

UNCOV
766
                        return len(rows), nil
×
767
                }, i.opts.paginationLimit)
UNCOV
768
        }, func() {
×
UNCOV
769
                invoices = make(map[lntypes.Hash]Invoice)
×
UNCOV
770
        })
×
UNCOV
771
        if err != nil {
×
772
                return nil, fmt.Errorf("unable to fetch pending invoices: %w",
×
773
                        err)
×
774
        }
×
775

UNCOV
776
        return invoices, nil
×
777
}
778

779
// InvoicesSettledSince can be used by callers to catch up any settled invoices
780
// they missed within the settled invoice time series. We'll return all known
781
// settled invoice that have a settle index higher than the passed idx.
782
//
783
// NOTE: The index starts from 1. As a result we enforce that specifying a value
784
// below the starting index value is a noop.
785
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
UNCOV
786
        []Invoice, error) {
×
UNCOV
787

×
UNCOV
788
        var invoices []Invoice
×
UNCOV
789

×
UNCOV
790
        if idx == 0 {
×
UNCOV
791
                return invoices, nil
×
UNCOV
792
        }
×
793

UNCOV
794
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
795
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
796
                err := queryWithLimit(func(offset int) (int, error) {
×
UNCOV
797
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
798
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
799
                                NumOffset:      int32(offset),
×
UNCOV
800
                                NumLimit:       int32(i.opts.paginationLimit),
×
UNCOV
801
                                Reverse:        false,
×
UNCOV
802
                        }
×
UNCOV
803

×
UNCOV
804
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
805
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
806
                                return 0, fmt.Errorf("unable to get invoices "+
×
807
                                        "from db: %w", err)
×
808
                        }
×
809

810
                        // Load all the information for the invoices.
UNCOV
811
                        for _, row := range rows {
×
UNCOV
812
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
813
                                        ctx, db, row, nil, true,
×
UNCOV
814
                                )
×
UNCOV
815
                                if err != nil {
×
816
                                        return 0, fmt.Errorf("unable to fetch "+
×
817
                                                "invoice(id=%d) from db: %w",
×
818
                                                row.ID, err)
×
819
                                }
×
820

UNCOV
821
                                invoices = append(invoices, *invoice)
×
822
                        }
823

UNCOV
824
                        return len(rows), nil
×
825
                }, i.opts.paginationLimit)
UNCOV
826
                if err != nil {
×
827
                        return err
×
828
                }
×
829

830
                // Now fetch all the AMP sub invoices that were settled since
831
                // the provided index.
UNCOV
832
                ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
×
UNCOV
833
                        ctx, sqlc.FetchSettledAMPSubInvoicesParams{
×
UNCOV
834
                                SettleIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
835
                        },
×
UNCOV
836
                )
×
UNCOV
837
                if err != nil {
×
838
                        return err
×
839
                }
×
840

UNCOV
841
                for _, ampInvoice := range ampInvoices {
×
UNCOV
842
                        // Convert the row to a sqlc.Invoice so we can use the
×
UNCOV
843
                        // existing fetchInvoiceData function.
×
UNCOV
844
                        sqlInvoice := sqlc.Invoice{
×
UNCOV
845
                                ID:             ampInvoice.ID,
×
UNCOV
846
                                Hash:           ampInvoice.Hash,
×
UNCOV
847
                                Preimage:       ampInvoice.Preimage,
×
UNCOV
848
                                SettleIndex:    ampInvoice.AmpSettleIndex,
×
UNCOV
849
                                SettledAt:      ampInvoice.AmpSettledAt,
×
UNCOV
850
                                Memo:           ampInvoice.Memo,
×
UNCOV
851
                                AmountMsat:     ampInvoice.AmountMsat,
×
UNCOV
852
                                CltvDelta:      ampInvoice.CltvDelta,
×
UNCOV
853
                                Expiry:         ampInvoice.Expiry,
×
UNCOV
854
                                PaymentAddr:    ampInvoice.PaymentAddr,
×
UNCOV
855
                                PaymentRequest: ampInvoice.PaymentRequest,
×
UNCOV
856
                                State:          ampInvoice.State,
×
UNCOV
857
                                AmountPaidMsat: ampInvoice.AmountPaidMsat,
×
UNCOV
858
                                IsAmp:          ampInvoice.IsAmp,
×
UNCOV
859
                                IsHodl:         ampInvoice.IsHodl,
×
UNCOV
860
                                IsKeysend:      ampInvoice.IsKeysend,
×
UNCOV
861
                                CreatedAt:      ampInvoice.CreatedAt.UTC(),
×
UNCOV
862
                        }
×
UNCOV
863

×
UNCOV
864
                        // Fetch the state and HTLCs for this AMP sub invoice.
×
UNCOV
865
                        _, invoice, err := fetchInvoiceData(
×
UNCOV
866
                                ctx, db, sqlInvoice,
×
UNCOV
867
                                (*[32]byte)(ampInvoice.SetID), true,
×
UNCOV
868
                        )
×
UNCOV
869
                        if err != nil {
×
870
                                return fmt.Errorf("unable to fetch "+
×
871
                                        "AMP invoice(id=%d) from db: %w",
×
872
                                        ampInvoice.ID, err)
×
873
                        }
×
874

UNCOV
875
                        invoices = append(invoices, *invoice)
×
876
                }
877

UNCOV
878
                return nil
×
UNCOV
879
        }, func() {
×
UNCOV
880
                invoices = nil
×
UNCOV
881
        })
×
UNCOV
882
        if err != nil {
×
883
                return nil, fmt.Errorf("unable to get invoices settled since "+
×
884
                        "index (excluding) %d: %w", idx, err)
×
885
        }
×
886

UNCOV
887
        return invoices, nil
×
888
}
889

890
// InvoicesAddedSince can be used by callers to seek into the event time series
891
// of all the invoices added in the database. This method will return all
892
// invoices with an add index greater than the specified idx.
893
//
894
// NOTE: The index starts from 1. As a result we enforce that specifying a value
895
// below the starting index value is a noop.
896
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
UNCOV
897
        []Invoice, error) {
×
UNCOV
898

×
UNCOV
899
        var result []Invoice
×
UNCOV
900

×
UNCOV
901
        if idx == 0 {
×
UNCOV
902
                return result, nil
×
UNCOV
903
        }
×
904

UNCOV
905
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
906
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
907
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
908
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
909
                                AddIndexGet: sqldb.SQLInt64(idx + 1),
×
UNCOV
910
                                NumOffset:   int32(offset),
×
UNCOV
911
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
912
                                Reverse:     false,
×
UNCOV
913
                        }
×
UNCOV
914

×
UNCOV
915
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
916
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
917
                                return 0, fmt.Errorf("unable to get invoices "+
×
918
                                        "from db: %w", err)
×
919
                        }
×
920

921
                        // Load all the information for the invoices.
UNCOV
922
                        for _, row := range rows {
×
UNCOV
923
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
924
                                        ctx, db, row, nil, true,
×
UNCOV
925
                                )
×
UNCOV
926
                                if err != nil {
×
927
                                        return 0, err
×
928
                                }
×
929

UNCOV
930
                                result = append(result, *invoice)
×
931
                        }
932

UNCOV
933
                        return len(rows), nil
×
934
                }, i.opts.paginationLimit)
UNCOV
935
        }, func() {
×
UNCOV
936
                result = nil
×
UNCOV
937
        })
×
938

UNCOV
939
        if err != nil {
×
940
                return nil, fmt.Errorf("unable to get invoices added since "+
×
941
                        "index %d: %w", idx, err)
×
942
        }
×
943

UNCOV
944
        return result, nil
×
945
}
946

947
// QueryInvoices allows a caller to query the invoice database for invoices
948
// within the specified add index range.
949
func (i *SQLStore) QueryInvoices(ctx context.Context,
UNCOV
950
        q InvoiceQuery) (InvoiceSlice, error) {
×
UNCOV
951

×
UNCOV
952
        var invoices []Invoice
×
UNCOV
953

×
UNCOV
954
        if q.NumMaxInvoices == 0 {
×
955
                return InvoiceSlice{}, fmt.Errorf("max invoices must " +
×
956
                        "be non-zero")
×
957
        }
×
958

UNCOV
959
        readTxOpt := NewSQLInvoiceQueryReadTx()
×
UNCOV
960
        err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
961
                return queryWithLimit(func(offset int) (int, error) {
×
UNCOV
962
                        params := sqlc.FilterInvoicesParams{
×
UNCOV
963
                                NumOffset:   int32(offset),
×
UNCOV
964
                                NumLimit:    int32(i.opts.paginationLimit),
×
UNCOV
965
                                PendingOnly: q.PendingOnly,
×
UNCOV
966
                                Reverse:     q.Reversed,
×
UNCOV
967
                        }
×
UNCOV
968

×
UNCOV
969
                        if q.Reversed {
×
UNCOV
970
                                // If the index offset was not set, we want to
×
UNCOV
971
                                // fetch from the lastest invoice.
×
UNCOV
972
                                if q.IndexOffset == 0 {
×
UNCOV
973
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
974
                                                int64(math.MaxInt64),
×
UNCOV
975
                                        )
×
UNCOV
976
                                } else {
×
UNCOV
977
                                        // The invoice with index offset id must
×
UNCOV
978
                                        // not be included in the results.
×
UNCOV
979
                                        params.AddIndexLet = sqldb.SQLInt64(
×
UNCOV
980
                                                q.IndexOffset - 1,
×
UNCOV
981
                                        )
×
UNCOV
982
                                }
×
UNCOV
983
                        } else {
×
UNCOV
984
                                // The invoice with index offset id must not be
×
UNCOV
985
                                // included in the results.
×
UNCOV
986
                                params.AddIndexGet = sqldb.SQLInt64(
×
UNCOV
987
                                        q.IndexOffset + 1,
×
UNCOV
988
                                )
×
UNCOV
989
                        }
×
990

UNCOV
991
                        if q.CreationDateStart != 0 {
×
UNCOV
992
                                params.CreatedAfter = sqldb.SQLTime(
×
UNCOV
993
                                        time.Unix(q.CreationDateStart, 0).UTC(),
×
UNCOV
994
                                )
×
UNCOV
995
                        }
×
996

UNCOV
997
                        if q.CreationDateEnd != 0 {
×
UNCOV
998
                                // We need to add 1 to the end date as we're
×
UNCOV
999
                                // checking less than the end date in SQL.
×
UNCOV
1000
                                params.CreatedBefore = sqldb.SQLTime(
×
UNCOV
1001
                                        time.Unix(q.CreationDateEnd+1, 0).UTC(),
×
UNCOV
1002
                                )
×
UNCOV
1003
                        }
×
1004

UNCOV
1005
                        rows, err := db.FilterInvoices(ctx, params)
×
UNCOV
1006
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1007
                                return 0, fmt.Errorf("unable to get invoices "+
×
1008
                                        "from db: %w", err)
×
1009
                        }
×
1010

1011
                        // Load all the information for the invoices.
UNCOV
1012
                        for _, row := range rows {
×
UNCOV
1013
                                _, invoice, err := fetchInvoiceData(
×
UNCOV
1014
                                        ctx, db, row, nil, true,
×
UNCOV
1015
                                )
×
UNCOV
1016
                                if err != nil {
×
1017
                                        return 0, err
×
1018
                                }
×
1019

UNCOV
1020
                                invoices = append(invoices, *invoice)
×
UNCOV
1021

×
UNCOV
1022
                                if len(invoices) == int(q.NumMaxInvoices) {
×
UNCOV
1023
                                        return 0, nil
×
UNCOV
1024
                                }
×
1025
                        }
1026

UNCOV
1027
                        return len(rows), nil
×
1028
                }, i.opts.paginationLimit)
UNCOV
1029
        }, func() {
×
UNCOV
1030
                invoices = nil
×
UNCOV
1031
        })
×
UNCOV
1032
        if err != nil {
×
1033
                return InvoiceSlice{}, fmt.Errorf("unable to query "+
×
1034
                        "invoices: %w", err)
×
1035
        }
×
1036

UNCOV
1037
        if len(invoices) == 0 {
×
UNCOV
1038
                return InvoiceSlice{
×
UNCOV
1039
                        InvoiceQuery: q,
×
UNCOV
1040
                }, nil
×
UNCOV
1041
        }
×
1042

1043
        // If we iterated through the add index in reverse order, then
1044
        // we'll need to reverse the slice of invoices to return them in
1045
        // forward order.
UNCOV
1046
        if q.Reversed {
×
UNCOV
1047
                numInvoices := len(invoices)
×
UNCOV
1048
                for i := 0; i < numInvoices/2; i++ {
×
UNCOV
1049
                        reverse := numInvoices - i - 1
×
UNCOV
1050
                        invoices[i], invoices[reverse] =
×
UNCOV
1051
                                invoices[reverse], invoices[i]
×
UNCOV
1052
                }
×
1053
        }
1054

UNCOV
1055
        res := InvoiceSlice{
×
UNCOV
1056
                InvoiceQuery:     q,
×
UNCOV
1057
                Invoices:         invoices,
×
UNCOV
1058
                FirstIndexOffset: invoices[0].AddIndex,
×
UNCOV
1059
                LastIndexOffset:  invoices[len(invoices)-1].AddIndex,
×
UNCOV
1060
        }
×
UNCOV
1061

×
UNCOV
1062
        return res, nil
×
1063
}
1064

1065
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
1066
// a SQL database as the backend.
1067
type sqlInvoiceUpdater struct {
1068
        db         SQLInvoiceQueries
1069
        ctx        context.Context //nolint:containedctx
1070
        invoice    *Invoice
1071
        updateTime time.Time
1072
}
1073

1074
// AddHtlc adds a new htlc to the invoice.
1075
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
UNCOV
1076
        newHtlc *InvoiceHTLC) error {
×
UNCOV
1077

×
UNCOV
1078
        htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
×
UNCOV
1079
                s.ctx, sqlc.InsertInvoiceHTLCParams{
×
UNCOV
1080
                        HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1081
                        ChanID: strconv.FormatUint(
×
UNCOV
1082
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1083
                        ),
×
UNCOV
1084
                        AmountMsat: int64(newHtlc.Amt),
×
UNCOV
1085
                        TotalMppMsat: sql.NullInt64{
×
UNCOV
1086
                                Int64: int64(newHtlc.MppTotalAmt),
×
UNCOV
1087
                                Valid: newHtlc.MppTotalAmt != 0,
×
UNCOV
1088
                        },
×
UNCOV
1089
                        AcceptHeight: int32(newHtlc.AcceptHeight),
×
UNCOV
1090
                        AcceptTime:   newHtlc.AcceptTime.UTC(),
×
UNCOV
1091
                        ExpiryHeight: int32(newHtlc.Expiry),
×
UNCOV
1092
                        State:        int16(newHtlc.State),
×
UNCOV
1093
                        InvoiceID:    int64(s.invoice.AddIndex),
×
UNCOV
1094
                },
×
UNCOV
1095
        )
×
UNCOV
1096
        if err != nil {
×
1097
                return err
×
1098
        }
×
1099

UNCOV
1100
        for key, value := range newHtlc.CustomRecords {
×
UNCOV
1101
                err = s.db.InsertInvoiceHTLCCustomRecord(
×
UNCOV
1102
                        s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
×
UNCOV
1103
                                // TODO(bhandras): schema might be wrong here
×
UNCOV
1104
                                // as the custom record key is an uint64.
×
UNCOV
1105
                                Key:    int64(key),
×
UNCOV
1106
                                Value:  value,
×
UNCOV
1107
                                HtlcID: htlcPrimaryKeyID,
×
UNCOV
1108
                        },
×
UNCOV
1109
                )
×
UNCOV
1110
                if err != nil {
×
1111
                        return err
×
1112
                }
×
1113
        }
1114

UNCOV
1115
        if newHtlc.AMP != nil {
×
UNCOV
1116
                setID := newHtlc.AMP.Record.SetID()
×
UNCOV
1117

×
UNCOV
1118
                upsertResult, err := s.db.UpsertAMPSubInvoice(
×
UNCOV
1119
                        s.ctx, sqlc.UpsertAMPSubInvoiceParams{
×
UNCOV
1120
                                SetID:     setID[:],
×
UNCOV
1121
                                CreatedAt: s.updateTime.UTC(),
×
UNCOV
1122
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1123
                        },
×
UNCOV
1124
                )
×
UNCOV
1125
                if err != nil {
×
UNCOV
1126
                        mappedSQLErr := sqldb.MapSQLError(err)
×
UNCOV
1127
                        var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
×
UNCOV
1128
                        if errors.As(mappedSQLErr, &uniqueConstraintErr) {
×
UNCOV
1129
                                return ErrDuplicateSetID{
×
UNCOV
1130
                                        SetID: setID,
×
UNCOV
1131
                                }
×
UNCOV
1132
                        }
×
1133

1134
                        return err
×
1135
                }
1136

1137
                // If we're just inserting the AMP invoice, we'll get a non
1138
                // zero rows affected count.
UNCOV
1139
                rowsAffected, err := upsertResult.RowsAffected()
×
UNCOV
1140
                if err != nil {
×
1141
                        return err
×
1142
                }
×
UNCOV
1143
                if rowsAffected != 0 {
×
UNCOV
1144
                        // If we're inserting a new AMP invoice, we'll also
×
UNCOV
1145
                        // insert a new invoice event.
×
UNCOV
1146
                        err = s.db.OnAMPSubInvoiceCreated(
×
UNCOV
1147
                                s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
×
UNCOV
1148
                                        AddedAt:   s.updateTime.UTC(),
×
UNCOV
1149
                                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1150
                                        SetID:     setID[:],
×
UNCOV
1151
                                },
×
UNCOV
1152
                        )
×
UNCOV
1153
                        if err != nil {
×
1154
                                return err
×
1155
                        }
×
1156
                }
1157

UNCOV
1158
                rootShare := newHtlc.AMP.Record.RootShare()
×
UNCOV
1159

×
UNCOV
1160
                ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
×
UNCOV
1161
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1162
                        SetID:     setID[:],
×
UNCOV
1163
                        HtlcID:    htlcPrimaryKeyID,
×
UNCOV
1164
                        RootShare: rootShare[:],
×
UNCOV
1165
                        ChildIndex: int64(
×
UNCOV
1166
                                newHtlc.AMP.Record.ChildIndex(),
×
UNCOV
1167
                        ),
×
UNCOV
1168
                        Hash: newHtlc.AMP.Hash[:],
×
UNCOV
1169
                }
×
UNCOV
1170

×
UNCOV
1171
                if newHtlc.AMP.Preimage != nil {
×
UNCOV
1172
                        ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
×
UNCOV
1173
                }
×
1174

UNCOV
1175
                err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
×
UNCOV
1176
                if err != nil {
×
1177
                        return err
×
1178
                }
×
1179
        }
1180

UNCOV
1181
        return nil
×
1182
}
1183

1184
// ResolveHtlc marks an htlc as resolved with the given state.
1185
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
UNCOV
1186
        state HtlcState, resolveTime time.Time) error {
×
UNCOV
1187

×
UNCOV
1188
        return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
×
UNCOV
1189
                HtlcID: int64(circuitKey.HtlcID),
×
UNCOV
1190
                ChanID: strconv.FormatUint(
×
UNCOV
1191
                        circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1192
                ),
×
UNCOV
1193
                InvoiceID:   int64(s.invoice.AddIndex),
×
UNCOV
1194
                State:       int16(state),
×
UNCOV
1195
                ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
×
UNCOV
1196
        })
×
UNCOV
1197
}
×
1198

1199
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
1200
// identified by the setID.
1201
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
UNCOV
1202
        circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
×
UNCOV
1203

×
UNCOV
1204
        result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
×
UNCOV
1205
                s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
×
UNCOV
1206
                        InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1207
                        SetID:     setID[:],
×
UNCOV
1208
                        HtlcID:    int64(circuitKey.HtlcID),
×
UNCOV
1209
                        Preimage:  preimage[:],
×
UNCOV
1210
                        ChanID: strconv.FormatUint(
×
UNCOV
1211
                                circuitKey.ChanID.ToUint64(), 10,
×
UNCOV
1212
                        ),
×
UNCOV
1213
                },
×
UNCOV
1214
        )
×
UNCOV
1215
        if err != nil {
×
1216
                return err
×
1217
        }
×
1218

UNCOV
1219
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1220
        if err != nil {
×
1221
                return err
×
1222
        }
×
UNCOV
1223
        if rowsAffected == 0 {
×
1224
                return ErrInvoiceNotFound
×
1225
        }
×
1226

UNCOV
1227
        return nil
×
1228
}
1229

1230
// UpdateInvoiceState updates the invoice state to the new state.
1231
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
UNCOV
1232
        newState ContractState, preimage *lntypes.Preimage) error {
×
UNCOV
1233

×
UNCOV
1234
        var (
×
UNCOV
1235
                settleIndex sql.NullInt64
×
UNCOV
1236
                settledAt   sql.NullTime
×
UNCOV
1237
        )
×
UNCOV
1238

×
UNCOV
1239
        switch newState {
×
UNCOV
1240
        case ContractSettled:
×
UNCOV
1241
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1242
                if err != nil {
×
1243
                        return err
×
1244
                }
×
1245

UNCOV
1246
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1247

×
UNCOV
1248
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1249
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1250

×
UNCOV
1251
                err = s.db.OnInvoiceSettled(
×
UNCOV
1252
                        s.ctx, sqlc.OnInvoiceSettledParams{
×
UNCOV
1253
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1254
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1255
                        },
×
UNCOV
1256
                )
×
UNCOV
1257
                if err != nil {
×
1258
                        return err
×
1259
                }
×
1260

UNCOV
1261
        case ContractCanceled:
×
UNCOV
1262
                err := s.db.OnInvoiceCanceled(
×
UNCOV
1263
                        s.ctx, sqlc.OnInvoiceCanceledParams{
×
UNCOV
1264
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1265
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1266
                        },
×
UNCOV
1267
                )
×
UNCOV
1268
                if err != nil {
×
1269
                        return err
×
1270
                }
×
1271
        }
1272

UNCOV
1273
        params := sqlc.UpdateInvoiceStateParams{
×
UNCOV
1274
                ID:          int64(s.invoice.AddIndex),
×
UNCOV
1275
                State:       int16(newState),
×
UNCOV
1276
                SettleIndex: settleIndex,
×
UNCOV
1277
                SettledAt:   settledAt,
×
UNCOV
1278
        }
×
UNCOV
1279

×
UNCOV
1280
        if preimage != nil {
×
UNCOV
1281
                params.Preimage = preimage[:]
×
UNCOV
1282
        }
×
1283

UNCOV
1284
        result, err := s.db.UpdateInvoiceState(s.ctx, params)
×
UNCOV
1285
        if err != nil {
×
UNCOV
1286
                return err
×
UNCOV
1287
        }
×
UNCOV
1288
        rowsAffected, err := result.RowsAffected()
×
UNCOV
1289
        if err != nil {
×
1290
                return err
×
1291
        }
×
1292

UNCOV
1293
        if rowsAffected == 0 {
×
1294
                return ErrInvoiceNotFound
×
1295
        }
×
1296

UNCOV
1297
        if settleIndex.Valid {
×
UNCOV
1298
                s.invoice.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1299
                s.invoice.SettleDate = s.updateTime
×
UNCOV
1300
        }
×
1301

UNCOV
1302
        return nil
×
1303
}
1304

1305
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
1306
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
UNCOV
1307
        amtPaid lnwire.MilliSatoshi) error {
×
UNCOV
1308

×
UNCOV
1309
        _, err := s.db.UpdateInvoiceAmountPaid(
×
UNCOV
1310
                s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
×
UNCOV
1311
                        ID:             int64(s.invoice.AddIndex),
×
UNCOV
1312
                        AmountPaidMsat: int64(amtPaid),
×
UNCOV
1313
                },
×
UNCOV
1314
        )
×
UNCOV
1315

×
UNCOV
1316
        return err
×
UNCOV
1317
}
×
1318

1319
// UpdateAmpState updates the state of the AMP sub invoice identified by the
1320
// setID.
1321
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
UNCOV
1322
        newState InvoiceStateAMP, _ models.CircuitKey) error {
×
UNCOV
1323

×
UNCOV
1324
        var (
×
UNCOV
1325
                settleIndex sql.NullInt64
×
UNCOV
1326
                settledAt   sql.NullTime
×
UNCOV
1327
        )
×
UNCOV
1328

×
UNCOV
1329
        switch newState.State {
×
UNCOV
1330
        case HtlcStateSettled:
×
UNCOV
1331
                nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
×
UNCOV
1332
                if err != nil {
×
1333
                        return err
×
1334
                }
×
1335

UNCOV
1336
                settleIndex = sqldb.SQLInt64(nextSettleIndex)
×
UNCOV
1337

×
UNCOV
1338
                // If the invoice is settled, we'll also update the settle time.
×
UNCOV
1339
                settledAt = sqldb.SQLTime(s.updateTime.UTC())
×
UNCOV
1340

×
UNCOV
1341
                err = s.db.OnAMPSubInvoiceSettled(
×
UNCOV
1342
                        s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
×
UNCOV
1343
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1344
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1345
                                SetID:     setID[:],
×
UNCOV
1346
                        },
×
UNCOV
1347
                )
×
UNCOV
1348
                if err != nil {
×
1349
                        return err
×
1350
                }
×
1351

UNCOV
1352
        case HtlcStateCanceled:
×
UNCOV
1353
                err := s.db.OnAMPSubInvoiceCanceled(
×
UNCOV
1354
                        s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
×
UNCOV
1355
                                AddedAt:   s.updateTime.UTC(),
×
UNCOV
1356
                                InvoiceID: int64(s.invoice.AddIndex),
×
UNCOV
1357
                                SetID:     setID[:],
×
UNCOV
1358
                        },
×
UNCOV
1359
                )
×
UNCOV
1360
                if err != nil {
×
1361
                        return err
×
1362
                }
×
1363
        }
1364

UNCOV
1365
        err := s.db.UpdateAMPSubInvoiceState(
×
UNCOV
1366
                s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
×
UNCOV
1367
                        SetID:       setID[:],
×
UNCOV
1368
                        State:       int16(newState.State),
×
UNCOV
1369
                        SettleIndex: settleIndex,
×
UNCOV
1370
                        SettledAt:   settledAt,
×
UNCOV
1371
                },
×
UNCOV
1372
        )
×
UNCOV
1373
        if err != nil {
×
1374
                return err
×
1375
        }
×
1376

UNCOV
1377
        if settleIndex.Valid {
×
UNCOV
1378
                updatedState := s.invoice.AMPState[setID]
×
UNCOV
1379
                updatedState.SettleIndex = uint64(settleIndex.Int64)
×
UNCOV
1380
                updatedState.SettleDate = s.updateTime.UTC()
×
UNCOV
1381
                s.invoice.AMPState[setID] = updatedState
×
UNCOV
1382
        }
×
1383

UNCOV
1384
        return nil
×
1385
}
1386

1387
// Finalize finalizes the update before it is written to the database. Note that
1388
// we don't use this directly in the SQL implementation, so the function is just
1389
// a stub.
UNCOV
1390
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
×
UNCOV
1391
        return nil
×
UNCOV
1392
}
×
1393

1394
// UpdateInvoice attempts to update an invoice corresponding to the passed
1395
// reference. If an invoice matching the passed reference doesn't exist within
1396
// the database, then the action will fail with  ErrInvoiceNotFound error.
1397
//
1398
// The update is performed inside the same database transaction that fetches the
1399
// invoice and is therefore atomic. The fields to update are controlled by the
1400
// supplied callback.
1401
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
1402
        setID *SetID, callback InvoiceUpdateCallback) (
UNCOV
1403
        *Invoice, error) {
×
UNCOV
1404

×
UNCOV
1405
        var updatedInvoice *Invoice
×
UNCOV
1406

×
UNCOV
1407
        txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
×
UNCOV
1408
        txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1409
                switch {
×
1410
                // For the default case we fetch all HTLCs.
UNCOV
1411
                case setID == nil:
×
UNCOV
1412
                        ref.refModifier = DefaultModifier
×
1413

1414
                // If the setID is the blank but NOT nil, we set the
1415
                // refModifier to HtlcSetBlankModifier to fetch no HTLC for the
1416
                // AMP invoice.
1417
                case *setID == BlankPayAddr:
×
1418
                        ref.refModifier = HtlcSetBlankModifier
×
1419

1420
                // A setID is provided, we use the refModifier to fetch only
1421
                // the HTLCs for the given setID and also make sure we add the
1422
                // setID to the ref.
UNCOV
1423
                default:
×
UNCOV
1424
                        var setIDBytes [32]byte
×
UNCOV
1425
                        copy(setIDBytes[:], setID[:])
×
UNCOV
1426
                        ref.setID = &setIDBytes
×
UNCOV
1427

×
UNCOV
1428
                        // We only fetch the HTLCs for the given setID.
×
UNCOV
1429
                        ref.refModifier = HtlcSetOnlyModifier
×
1430
                }
1431

UNCOV
1432
                invoice, err := fetchInvoice(ctx, db, ref)
×
UNCOV
1433
                if err != nil {
×
UNCOV
1434
                        return err
×
UNCOV
1435
                }
×
1436

UNCOV
1437
                updateTime := i.clock.Now()
×
UNCOV
1438
                updater := &sqlInvoiceUpdater{
×
UNCOV
1439
                        db:         db,
×
UNCOV
1440
                        ctx:        ctx,
×
UNCOV
1441
                        invoice:    invoice,
×
UNCOV
1442
                        updateTime: updateTime,
×
UNCOV
1443
                }
×
UNCOV
1444

×
UNCOV
1445
                payHash := ref.PayHash()
×
UNCOV
1446
                updatedInvoice, err = UpdateInvoice(
×
UNCOV
1447
                        payHash, invoice, updateTime, callback, updater,
×
UNCOV
1448
                )
×
UNCOV
1449

×
UNCOV
1450
                return err
×
UNCOV
1451
        }, func() {})
×
UNCOV
1452
        if txErr != nil {
×
UNCOV
1453
                // If the invoice is already settled, we'll return the
×
UNCOV
1454
                // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
×
UNCOV
1455
                if errors.Is(txErr, ErrInvoiceAlreadySettled) {
×
UNCOV
1456
                        return updatedInvoice, txErr
×
UNCOV
1457
                }
×
1458

UNCOV
1459
                return nil, txErr
×
1460
        }
1461

UNCOV
1462
        return updatedInvoice, nil
×
1463
}
1464

1465
// DeleteInvoice attempts to delete the passed invoices and all their related
1466
// data from the database in one transaction.
1467
func (i *SQLStore) DeleteInvoice(ctx context.Context,
UNCOV
1468
        invoicesToDelete []InvoiceDeleteRef) error {
×
UNCOV
1469

×
UNCOV
1470
        // All the InvoiceDeleteRef instances include the add index of the
×
UNCOV
1471
        // invoice. The rest was added to ensure that the invoices were deleted
×
UNCOV
1472
        // properly in the kv database. When we have fully migrated we can
×
UNCOV
1473
        // remove the rest of the fields.
×
UNCOV
1474
        for _, ref := range invoicesToDelete {
×
UNCOV
1475
                if ref.AddIndex == 0 {
×
1476
                        return fmt.Errorf("unable to delete invoice using a "+
×
1477
                                "ref without AddIndex set: %v", ref)
×
1478
                }
×
1479
        }
1480

UNCOV
1481
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
UNCOV
1482
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1483
                for _, ref := range invoicesToDelete {
×
UNCOV
1484
                        params := sqlc.DeleteInvoiceParams{
×
UNCOV
1485
                                AddIndex: sqldb.SQLInt64(ref.AddIndex),
×
UNCOV
1486
                        }
×
UNCOV
1487

×
UNCOV
1488
                        if ref.SettleIndex != 0 {
×
UNCOV
1489
                                params.SettleIndex = sqldb.SQLInt64(
×
UNCOV
1490
                                        ref.SettleIndex,
×
UNCOV
1491
                                )
×
UNCOV
1492
                        }
×
1493

UNCOV
1494
                        if ref.PayHash != lntypes.ZeroHash {
×
UNCOV
1495
                                params.Hash = ref.PayHash[:]
×
UNCOV
1496
                        }
×
1497

UNCOV
1498
                        result, err := db.DeleteInvoice(ctx, params)
×
UNCOV
1499
                        if err != nil {
×
1500
                                return fmt.Errorf("unable to delete "+
×
1501
                                        "invoice(%v): %w", ref.AddIndex, err)
×
1502
                        }
×
UNCOV
1503
                        rowsAffected, err := result.RowsAffected()
×
UNCOV
1504
                        if err != nil {
×
1505
                                return fmt.Errorf("unable to get rows "+
×
1506
                                        "affected: %w", err)
×
1507
                        }
×
UNCOV
1508
                        if rowsAffected == 0 {
×
UNCOV
1509
                                return fmt.Errorf("%w: %v",
×
UNCOV
1510
                                        ErrInvoiceNotFound, ref.AddIndex)
×
UNCOV
1511
                        }
×
1512
                }
1513

UNCOV
1514
                return nil
×
UNCOV
1515
        }, func() {})
×
1516

UNCOV
1517
        if err != nil {
×
UNCOV
1518
                return fmt.Errorf("unable to delete invoices: %w", err)
×
UNCOV
1519
        }
×
1520

UNCOV
1521
        return nil
×
1522
}
1523

1524
// DeleteCanceledInvoices removes all canceled invoices from the database.
UNCOV
1525
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
×
UNCOV
1526
        var writeTxOpt SQLInvoiceQueriesTxOptions
×
UNCOV
1527
        err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
×
UNCOV
1528
                _, err := db.DeleteCanceledInvoices(ctx)
×
UNCOV
1529
                if err != nil {
×
1530
                        return fmt.Errorf("unable to delete canceled "+
×
1531
                                "invoices: %w", err)
×
1532
                }
×
1533

UNCOV
1534
                return nil
×
UNCOV
1535
        }, func() {})
×
UNCOV
1536
        if err != nil {
×
1537
                return fmt.Errorf("unable to delete invoices: %w", err)
×
1538
        }
×
1539

UNCOV
1540
        return nil
×
1541
}
1542

1543
// fetchInvoiceData fetches additional data for the given invoice. If the
1544
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
1545
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
1546
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
1547
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
1548
        row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
UNCOV
1549
        *Invoice, error) {
×
UNCOV
1550

×
UNCOV
1551
        // Unmarshal the common data.
×
UNCOV
1552
        hash, invoice, err := unmarshalInvoice(row)
×
UNCOV
1553
        if err != nil {
×
1554
                return nil, nil, fmt.Errorf("unable to unmarshal "+
×
1555
                        "invoice(id=%d) from db: %w", row.ID, err)
×
1556
        }
×
1557

1558
        // Fetch the invoice features.
UNCOV
1559
        features, err := getInvoiceFeatures(ctx, db, row.ID)
×
UNCOV
1560
        if err != nil {
×
1561
                return nil, nil, err
×
1562
        }
×
1563

UNCOV
1564
        invoice.Terms.Features = features
×
UNCOV
1565

×
UNCOV
1566
        // If this is an AMP invoice, we'll need fetch the AMP state along
×
UNCOV
1567
        // with the HTLCs (if requested).
×
UNCOV
1568
        if invoice.IsAMP() {
×
UNCOV
1569
                invoiceID := int64(invoice.AddIndex)
×
UNCOV
1570
                ampState, ampHtlcs, err := fetchAmpState(
×
UNCOV
1571
                        ctx, db, invoiceID, setID, fetchAmpHtlcs,
×
UNCOV
1572
                )
×
UNCOV
1573
                if err != nil {
×
1574
                        return nil, nil, err
×
1575
                }
×
1576

UNCOV
1577
                invoice.AMPState = ampState
×
UNCOV
1578
                invoice.Htlcs = ampHtlcs
×
UNCOV
1579

×
UNCOV
1580
                return hash, invoice, nil
×
1581
        }
1582

1583
        // Otherwise simply fetch the invoice HTLCs.
UNCOV
1584
        htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
×
UNCOV
1585
        if err != nil {
×
1586
                return nil, nil, err
×
1587
        }
×
1588

UNCOV
1589
        if len(htlcs) > 0 {
×
UNCOV
1590
                invoice.Htlcs = htlcs
×
UNCOV
1591
        }
×
1592

UNCOV
1593
        return hash, invoice, nil
×
1594
}
1595

1596
// getInvoiceFeatures fetches the invoice features for the given invoice id.
1597
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1598
        invoiceID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
1599

×
UNCOV
1600
        rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
×
UNCOV
1601
        if err != nil {
×
1602
                return nil, fmt.Errorf("unable to get invoice features: %w",
×
1603
                        err)
×
1604
        }
×
1605

UNCOV
1606
        features := lnwire.EmptyFeatureVector()
×
UNCOV
1607
        for _, feature := range rows {
×
UNCOV
1608
                features.Set(lnwire.FeatureBit(feature.Feature))
×
UNCOV
1609
        }
×
1610

UNCOV
1611
        return features, nil
×
1612
}
1613

1614
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
1615
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
UNCOV
1616
        invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
×
UNCOV
1617

×
UNCOV
1618
        htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
×
UNCOV
1619
        if err != nil {
×
1620
                return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
×
1621
        }
×
1622

1623
        // We have no htlcs to unmarshal.
UNCOV
1624
        if len(htlcRows) == 0 {
×
UNCOV
1625
                return nil, nil
×
UNCOV
1626
        }
×
1627

UNCOV
1628
        crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
×
UNCOV
1629
        if err != nil {
×
1630
                return nil, fmt.Errorf("unable to get custom records for "+
×
1631
                        "invoice htlcs: %w", err)
×
1632
        }
×
1633

UNCOV
1634
        cr := make(map[int64]record.CustomSet, len(crRows))
×
UNCOV
1635
        for _, row := range crRows {
×
UNCOV
1636
                if _, ok := cr[row.HtlcID]; !ok {
×
UNCOV
1637
                        cr[row.HtlcID] = make(record.CustomSet)
×
UNCOV
1638
                }
×
1639

UNCOV
1640
                value := row.Value
×
UNCOV
1641
                if value == nil {
×
UNCOV
1642
                        value = []byte{}
×
UNCOV
1643
                }
×
UNCOV
1644
                cr[row.HtlcID][uint64(row.Key)] = value
×
1645
        }
1646

UNCOV
1647
        htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
×
UNCOV
1648

×
UNCOV
1649
        for _, row := range htlcRows {
×
UNCOV
1650
                circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
×
UNCOV
1651
                if err != nil {
×
1652
                        return nil, fmt.Errorf("unable to unmarshal "+
×
1653
                                "htlc(%d): %w", row.ID, err)
×
1654
                }
×
1655

UNCOV
1656
                if customRecords, ok := cr[row.ID]; ok {
×
UNCOV
1657
                        htlc.CustomRecords = customRecords
×
UNCOV
1658
                } else {
×
UNCOV
1659
                        htlc.CustomRecords = make(record.CustomSet)
×
UNCOV
1660
                }
×
1661

UNCOV
1662
                htlcs[circuiteKey] = htlc
×
1663
        }
1664

UNCOV
1665
        return htlcs, nil
×
1666
}
1667

1668
// unmarshalInvoice converts an InvoiceRow to an Invoice.
1669
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
UNCOV
1670
        error) {
×
UNCOV
1671

×
UNCOV
1672
        var (
×
UNCOV
1673
                settleIndex    int64
×
UNCOV
1674
                settledAt      time.Time
×
UNCOV
1675
                memo           []byte
×
UNCOV
1676
                paymentRequest []byte
×
UNCOV
1677
                preimage       *lntypes.Preimage
×
UNCOV
1678
                paymentAddr    [32]byte
×
UNCOV
1679
        )
×
UNCOV
1680

×
UNCOV
1681
        hash, err := lntypes.MakeHash(row.Hash)
×
UNCOV
1682
        if err != nil {
×
1683
                return nil, nil, err
×
1684
        }
×
1685

UNCOV
1686
        if row.SettleIndex.Valid {
×
UNCOV
1687
                settleIndex = row.SettleIndex.Int64
×
UNCOV
1688
        }
×
1689

UNCOV
1690
        if row.SettledAt.Valid {
×
UNCOV
1691
                settledAt = row.SettledAt.Time.Local()
×
UNCOV
1692
        }
×
1693

UNCOV
1694
        if row.Memo.Valid {
×
UNCOV
1695
                memo = []byte(row.Memo.String)
×
UNCOV
1696
        }
×
1697

1698
        // Keysend payments will have this field empty.
UNCOV
1699
        if row.PaymentRequest.Valid {
×
UNCOV
1700
                paymentRequest = []byte(row.PaymentRequest.String)
×
UNCOV
1701
        } else {
×
UNCOV
1702
                paymentRequest = []byte{}
×
UNCOV
1703
        }
×
1704

1705
        // We may not have the preimage if this a hodl invoice.
UNCOV
1706
        if row.Preimage != nil {
×
UNCOV
1707
                preimage = &lntypes.Preimage{}
×
UNCOV
1708
                copy(preimage[:], row.Preimage)
×
UNCOV
1709
        }
×
1710

UNCOV
1711
        copy(paymentAddr[:], row.PaymentAddr)
×
UNCOV
1712

×
UNCOV
1713
        var cltvDelta int32
×
UNCOV
1714
        if row.CltvDelta.Valid {
×
UNCOV
1715
                cltvDelta = row.CltvDelta.Int32
×
UNCOV
1716
        }
×
1717

UNCOV
1718
        expiry := time.Duration(row.Expiry) * time.Second
×
UNCOV
1719

×
UNCOV
1720
        invoice := &Invoice{
×
UNCOV
1721
                SettleIndex:    uint64(settleIndex),
×
UNCOV
1722
                SettleDate:     settledAt,
×
UNCOV
1723
                Memo:           memo,
×
UNCOV
1724
                PaymentRequest: paymentRequest,
×
UNCOV
1725
                CreationDate:   row.CreatedAt.Local(),
×
UNCOV
1726
                Terms: ContractTerm{
×
UNCOV
1727
                        FinalCltvDelta:  cltvDelta,
×
UNCOV
1728
                        Expiry:          expiry,
×
UNCOV
1729
                        PaymentPreimage: preimage,
×
UNCOV
1730
                        Value:           lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1731
                        PaymentAddr:     paymentAddr,
×
UNCOV
1732
                },
×
UNCOV
1733
                AddIndex:    uint64(row.ID),
×
UNCOV
1734
                State:       ContractState(row.State),
×
UNCOV
1735
                AmtPaid:     lnwire.MilliSatoshi(row.AmountPaidMsat),
×
UNCOV
1736
                Htlcs:       make(map[models.CircuitKey]*InvoiceHTLC),
×
UNCOV
1737
                AMPState:    AMPInvoiceState{},
×
UNCOV
1738
                HodlInvoice: row.IsHodl,
×
UNCOV
1739
        }
×
UNCOV
1740

×
UNCOV
1741
        return &hash, invoice, nil
×
1742
}
1743

1744
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
1745
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
UNCOV
1746
        *InvoiceHTLC, error) {
×
UNCOV
1747

×
UNCOV
1748
        uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
×
UNCOV
1749
        if err != nil {
×
1750
                return CircuitKey{}, nil, err
×
1751
        }
×
1752

UNCOV
1753
        chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
×
UNCOV
1754

×
UNCOV
1755
        if row.HtlcID < 0 {
×
1756
                return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
×
1757
                        "value: %v", row.HtlcID)
×
1758
        }
×
1759

UNCOV
1760
        htlcID := uint64(row.HtlcID)
×
UNCOV
1761

×
UNCOV
1762
        circuitKey := CircuitKey{
×
UNCOV
1763
                ChanID: chanID,
×
UNCOV
1764
                HtlcID: htlcID,
×
UNCOV
1765
        }
×
UNCOV
1766

×
UNCOV
1767
        htlc := &InvoiceHTLC{
×
UNCOV
1768
                Amt:          lnwire.MilliSatoshi(row.AmountMsat),
×
UNCOV
1769
                AcceptHeight: uint32(row.AcceptHeight),
×
UNCOV
1770
                AcceptTime:   row.AcceptTime.Local(),
×
UNCOV
1771
                Expiry:       uint32(row.ExpiryHeight),
×
UNCOV
1772
                State:        HtlcState(row.State),
×
UNCOV
1773
        }
×
UNCOV
1774

×
UNCOV
1775
        if row.TotalMppMsat.Valid {
×
UNCOV
1776
                htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
×
UNCOV
1777
        }
×
1778

UNCOV
1779
        if row.ResolveTime.Valid {
×
UNCOV
1780
                htlc.ResolveTime = row.ResolveTime.Time.Local()
×
UNCOV
1781
        }
×
1782

UNCOV
1783
        return circuitKey, htlc, nil
×
1784
}
1785

1786
// queryWithLimit is a helper method that can be used to query the database
1787
// using a limit and offset. The passed query function should return the number
1788
// of rows returned and an error if any.
UNCOV
1789
func queryWithLimit(query func(int) (int, error), limit int) error {
×
UNCOV
1790
        offset := 0
×
UNCOV
1791
        for {
×
UNCOV
1792
                rows, err := query(offset)
×
UNCOV
1793
                if err != nil {
×
1794
                        return err
×
1795
                }
×
1796

UNCOV
1797
                if rows < limit {
×
UNCOV
1798
                        return nil
×
UNCOV
1799
                }
×
1800

UNCOV
1801
                offset += limit
×
1802
        }
1803
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc