• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15732818454

18 Jun 2025 12:28PM UTC coverage: 68.406%. First build
15732818454

Pull #9959

github

web-flow
Merge bdff25e38 into a27bd69b9
Pull Request #9959: multi: add context.Context param to more graphdb.V1Store methods

49 of 72 new or added lines in 13 files covered. (68.06%)

134458 of 196559 relevant lines covered (68.41%)

22216.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "strconv"
13
        "sync"
14
        "time"
15

16
        "github.com/btcsuite/btcd/btcec/v2"
17
        "github.com/btcsuite/btcd/chaincfg/chainhash"
18
        "github.com/lightningnetwork/lnd/batch"
19
        "github.com/lightningnetwork/lnd/graph/db/models"
20
        "github.com/lightningnetwork/lnd/lnwire"
21
        "github.com/lightningnetwork/lnd/routing/route"
22
        "github.com/lightningnetwork/lnd/sqldb"
23
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
24
        "github.com/lightningnetwork/lnd/tlv"
25
        "github.com/lightningnetwork/lnd/tor"
26
)
27

28
// ProtocolVersion is an enum that defines the gossip protocol version of a
29
// message.
30
type ProtocolVersion uint8
31

32
const (
33
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
34
        ProtocolV1 ProtocolVersion = 1
35
)
36

37
// String returns a string representation of the protocol version.
38
func (v ProtocolVersion) String() string {
×
39
        return fmt.Sprintf("V%d", v)
×
40
}
×
41

42
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
43
// execute queries against the SQL graph tables.
44
//
45
//nolint:ll,interfacebloat
46
type SQLQueries interface {
47
        /*
48
                Node queries.
49
        */
50
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
51
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
52
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
53
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
54

55
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
56
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
57
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
58

59
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
60
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
61
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
62

63
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
64
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
65
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
66
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
67

68
        /*
69
                Source node queries.
70
        */
71
        AddSourceNode(ctx context.Context, nodeID int64) error
72
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
73

74
        /*
75
                Channel queries.
76
        */
77
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
78
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
79
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
80
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
81

82
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
83
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
84

85
        /*
86
                Channel Policy table queries.
87
        */
88
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
89

90
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
91
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
92
}
93

94
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
95
// database operations.
96
type BatchedSQLQueries interface {
97
        SQLQueries
98
        sqldb.BatchedTx[SQLQueries]
99
}
100

101
// SQLStore is an implementation of the V1Store interface that uses a SQL
102
// database as the backend.
103
//
104
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
105
// implement the V1Store interface incrementally. For any method not
106
// implemented,  things will fall back to the KVStore. This is ONLY the case
107
// for the time being while this struct is purely used in unit tests only.
108
type SQLStore struct {
109
        cfg *SQLStoreConfig
110
        db  BatchedSQLQueries
111

112
        // cacheMu guards all caches (rejectCache and chanCache). If
113
        // this mutex will be acquired at the same time as the DB mutex then
114
        // the cacheMu MUST be acquired first to prevent deadlock.
115
        cacheMu     sync.RWMutex
116
        rejectCache *rejectCache
117
        chanCache   *channelCache
118

119
        chanScheduler batch.Scheduler[SQLQueries]
120
        nodeScheduler batch.Scheduler[SQLQueries]
121

122
        // Temporary fall-back to the KVStore so that we can implement the
123
        // interface incrementally.
124
        *KVStore
125
}
126

127
// A compile-time assertion to ensure that SQLStore implements the V1Store
128
// interface.
129
var _ V1Store = (*SQLStore)(nil)
130

131
// SQLStoreConfig holds the configuration for the SQLStore.
132
type SQLStoreConfig struct {
133
        // ChainHash is the genesis hash for the chain that all the gossip
134
        // messages in this store are aimed at.
135
        ChainHash chainhash.Hash
136
}
137

138
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
139
// storage backend.
140
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
141
        options ...StoreOptionModifier) (*SQLStore, error) {
×
142

×
143
        opts := DefaultOptions()
×
144
        for _, o := range options {
×
145
                o(opts)
×
146
        }
×
147

148
        if opts.NoMigration {
×
149
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
150
                        "supported for SQL stores")
×
151
        }
×
152

153
        s := &SQLStore{
×
154
                cfg:         cfg,
×
155
                db:          db,
×
156
                KVStore:     kvStore,
×
157
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
158
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
159
        }
×
160

×
161
        s.chanScheduler = batch.NewTimeScheduler(
×
162
                db, &s.cacheMu, opts.BatchCommitInterval,
×
163
        )
×
164
        s.nodeScheduler = batch.NewTimeScheduler(
×
165
                db, nil, opts.BatchCommitInterval,
×
166
        )
×
167

×
168
        return s, nil
×
169
}
170

171
// AddLightningNode adds a vertex/node to the graph database. If the node is not
172
// in the database from before, this will add a new, unconnected one to the
173
// graph. If it is present from before, this will update that node's
174
// information.
175
//
176
// NOTE: part of the V1Store interface.
177
func (s *SQLStore) AddLightningNode(ctx context.Context,
178
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
179

×
180
        r := &batch.Request[SQLQueries]{
×
181
                Opts: batch.NewSchedulerOptions(opts...),
×
182
                Do: func(queries SQLQueries) error {
×
183
                        _, err := upsertNode(ctx, queries, node)
×
184
                        return err
×
185
                },
×
186
        }
187

188
        return s.nodeScheduler.Execute(ctx, r)
×
189
}
190

191
// FetchLightningNode attempts to look up a target node by its identity public
192
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
193
// returned.
194
//
195
// NOTE: part of the V1Store interface.
196
func (s *SQLStore) FetchLightningNode(ctx context.Context,
197
        pubKey route.Vertex) (*models.LightningNode, error) {
×
198

×
199
        var node *models.LightningNode
×
200
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
201
                var err error
×
202
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
203

×
204
                return err
×
205
        }, sqldb.NoOpReset)
×
206
        if err != nil {
×
207
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
208
        }
×
209

210
        return node, nil
×
211
}
212

213
// HasLightningNode determines if the graph has a vertex identified by the
214
// target node identity public key. If the node exists in the database, a
215
// timestamp of when the data for the node was lasted updated is returned along
216
// with a true boolean. Otherwise, an empty time.Time is returned with a false
217
// boolean.
218
//
219
// NOTE: part of the V1Store interface.
220
func (s *SQLStore) HasLightningNode(ctx context.Context,
221
        pubKey [33]byte) (time.Time, bool, error) {
×
222

×
223
        var (
×
224
                exists     bool
×
225
                lastUpdate time.Time
×
226
        )
×
227
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
228
                dbNode, err := db.GetNodeByPubKey(
×
229
                        ctx, sqlc.GetNodeByPubKeyParams{
×
230
                                Version: int16(ProtocolV1),
×
231
                                PubKey:  pubKey[:],
×
232
                        },
×
233
                )
×
234
                if errors.Is(err, sql.ErrNoRows) {
×
235
                        return nil
×
236
                } else if err != nil {
×
237
                        return fmt.Errorf("unable to fetch node: %w", err)
×
238
                }
×
239

240
                exists = true
×
241

×
242
                if dbNode.LastUpdate.Valid {
×
243
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
244
                }
×
245

246
                return nil
×
247
        }, sqldb.NoOpReset)
248
        if err != nil {
×
249
                return time.Time{}, false,
×
250
                        fmt.Errorf("unable to fetch node: %w", err)
×
251
        }
×
252

253
        return lastUpdate, exists, nil
×
254
}
255

256
// AddrsForNode returns all known addresses for the target node public key
257
// that the graph DB is aware of. The returned boolean indicates if the
258
// given node is unknown to the graph DB or not.
259
//
260
// NOTE: part of the V1Store interface.
261
func (s *SQLStore) AddrsForNode(ctx context.Context,
262
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
263

×
264
        var (
×
265
                addresses []net.Addr
×
266
                known     bool
×
267
        )
×
268
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
269
                var err error
×
270
                known, addresses, err = getNodeAddresses(
×
271
                        ctx, db, nodePub.SerializeCompressed(),
×
272
                )
×
273
                if err != nil {
×
274
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
275
                                err)
×
276
                }
×
277

278
                return nil
×
279
        }, sqldb.NoOpReset)
280
        if err != nil {
×
281
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
282
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
283
        }
×
284

285
        return known, addresses, nil
×
286
}
287

288
// DeleteLightningNode starts a new database transaction to remove a vertex/node
289
// from the database according to the node's public key.
290
//
291
// NOTE: part of the V1Store interface.
292
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
293
        pubKey route.Vertex) error {
×
294

×
295
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
296
                res, err := db.DeleteNodeByPubKey(
×
297
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
298
                                Version: int16(ProtocolV1),
×
299
                                PubKey:  pubKey[:],
×
300
                        },
×
301
                )
×
302
                if err != nil {
×
303
                        return err
×
304
                }
×
305

306
                rows, err := res.RowsAffected()
×
307
                if err != nil {
×
308
                        return err
×
309
                }
×
310

311
                if rows == 0 {
×
312
                        return ErrGraphNodeNotFound
×
313
                } else if rows > 1 {
×
314
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
315
                }
×
316

317
                return err
×
318
        }, sqldb.NoOpReset)
319
        if err != nil {
×
320
                return fmt.Errorf("unable to delete node: %w", err)
×
321
        }
×
322

323
        return nil
×
324
}
325

326
// FetchNodeFeatures returns the features of the given node. If no features are
327
// known for the node, an empty feature vector is returned.
328
//
329
// NOTE: this is part of the graphdb.NodeTraverser interface.
330
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
331
        *lnwire.FeatureVector, error) {
×
332

×
333
        ctx := context.TODO()
×
334

×
335
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
336
}
×
337

338
// LookupAlias attempts to return the alias as advertised by the target node.
339
//
340
// NOTE: part of the V1Store interface.
341
func (s *SQLStore) LookupAlias(ctx context.Context,
NEW
342
        pub *btcec.PublicKey) (string, error) {
×
NEW
343

×
NEW
344
        var alias string
×
345
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
346
                dbNode, err := db.GetNodeByPubKey(
×
347
                        ctx, sqlc.GetNodeByPubKeyParams{
×
348
                                Version: int16(ProtocolV1),
×
349
                                PubKey:  pub.SerializeCompressed(),
×
350
                        },
×
351
                )
×
352
                if errors.Is(err, sql.ErrNoRows) {
×
353
                        return ErrNodeAliasNotFound
×
354
                } else if err != nil {
×
355
                        return fmt.Errorf("unable to fetch node: %w", err)
×
356
                }
×
357

358
                if !dbNode.Alias.Valid {
×
359
                        return ErrNodeAliasNotFound
×
360
                }
×
361

362
                alias = dbNode.Alias.String
×
363

×
364
                return nil
×
365
        }, sqldb.NoOpReset)
366
        if err != nil {
×
367
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
368
        }
×
369

370
        return alias, nil
×
371
}
372

373
// SourceNode returns the source node of the graph. The source node is treated
374
// as the center node within a star-graph. This method may be used to kick off
375
// a path finding algorithm in order to explore the reachability of another
376
// node based off the source node.
377
//
378
// NOTE: part of the V1Store interface.
379
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
NEW
380
        error) {
×
381

×
382
        var node *models.LightningNode
×
383
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
384
                _, nodePub, err := getSourceNode(ctx, db, ProtocolV1)
×
385
                if err != nil {
×
386
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
387
                                err)
×
388
                }
×
389

390
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
391

×
392
                return err
×
393
        }, sqldb.NoOpReset)
394
        if err != nil {
×
395
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
396
        }
×
397

398
        return node, nil
×
399
}
400

401
// SetSourceNode sets the source node within the graph database. The source
402
// node is to be used as the center of a star-graph within path finding
403
// algorithms.
404
//
405
// NOTE: part of the V1Store interface.
406
func (s *SQLStore) SetSourceNode(ctx context.Context,
NEW
407
        node *models.LightningNode) error {
×
408

×
409
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
410
                id, err := upsertNode(ctx, db, node)
×
411
                if err != nil {
×
412
                        return fmt.Errorf("unable to upsert source node: %w",
×
413
                                err)
×
414
                }
×
415

416
                // Make sure that if a source node for this version is already
417
                // set, then the ID is the same as the one we are about to set.
418
                dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1)
×
419
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
420
                        return fmt.Errorf("unable to fetch source node: %w",
×
421
                                err)
×
422
                } else if err == nil {
×
423
                        if dbSourceNodeID != id {
×
424
                                return fmt.Errorf("v1 source node already "+
×
425
                                        "set to a different node: %d vs %d",
×
426
                                        dbSourceNodeID, id)
×
427
                        }
×
428

429
                        return nil
×
430
                }
431

432
                return db.AddSourceNode(ctx, id)
×
433
        }, sqldb.NoOpReset)
434
}
435

436
// NodeUpdatesInHorizon returns all the known lightning node which have an
437
// update timestamp within the passed range. This method can be used by two
438
// nodes to quickly determine if they have the same set of up to date node
439
// announcements.
440
//
441
// NOTE: This is part of the V1Store interface.
442
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
443
        endTime time.Time) ([]models.LightningNode, error) {
×
444

×
445
        ctx := context.TODO()
×
446

×
447
        var nodes []models.LightningNode
×
448
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
449
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
450
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
451
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
452
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
453
                        },
×
454
                )
×
455
                if err != nil {
×
456
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
457
                }
×
458

459
                for _, dbNode := range dbNodes {
×
460
                        node, err := buildNode(ctx, db, &dbNode)
×
461
                        if err != nil {
×
462
                                return fmt.Errorf("unable to build node: %w",
×
463
                                        err)
×
464
                        }
×
465

466
                        nodes = append(nodes, *node)
×
467
                }
468

469
                return nil
×
470
        }, sqldb.NoOpReset)
471
        if err != nil {
×
472
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
473
        }
×
474

475
        return nodes, nil
×
476
}
477

478
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
479
// undirected edge from the two target nodes are created. The information stored
480
// denotes the static attributes of the channel, such as the channelID, the keys
481
// involved in creation of the channel, and the set of features that the channel
482
// supports. The chanPoint and chanID are used to uniquely identify the edge
483
// globally within the database.
484
//
485
// NOTE: part of the V1Store interface.
486
func (s *SQLStore) AddChannelEdge(ctx context.Context,
NEW
487
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
488

×
489
        var alreadyExists bool
×
490
        r := &batch.Request[SQLQueries]{
×
491
                Opts: batch.NewSchedulerOptions(opts...),
×
492
                Reset: func() {
×
493
                        alreadyExists = false
×
494
                },
×
495
                Do: func(tx SQLQueries) error {
×
496
                        err := insertChannel(ctx, tx, edge)
×
497

×
498
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
499
                        // succeed, but propagate the error via local state.
×
500
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
501
                                alreadyExists = true
×
502
                                return nil
×
503
                        }
×
504

505
                        return err
×
506
                },
507
                OnCommit: func(err error) error {
×
508
                        switch {
×
509
                        case err != nil:
×
510
                                return err
×
511
                        case alreadyExists:
×
512
                                return ErrEdgeAlreadyExist
×
513
                        default:
×
514
                                s.rejectCache.remove(edge.ChannelID)
×
515
                                s.chanCache.remove(edge.ChannelID)
×
516
                                return nil
×
517
                        }
518
                },
519
        }
520

521
        return s.chanScheduler.Execute(ctx, r)
×
522
}
523

524
// HighestChanID returns the "highest" known channel ID in the channel graph.
525
// This represents the "newest" channel from the PoV of the chain. This method
526
// can be used by peers to quickly determine if their graphs are in sync.
527
//
528
// NOTE: This is part of the V1Store interface.
NEW
529
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
530
        var highestChanID uint64
×
531
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
532
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
533
                if errors.Is(err, sql.ErrNoRows) {
×
534
                        return nil
×
535
                } else if err != nil {
×
536
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
537
                                err)
×
538
                }
×
539

540
                highestChanID = byteOrder.Uint64(chanID)
×
541

×
542
                return nil
×
543
        }, sqldb.NoOpReset)
544
        if err != nil {
×
545
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
546
        }
×
547

548
        return highestChanID, nil
×
549
}
550

551
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
552
// within the database for the referenced channel. The `flags` attribute within
553
// the ChannelEdgePolicy determines which of the directed edges are being
554
// updated. If the flag is 1, then the first node's information is being
555
// updated, otherwise it's the second node's information. The node ordering is
556
// determined by the lexicographical ordering of the identity public keys of the
557
// nodes on either side of the channel.
558
//
559
// NOTE: part of the V1Store interface.
560
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
561
        edge *models.ChannelEdgePolicy,
562
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
563

×
564
        var (
×
565
                isUpdate1    bool
×
566
                edgeNotFound bool
×
567
                from, to     route.Vertex
×
568
        )
×
569

×
570
        r := &batch.Request[SQLQueries]{
×
571
                Opts: batch.NewSchedulerOptions(opts...),
×
572
                Reset: func() {
×
573
                        isUpdate1 = false
×
574
                        edgeNotFound = false
×
575
                },
×
576
                Do: func(tx SQLQueries) error {
×
577
                        var err error
×
578
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
579
                                ctx, tx, edge,
×
580
                        )
×
581
                        if err != nil {
×
582
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
583
                        }
×
584

585
                        // Silence ErrEdgeNotFound so that the batch can
586
                        // succeed, but propagate the error via local state.
587
                        if errors.Is(err, ErrEdgeNotFound) {
×
588
                                edgeNotFound = true
×
589
                                return nil
×
590
                        }
×
591

592
                        return err
×
593
                },
594
                OnCommit: func(err error) error {
×
595
                        switch {
×
596
                        case err != nil:
×
597
                                return err
×
598
                        case edgeNotFound:
×
599
                                return ErrEdgeNotFound
×
600
                        default:
×
601
                                s.updateEdgeCache(edge, isUpdate1)
×
602
                                return nil
×
603
                        }
604
                },
605
        }
606

607
        err := s.chanScheduler.Execute(ctx, r)
×
608

×
609
        return from, to, err
×
610
}
611

612
// updateEdgeCache updates our reject and channel caches with the new
613
// edge policy information.
614
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
615
        isUpdate1 bool) {
×
616

×
617
        // If an entry for this channel is found in reject cache, we'll modify
×
618
        // the entry with the updated timestamp for the direction that was just
×
619
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
620
        // during the next query for this edge.
×
621
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
622
                if isUpdate1 {
×
623
                        entry.upd1Time = e.LastUpdate.Unix()
×
624
                } else {
×
625
                        entry.upd2Time = e.LastUpdate.Unix()
×
626
                }
×
627
                s.rejectCache.insert(e.ChannelID, entry)
×
628
        }
629

630
        // If an entry for this channel is found in channel cache, we'll modify
631
        // the entry with the updated policy for the direction that was just
632
        // written. If the edge doesn't exist, we'll defer loading the info and
633
        // policies and lazily read from disk during the next query.
634
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
635
                if isUpdate1 {
×
636
                        channel.Policy1 = e
×
637
                } else {
×
638
                        channel.Policy2 = e
×
639
                }
×
640
                s.chanCache.insert(e.ChannelID, channel)
×
641
        }
642
}
643

644
// updateChanEdgePolicy upserts the channel policy info we have stored for
645
// a channel we already know of.
646
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
647
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
648
        error) {
×
649

×
650
        var (
×
651
                node1Pub, node2Pub route.Vertex
×
652
                isNode1            bool
×
653
                chanIDB            [8]byte
×
654
        )
×
655
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
656

×
657
        // Check that this edge policy refers to a channel that we already
×
658
        // know of. We do this explicitly so that we can return the appropriate
×
659
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
660
        // abort the transaction which would abort the entire batch.
×
661
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
662
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
663
                        Scid:    chanIDB[:],
×
664
                        Version: int16(ProtocolV1),
×
665
                },
×
666
        )
×
667
        if errors.Is(err, sql.ErrNoRows) {
×
668
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
669
        } else if err != nil {
×
670
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
671
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
672
        }
×
673

674
        copy(node1Pub[:], dbChan.Node1PubKey)
×
675
        copy(node2Pub[:], dbChan.Node2PubKey)
×
676

×
677
        // Figure out which node this edge is from.
×
678
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
679
        nodeID := dbChan.NodeID1
×
680
        if !isNode1 {
×
681
                nodeID = dbChan.NodeID2
×
682
        }
×
683

684
        var (
×
685
                inboundBase sql.NullInt64
×
686
                inboundRate sql.NullInt64
×
687
        )
×
688
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
689
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
690
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
691
        })
×
692

693
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
694
                Version:     int16(ProtocolV1),
×
695
                ChannelID:   dbChan.ID,
×
696
                NodeID:      nodeID,
×
697
                Timelock:    int32(edge.TimeLockDelta),
×
698
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
699
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
700
                MinHtlcMsat: int64(edge.MinHTLC),
×
701
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
702
                Disabled: sql.NullBool{
×
703
                        Valid: true,
×
704
                        Bool:  edge.IsDisabled(),
×
705
                },
×
706
                MaxHtlcMsat: sql.NullInt64{
×
707
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
708
                        Int64: int64(edge.MaxHTLC),
×
709
                },
×
710
                InboundBaseFeeMsat:      inboundBase,
×
711
                InboundFeeRateMilliMsat: inboundRate,
×
712
                Signature:               edge.SigBytes,
×
713
        })
×
714
        if err != nil {
×
715
                return node1Pub, node2Pub, isNode1,
×
716
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
717
        }
×
718

719
        // Convert the flat extra opaque data into a map of TLV types to
720
        // values.
721
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
722
        if err != nil {
×
723
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
724
                        "marshal extra opaque data: %w", err)
×
725
        }
×
726

727
        // Update the channel policy's extra signed fields.
728
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
729
        if err != nil {
×
730
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
731
                        "policy extra TLVs: %w", err)
×
732
        }
×
733

734
        return node1Pub, node2Pub, isNode1, nil
×
735
}
736

737
// getNodeByPubKey attempts to look up a target node by its public key.
738
func getNodeByPubKey(ctx context.Context, db SQLQueries,
739
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
740

×
741
        dbNode, err := db.GetNodeByPubKey(
×
742
                ctx, sqlc.GetNodeByPubKeyParams{
×
743
                        Version: int16(ProtocolV1),
×
744
                        PubKey:  pubKey[:],
×
745
                },
×
746
        )
×
747
        if errors.Is(err, sql.ErrNoRows) {
×
748
                return 0, nil, ErrGraphNodeNotFound
×
749
        } else if err != nil {
×
750
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
751
        }
×
752

753
        node, err := buildNode(ctx, db, &dbNode)
×
754
        if err != nil {
×
755
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
756
        }
×
757

758
        return dbNode.ID, node, nil
×
759
}
760

761
// buildNode constructs a LightningNode instance from the given database node
762
// record. The node's features, addresses and extra signed fields are also
763
// fetched from the database and set on the node.
764
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
765
        *models.LightningNode, error) {
×
766

×
767
        if dbNode.Version != int16(ProtocolV1) {
×
768
                return nil, fmt.Errorf("unsupported node version: %d",
×
769
                        dbNode.Version)
×
770
        }
×
771

772
        var pub [33]byte
×
773
        copy(pub[:], dbNode.PubKey)
×
774

×
775
        node := &models.LightningNode{
×
776
                PubKeyBytes: pub,
×
777
                Features:    lnwire.EmptyFeatureVector(),
×
778
                LastUpdate:  time.Unix(0, 0),
×
779
        }
×
780

×
781
        if len(dbNode.Signature) == 0 {
×
782
                return node, nil
×
783
        }
×
784

785
        node.HaveNodeAnnouncement = true
×
786
        node.AuthSigBytes = dbNode.Signature
×
787
        node.Alias = dbNode.Alias.String
×
788
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
789

×
790
        var err error
×
791
        node.Color, err = DecodeHexColor(dbNode.Color.String)
×
792
        if err != nil {
×
793
                return nil, fmt.Errorf("unable to decode color: %w", err)
×
794
        }
×
795

796
        // Fetch the node's features.
797
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
798
        if err != nil {
×
799
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
800
                        "features: %w", dbNode.ID, err)
×
801
        }
×
802

803
        // Fetch the node's addresses.
804
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
805
        if err != nil {
×
806
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
807
                        "addresses: %w", dbNode.ID, err)
×
808
        }
×
809

810
        // Fetch the node's extra signed fields.
811
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
812
        if err != nil {
×
813
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
814
                        "extra signed fields: %w", dbNode.ID, err)
×
815
        }
×
816

817
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
818
        if err != nil {
×
819
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
820
                        "fields: %w", err)
×
821
        }
×
822

823
        if len(recs) != 0 {
×
824
                node.ExtraOpaqueData = recs
×
825
        }
×
826

827
        return node, nil
×
828
}
829

830
// getNodeFeatures fetches the feature bits and constructs the feature vector
831
// for a node with the given DB ID.
832
func getNodeFeatures(ctx context.Context, db SQLQueries,
833
        nodeID int64) (*lnwire.FeatureVector, error) {
×
834

×
835
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
836
        if err != nil {
×
837
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
838
                        nodeID, err)
×
839
        }
×
840

841
        features := lnwire.EmptyFeatureVector()
×
842
        for _, feature := range rows {
×
843
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
844
        }
×
845

846
        return features, nil
×
847
}
848

849
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
850
// given DB ID.
851
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
852
        nodeID int64) (map[uint64][]byte, error) {
×
853

×
854
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
855
        if err != nil {
×
856
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
857
                        "signed fields: %w", nodeID, err)
×
858
        }
×
859

860
        extraFields := make(map[uint64][]byte)
×
861
        for _, field := range fields {
×
862
                extraFields[uint64(field.Type)] = field.Value
×
863
        }
×
864

865
        return extraFields, nil
×
866
}
867

868
// upsertNode upserts the node record into the database. If the node already
869
// exists, then the node's information is updated. If the node doesn't exist,
870
// then a new node is created. The node's features, addresses and extra TLV
871
// types are also updated. The node's DB ID is returned.
872
func upsertNode(ctx context.Context, db SQLQueries,
873
        node *models.LightningNode) (int64, error) {
×
874

×
875
        params := sqlc.UpsertNodeParams{
×
876
                Version: int16(ProtocolV1),
×
877
                PubKey:  node.PubKeyBytes[:],
×
878
        }
×
879

×
880
        if node.HaveNodeAnnouncement {
×
881
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
882
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
883
                params.Alias = sqldb.SQLStr(node.Alias)
×
884
                params.Signature = node.AuthSigBytes
×
885
        }
×
886

887
        nodeID, err := db.UpsertNode(ctx, params)
×
888
        if err != nil {
×
889
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
890
                        err)
×
891
        }
×
892

893
        // We can exit here if we don't have the announcement yet.
894
        if !node.HaveNodeAnnouncement {
×
895
                return nodeID, nil
×
896
        }
×
897

898
        // Update the node's features.
899
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
900
        if err != nil {
×
901
                return 0, fmt.Errorf("inserting node features: %w", err)
×
902
        }
×
903

904
        // Update the node's addresses.
905
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
906
        if err != nil {
×
907
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
908
        }
×
909

910
        // Convert the flat extra opaque data into a map of TLV types to
911
        // values.
912
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
913
        if err != nil {
×
914
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
915
                        err)
×
916
        }
×
917

918
        // Update the node's extra signed fields.
919
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
920
        if err != nil {
×
921
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
922
        }
×
923

924
        return nodeID, nil
×
925
}
926

927
// upsertNodeFeatures updates the node's features node_features table. This
928
// includes deleting any feature bits no longer present and inserting any new
929
// feature bits. If the feature bit does not yet exist in the features table,
930
// then an entry is created in that table first.
931
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
932
        features *lnwire.FeatureVector) error {
×
933

×
934
        // Get any existing features for the node.
×
935
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
936
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
937
                return err
×
938
        }
×
939

940
        // Copy the nodes latest set of feature bits.
941
        newFeatures := make(map[int32]struct{})
×
942
        if features != nil {
×
943
                for feature := range features.Features() {
×
944
                        newFeatures[int32(feature)] = struct{}{}
×
945
                }
×
946
        }
947

948
        // For any current feature that already exists in the DB, remove it from
949
        // the in-memory map. For any existing feature that does not exist in
950
        // the in-memory map, delete it from the database.
951
        for _, feature := range existingFeatures {
×
952
                // The feature is still present, so there are no updates to be
×
953
                // made.
×
954
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
955
                        delete(newFeatures, feature.FeatureBit)
×
956
                        continue
×
957
                }
958

959
                // The feature is no longer present, so we remove it from the
960
                // database.
961
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
962
                        NodeID:     nodeID,
×
963
                        FeatureBit: feature.FeatureBit,
×
964
                })
×
965
                if err != nil {
×
966
                        return fmt.Errorf("unable to delete node(%d) "+
×
967
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
968
                                err)
×
969
                }
×
970
        }
971

972
        // Any remaining entries in newFeatures are new features that need to be
973
        // added to the database for the first time.
974
        for feature := range newFeatures {
×
975
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
976
                        NodeID:     nodeID,
×
977
                        FeatureBit: feature,
×
978
                })
×
979
                if err != nil {
×
980
                        return fmt.Errorf("unable to insert node(%d) "+
×
981
                                "feature(%v): %w", nodeID, feature, err)
×
982
                }
×
983
        }
984

985
        return nil
×
986
}
987

988
// fetchNodeFeatures fetches the features for a node with the given public key.
989
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
990
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
991

×
992
        rows, err := queries.GetNodeFeaturesByPubKey(
×
993
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
994
                        PubKey:  nodePub[:],
×
995
                        Version: int16(ProtocolV1),
×
996
                },
×
997
        )
×
998
        if err != nil {
×
999
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
1000
                        nodePub, err)
×
1001
        }
×
1002

1003
        features := lnwire.EmptyFeatureVector()
×
1004
        for _, bit := range rows {
×
1005
                features.Set(lnwire.FeatureBit(bit))
×
1006
        }
×
1007

1008
        return features, nil
×
1009
}
1010

1011
// dbAddressType is an enum type that represents the different address types
1012
// that we store in the node_addresses table. The address type determines how
1013
// the address is to be serialised/deserialize.
1014
type dbAddressType uint8
1015

1016
const (
1017
        addressTypeIPv4   dbAddressType = 1
1018
        addressTypeIPv6   dbAddressType = 2
1019
        addressTypeTorV2  dbAddressType = 3
1020
        addressTypeTorV3  dbAddressType = 4
1021
        addressTypeOpaque dbAddressType = math.MaxInt8
1022
)
1023

1024
// upsertNodeAddresses updates the node's addresses in the database. This
1025
// includes deleting any existing addresses and inserting the new set of
1026
// addresses. The deletion is necessary since the ordering of the addresses may
1027
// change, and we need to ensure that the database reflects the latest set of
1028
// addresses so that at the time of reconstructing the node announcement, the
1029
// order is preserved and the signature over the message remains valid.
1030
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
1031
        addresses []net.Addr) error {
×
1032

×
1033
        // Delete any existing addresses for the node. This is required since
×
1034
        // even if the new set of addresses is the same, the ordering may have
×
1035
        // changed for a given address type.
×
1036
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
1037
        if err != nil {
×
1038
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
1039
                        nodeID, err)
×
1040
        }
×
1041

1042
        // Copy the nodes latest set of addresses.
1043
        newAddresses := map[dbAddressType][]string{
×
1044
                addressTypeIPv4:   {},
×
1045
                addressTypeIPv6:   {},
×
1046
                addressTypeTorV2:  {},
×
1047
                addressTypeTorV3:  {},
×
1048
                addressTypeOpaque: {},
×
1049
        }
×
1050
        addAddr := func(t dbAddressType, addr net.Addr) {
×
1051
                newAddresses[t] = append(newAddresses[t], addr.String())
×
1052
        }
×
1053

1054
        for _, address := range addresses {
×
1055
                switch addr := address.(type) {
×
1056
                case *net.TCPAddr:
×
1057
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
1058
                                addAddr(addressTypeIPv4, addr)
×
1059
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
1060
                                addAddr(addressTypeIPv6, addr)
×
1061
                        } else {
×
1062
                                return fmt.Errorf("unhandled IP address: %v",
×
1063
                                        addr)
×
1064
                        }
×
1065

1066
                case *tor.OnionAddr:
×
1067
                        switch len(addr.OnionService) {
×
1068
                        case tor.V2Len:
×
1069
                                addAddr(addressTypeTorV2, addr)
×
1070
                        case tor.V3Len:
×
1071
                                addAddr(addressTypeTorV3, addr)
×
1072
                        default:
×
1073
                                return fmt.Errorf("invalid length for a tor " +
×
1074
                                        "address")
×
1075
                        }
1076

1077
                case *lnwire.OpaqueAddrs:
×
1078
                        addAddr(addressTypeOpaque, addr)
×
1079

1080
                default:
×
1081
                        return fmt.Errorf("unhandled address type: %T", addr)
×
1082
                }
1083
        }
1084

1085
        // Any remaining entries in newAddresses are new addresses that need to
1086
        // be added to the database for the first time.
1087
        for addrType, addrList := range newAddresses {
×
1088
                for position, addr := range addrList {
×
1089
                        err := db.InsertNodeAddress(
×
1090
                                ctx, sqlc.InsertNodeAddressParams{
×
1091
                                        NodeID:   nodeID,
×
1092
                                        Type:     int16(addrType),
×
1093
                                        Address:  addr,
×
1094
                                        Position: int32(position),
×
1095
                                },
×
1096
                        )
×
1097
                        if err != nil {
×
1098
                                return fmt.Errorf("unable to insert "+
×
1099
                                        "node(%d) address(%v): %w", nodeID,
×
1100
                                        addr, err)
×
1101
                        }
×
1102
                }
1103
        }
1104

1105
        return nil
×
1106
}
1107

1108
// getNodeAddresses fetches the addresses for a node with the given public key.
1109
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
1110
        []net.Addr, error) {
×
1111

×
1112
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
1113
        // are returned in the same order as they were inserted.
×
1114
        rows, err := db.GetNodeAddressesByPubKey(
×
1115
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
1116
                        Version: int16(ProtocolV1),
×
1117
                        PubKey:  nodePub,
×
1118
                },
×
1119
        )
×
1120
        if err != nil {
×
1121
                return false, nil, err
×
1122
        }
×
1123

1124
        // GetNodeAddressesByPubKey uses a left join so there should always be
1125
        // at least one row returned if the node exists even if it has no
1126
        // addresses.
1127
        if len(rows) == 0 {
×
1128
                return false, nil, nil
×
1129
        }
×
1130

1131
        addresses := make([]net.Addr, 0, len(rows))
×
1132
        for _, addr := range rows {
×
1133
                if !(addr.Type.Valid && addr.Address.Valid) {
×
1134
                        continue
×
1135
                }
1136

1137
                address := addr.Address.String
×
1138

×
1139
                switch dbAddressType(addr.Type.Int16) {
×
1140
                case addressTypeIPv4:
×
1141
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
1142
                        if err != nil {
×
1143
                                return false, nil, nil
×
1144
                        }
×
1145
                        tcp.IP = tcp.IP.To4()
×
1146

×
1147
                        addresses = append(addresses, tcp)
×
1148

1149
                case addressTypeIPv6:
×
1150
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
1151
                        if err != nil {
×
1152
                                return false, nil, nil
×
1153
                        }
×
1154
                        addresses = append(addresses, tcp)
×
1155

1156
                case addressTypeTorV3, addressTypeTorV2:
×
1157
                        service, portStr, err := net.SplitHostPort(address)
×
1158
                        if err != nil {
×
1159
                                return false, nil, fmt.Errorf("unable to "+
×
1160
                                        "split tor v3 address: %v",
×
1161
                                        addr.Address)
×
1162
                        }
×
1163

1164
                        port, err := strconv.Atoi(portStr)
×
1165
                        if err != nil {
×
1166
                                return false, nil, err
×
1167
                        }
×
1168

1169
                        addresses = append(addresses, &tor.OnionAddr{
×
1170
                                OnionService: service,
×
1171
                                Port:         port,
×
1172
                        })
×
1173

1174
                case addressTypeOpaque:
×
1175
                        opaque, err := hex.DecodeString(address)
×
1176
                        if err != nil {
×
1177
                                return false, nil, fmt.Errorf("unable to "+
×
1178
                                        "decode opaque address: %v", addr)
×
1179
                        }
×
1180

1181
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
1182
                                Payload: opaque,
×
1183
                        })
×
1184

1185
                default:
×
1186
                        return false, nil, fmt.Errorf("unknown address "+
×
1187
                                "type: %v", addr.Type)
×
1188
                }
1189
        }
1190

1191
        return true, addresses, nil
×
1192
}
1193

1194
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
1195
// database. This includes updating any existing types, inserting any new types,
1196
// and deleting any types that are no longer present.
1197
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
1198
        nodeID int64, extraFields map[uint64][]byte) error {
×
1199

×
1200
        // Get any existing extra signed fields for the node.
×
1201
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
1202
        if err != nil {
×
1203
                return err
×
1204
        }
×
1205

1206
        // Make a lookup map of the existing field types so that we can use it
1207
        // to keep track of any fields we should delete.
1208
        m := make(map[uint64]bool)
×
1209
        for _, field := range existingFields {
×
1210
                m[uint64(field.Type)] = true
×
1211
        }
×
1212

1213
        // For all the new fields, we'll upsert them and remove them from the
1214
        // map of existing fields.
1215
        for tlvType, value := range extraFields {
×
1216
                err = db.UpsertNodeExtraType(
×
1217
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
1218
                                NodeID: nodeID,
×
1219
                                Type:   int64(tlvType),
×
1220
                                Value:  value,
×
1221
                        },
×
1222
                )
×
1223
                if err != nil {
×
1224
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
1225
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1226
                }
×
1227

1228
                // Remove the field from the map of existing fields if it was
1229
                // present.
1230
                delete(m, tlvType)
×
1231
        }
1232

1233
        // For all the fields that are left in the map of existing fields, we'll
1234
        // delete them as they are no longer present in the new set of fields.
1235
        for tlvType := range m {
×
1236
                err = db.DeleteExtraNodeType(
×
1237
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
1238
                                NodeID: nodeID,
×
1239
                                Type:   int64(tlvType),
×
1240
                        },
×
1241
                )
×
1242
                if err != nil {
×
1243
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
1244
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1245
                }
×
1246
        }
1247

1248
        return nil
×
1249
}
1250

1251
// getSourceNode returns the DB node ID and pub key of the source node for the
1252
// specified protocol version.
1253
func getSourceNode(ctx context.Context, db SQLQueries,
1254
        version ProtocolVersion) (int64, route.Vertex, error) {
×
1255

×
1256
        var pubKey route.Vertex
×
1257

×
1258
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
1259
        if err != nil {
×
1260
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
1261
                        err)
×
1262
        }
×
1263

1264
        if len(nodes) == 0 {
×
1265
                return 0, pubKey, ErrSourceNodeNotSet
×
1266
        } else if len(nodes) > 1 {
×
1267
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
1268
                        "protocol %s found", version)
×
1269
        }
×
1270

1271
        copy(pubKey[:], nodes[0].PubKey)
×
1272

×
1273
        return nodes[0].NodeID, pubKey, nil
×
1274
}
1275

1276
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
1277
// This then produces a map from TLV type to value. If the input is not a
1278
// valid TLV stream, then an error is returned.
1279
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
1280
        r := bytes.NewReader(data)
×
1281

×
1282
        tlvStream, err := tlv.NewStream()
×
1283
        if err != nil {
×
1284
                return nil, err
×
1285
        }
×
1286

1287
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
1288
        // pass it into the P2P decoding variant.
1289
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
1290
        if err != nil {
×
1291
                return nil, err
×
1292
        }
×
1293
        if len(parsedTypes) == 0 {
×
1294
                return nil, nil
×
1295
        }
×
1296

1297
        records := make(map[uint64][]byte)
×
1298
        for k, v := range parsedTypes {
×
1299
                records[uint64(k)] = v
×
1300
        }
×
1301

1302
        return records, nil
×
1303
}
1304

1305
// insertChannel inserts a new channel record into the database.
1306
func insertChannel(ctx context.Context, db SQLQueries,
1307
        edge *models.ChannelEdgeInfo) error {
×
1308

×
1309
        var chanIDB [8]byte
×
1310
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
1311

×
1312
        // Make sure that the channel doesn't already exist. We do this
×
1313
        // explicitly instead of relying on catching a unique constraint error
×
1314
        // because relying on SQL to throw that error would abort the entire
×
1315
        // batch of transactions.
×
1316
        _, err := db.GetChannelBySCID(
×
1317
                ctx, sqlc.GetChannelBySCIDParams{
×
1318
                        Scid:    chanIDB[:],
×
1319
                        Version: int16(ProtocolV1),
×
1320
                },
×
1321
        )
×
1322
        if err == nil {
×
1323
                return ErrEdgeAlreadyExist
×
1324
        } else if !errors.Is(err, sql.ErrNoRows) {
×
1325
                return fmt.Errorf("unable to fetch channel: %w", err)
×
1326
        }
×
1327

1328
        // Make sure that at least a "shell" entry for each node is present in
1329
        // the nodes table.
1330
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
1331
        if err != nil {
×
1332
                return fmt.Errorf("unable to create shell node: %w", err)
×
1333
        }
×
1334

1335
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
1336
        if err != nil {
×
1337
                return fmt.Errorf("unable to create shell node: %w", err)
×
1338
        }
×
1339

1340
        var capacity sql.NullInt64
×
1341
        if edge.Capacity != 0 {
×
1342
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
1343
        }
×
1344

1345
        createParams := sqlc.CreateChannelParams{
×
1346
                Version:     int16(ProtocolV1),
×
1347
                Scid:        chanIDB[:],
×
1348
                NodeID1:     node1DBID,
×
1349
                NodeID2:     node2DBID,
×
1350
                Outpoint:    edge.ChannelPoint.String(),
×
1351
                Capacity:    capacity,
×
1352
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
1353
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
1354
        }
×
1355

×
1356
        if edge.AuthProof != nil {
×
1357
                proof := edge.AuthProof
×
1358

×
1359
                createParams.Node1Signature = proof.NodeSig1Bytes
×
1360
                createParams.Node2Signature = proof.NodeSig2Bytes
×
1361
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
1362
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
1363
        }
×
1364

1365
        // Insert the new channel record.
1366
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
1367
        if err != nil {
×
1368
                return err
×
1369
        }
×
1370

1371
        // Insert any channel features.
1372
        if len(edge.Features) != 0 {
×
1373
                chanFeatures := lnwire.NewRawFeatureVector()
×
1374
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
1375
                if err != nil {
×
1376
                        return err
×
1377
                }
×
1378

1379
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
1380
                for feature := range fv.Features() {
×
1381
                        err = db.InsertChannelFeature(
×
1382
                                ctx, sqlc.InsertChannelFeatureParams{
×
1383
                                        ChannelID:  dbChanID,
×
1384
                                        FeatureBit: int32(feature),
×
1385
                                },
×
1386
                        )
×
1387
                        if err != nil {
×
1388
                                return fmt.Errorf("unable to insert "+
×
1389
                                        "channel(%d) feature(%v): %w", dbChanID,
×
1390
                                        feature, err)
×
1391
                        }
×
1392
                }
1393
        }
1394

1395
        // Finally, insert any extra TLV fields in the channel announcement.
1396
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1397
        if err != nil {
×
1398
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
1399
                        err)
×
1400
        }
×
1401

1402
        for tlvType, value := range extra {
×
1403
                err := db.CreateChannelExtraType(
×
1404
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
1405
                                ChannelID: dbChanID,
×
1406
                                Type:      int64(tlvType),
×
1407
                                Value:     value,
×
1408
                        },
×
1409
                )
×
1410
                if err != nil {
×
1411
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
1412
                                "signed field(%v): %w", edge.ChannelID,
×
1413
                                tlvType, err)
×
1414
                }
×
1415
        }
1416

1417
        return nil
×
1418
}
1419

1420
// maybeCreateShellNode checks if a shell node entry exists for the
1421
// given public key. If it does not exist, then a new shell node entry is
1422
// created. The ID of the node is returned. A shell node only has a protocol
1423
// version and public key persisted.
1424
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
1425
        pubKey route.Vertex) (int64, error) {
×
1426

×
1427
        dbNode, err := db.GetNodeByPubKey(
×
1428
                ctx, sqlc.GetNodeByPubKeyParams{
×
1429
                        PubKey:  pubKey[:],
×
1430
                        Version: int16(ProtocolV1),
×
1431
                },
×
1432
        )
×
1433
        // The node exists. Return the ID.
×
1434
        if err == nil {
×
1435
                return dbNode.ID, nil
×
1436
        } else if !errors.Is(err, sql.ErrNoRows) {
×
1437
                return 0, err
×
1438
        }
×
1439

1440
        // Otherwise, the node does not exist, so we create a shell entry for
1441
        // it.
1442
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
1443
                Version: int16(ProtocolV1),
×
1444
                PubKey:  pubKey[:],
×
1445
        })
×
1446
        if err != nil {
×
1447
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
1448
        }
×
1449

1450
        return id, nil
×
1451
}
1452

1453
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
1454
// the database. This includes deleting any existing types and then inserting
1455
// the new types.
1456
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
1457
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
1458

×
1459
        // Delete all existing extra signed fields for the channel policy.
×
1460
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
1461
        if err != nil {
×
1462
                return fmt.Errorf("unable to delete "+
×
1463
                        "existing policy extra signed fields for policy %d: %w",
×
1464
                        chanPolicyID, err)
×
1465
        }
×
1466

1467
        // Insert all new extra signed fields for the channel policy.
1468
        for tlvType, value := range extraFields {
×
1469
                err = db.InsertChanPolicyExtraType(
×
1470
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
1471
                                ChannelPolicyID: chanPolicyID,
×
1472
                                Type:            int64(tlvType),
×
1473
                                Value:           value,
×
1474
                        },
×
1475
                )
×
1476
                if err != nil {
×
1477
                        return fmt.Errorf("unable to insert "+
×
1478
                                "channel_policy(%d) extra signed field(%v): %w",
×
1479
                                chanPolicyID, tlvType, err)
×
1480
                }
×
1481
        }
1482

1483
        return nil
×
1484
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc