• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15298164175

28 May 2025 10:45AM UTC coverage: 58.327% (-0.04%) from 58.362%
15298164175

push

github

web-flow
Merge pull request #9873 from ellemouton/sqldbHelpers

sqldb: re-usable TxOptions and NoOpReset

4 of 38 new or added lines in 4 files covered. (10.53%)

96 existing lines in 16 files now uncovered.

97406 of 167000 relevant lines covered (58.33%)

1.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "strconv"
13
        "sync"
14
        "time"
15

16
        "github.com/btcsuite/btcd/btcec/v2"
17
        "github.com/lightningnetwork/lnd/batch"
18
        "github.com/lightningnetwork/lnd/graph/db/models"
19
        "github.com/lightningnetwork/lnd/lnwire"
20
        "github.com/lightningnetwork/lnd/routing/route"
21
        "github.com/lightningnetwork/lnd/sqldb"
22
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
23
        "github.com/lightningnetwork/lnd/tlv"
24
        "github.com/lightningnetwork/lnd/tor"
25
)
26

27
// ProtocolVersion is an enum that defines the gossip protocol version of a
28
// message.
29
type ProtocolVersion uint8
30

31
const (
32
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
33
        ProtocolV1 ProtocolVersion = 1
34
)
35

36
// String returns a string representation of the protocol version.
37
func (v ProtocolVersion) String() string {
×
38
        return fmt.Sprintf("V%d", v)
×
39
}
×
40

41
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
42
// execute queries against the SQL graph tables.
43
//
44
//nolint:ll,interfacebloat
45
type SQLQueries interface {
46
        /*
47
                Node queries.
48
        */
49
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
50
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
51
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
52
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
53

54
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
55
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
56
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
57

58
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
59
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
60
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
61

62
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
63
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
64
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
65
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
66

67
        /*
68
                Source node queries.
69
        */
70
        AddSourceNode(ctx context.Context, nodeID int64) error
71
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
72
}
73

74
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
75
// database operations.
76
type BatchedSQLQueries interface {
77
        SQLQueries
78
        sqldb.BatchedTx[SQLQueries]
79
}
80

81
// SQLStore is an implementation of the V1Store interface that uses a SQL
82
// database as the backend.
83
//
84
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
85
// implement the V1Store interface incrementally. For any method not
86
// implemented,  things will fall back to the KVStore. This is ONLY the case
87
// for the time being while this struct is purely used in unit tests only.
88
type SQLStore struct {
89
        db BatchedSQLQueries
90

91
        // cacheMu guards all caches (rejectCache and chanCache). If
92
        // this mutex will be acquired at the same time as the DB mutex then
93
        // the cacheMu MUST be acquired first to prevent deadlock.
94
        cacheMu     sync.RWMutex
95
        rejectCache *rejectCache
96
        chanCache   *channelCache
97

98
        chanScheduler batch.Scheduler[SQLQueries]
99
        nodeScheduler batch.Scheduler[SQLQueries]
100

101
        // Temporary fall-back to the KVStore so that we can implement the
102
        // interface incrementally.
103
        *KVStore
104
}
105

106
// A compile-time assertion to ensure that SQLStore implements the V1Store
107
// interface.
108
var _ V1Store = (*SQLStore)(nil)
109

110
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
111
// storage backend.
112
func NewSQLStore(db BatchedSQLQueries, kvStore *KVStore,
113
        options ...StoreOptionModifier) (*SQLStore, error) {
×
114

×
115
        opts := DefaultOptions()
×
116
        for _, o := range options {
×
117
                o(opts)
×
118
        }
×
119

120
        if opts.NoMigration {
×
121
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
122
                        "supported for SQL stores")
×
123
        }
×
124

125
        s := &SQLStore{
×
126
                db:          db,
×
127
                KVStore:     kvStore,
×
128
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
129
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
130
        }
×
131

×
132
        s.chanScheduler = batch.NewTimeScheduler(
×
133
                db, &s.cacheMu, opts.BatchCommitInterval,
×
134
        )
×
135
        s.nodeScheduler = batch.NewTimeScheduler(
×
136
                db, nil, opts.BatchCommitInterval,
×
137
        )
×
138

×
139
        return s, nil
×
140
}
141

142
// AddLightningNode adds a vertex/node to the graph database. If the node is not
143
// in the database from before, this will add a new, unconnected one to the
144
// graph. If it is present from before, this will update that node's
145
// information.
146
//
147
// NOTE: part of the V1Store interface.
148
func (s *SQLStore) AddLightningNode(node *models.LightningNode,
149
        opts ...batch.SchedulerOption) error {
×
150

×
151
        ctx := context.TODO()
×
152

×
153
        r := &batch.Request[SQLQueries]{
×
154
                Opts: batch.NewSchedulerOptions(opts...),
×
155
                Do: func(queries SQLQueries) error {
×
156
                        _, err := upsertNode(ctx, queries, node)
×
157
                        return err
×
158
                },
×
159
        }
160

161
        return s.nodeScheduler.Execute(ctx, r)
×
162
}
163

164
// FetchLightningNode attempts to look up a target node by its identity public
165
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
166
// returned.
167
//
168
// NOTE: part of the V1Store interface.
169
func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) (
170
        *models.LightningNode, error) {
×
171

×
172
        ctx := context.TODO()
×
173

×
NEW
174
        var node *models.LightningNode
×
NEW
175
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
176
                var err error
×
177
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
178

×
179
                return err
×
NEW
180
        }, sqldb.NoOpReset)
×
181
        if err != nil {
×
182
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
183
        }
×
184

185
        return node, nil
×
186
}
187

188
// HasLightningNode determines if the graph has a vertex identified by the
189
// target node identity public key. If the node exists in the database, a
190
// timestamp of when the data for the node was lasted updated is returned along
191
// with a true boolean. Otherwise, an empty time.Time is returned with a false
192
// boolean.
193
//
194
// NOTE: part of the V1Store interface.
195
func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool,
196
        error) {
×
197

×
198
        ctx := context.TODO()
×
199

×
200
        var (
×
201
                exists     bool
×
202
                lastUpdate time.Time
×
203
        )
×
NEW
204
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
205
                dbNode, err := db.GetNodeByPubKey(
×
206
                        ctx, sqlc.GetNodeByPubKeyParams{
×
207
                                Version: int16(ProtocolV1),
×
208
                                PubKey:  pubKey[:],
×
209
                        },
×
210
                )
×
211
                if errors.Is(err, sql.ErrNoRows) {
×
212
                        return nil
×
213
                } else if err != nil {
×
214
                        return fmt.Errorf("unable to fetch node: %w", err)
×
215
                }
×
216

217
                exists = true
×
218

×
219
                if dbNode.LastUpdate.Valid {
×
220
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
221
                }
×
222

223
                return nil
×
224
        }, sqldb.NoOpReset)
225
        if err != nil {
×
226
                return time.Time{}, false,
×
227
                        fmt.Errorf("unable to fetch node: %w", err)
×
228
        }
×
229

230
        return lastUpdate, exists, nil
×
231
}
232

233
// AddrsForNode returns all known addresses for the target node public key
234
// that the graph DB is aware of. The returned boolean indicates if the
235
// given node is unknown to the graph DB or not.
236
//
237
// NOTE: part of the V1Store interface.
238
func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
239
        error) {
×
240

×
241
        ctx := context.TODO()
×
242

×
243
        var (
×
244
                addresses []net.Addr
×
245
                known     bool
×
246
        )
×
NEW
247
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
248
                var err error
×
249
                known, addresses, err = getNodeAddresses(
×
250
                        ctx, db, nodePub.SerializeCompressed(),
×
251
                )
×
252
                if err != nil {
×
253
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
254
                                err)
×
255
                }
×
256

257
                return nil
×
258
        }, sqldb.NoOpReset)
259
        if err != nil {
×
260
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
261
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
262
        }
×
263

264
        return known, addresses, nil
×
265
}
266

267
// DeleteLightningNode starts a new database transaction to remove a vertex/node
268
// from the database according to the node's public key.
269
//
270
// NOTE: part of the V1Store interface.
271
func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error {
×
272
        ctx := context.TODO()
×
273

×
NEW
274
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
275
                res, err := db.DeleteNodeByPubKey(
×
276
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
277
                                Version: int16(ProtocolV1),
×
278
                                PubKey:  pubKey[:],
×
279
                        },
×
280
                )
×
281
                if err != nil {
×
282
                        return err
×
283
                }
×
284

285
                rows, err := res.RowsAffected()
×
286
                if err != nil {
×
287
                        return err
×
288
                }
×
289

290
                if rows == 0 {
×
291
                        return ErrGraphNodeNotFound
×
292
                } else if rows > 1 {
×
293
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
294
                }
×
295

296
                return err
×
297
        }, sqldb.NoOpReset)
298
        if err != nil {
×
299
                return fmt.Errorf("unable to delete node: %w", err)
×
300
        }
×
301

302
        return nil
×
303
}
304

305
// FetchNodeFeatures returns the features of the given node. If no features are
306
// known for the node, an empty feature vector is returned.
307
//
308
// NOTE: this is part of the graphdb.NodeTraverser interface.
309
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
310
        *lnwire.FeatureVector, error) {
×
311

×
312
        ctx := context.TODO()
×
313

×
314
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
315
}
×
316

317
// LookupAlias attempts to return the alias as advertised by the target node.
318
//
319
// NOTE: part of the V1Store interface.
320
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
×
321
        var (
×
NEW
322
                ctx   = context.TODO()
×
NEW
323
                alias string
×
324
        )
×
NEW
325
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
326
                dbNode, err := db.GetNodeByPubKey(
×
327
                        ctx, sqlc.GetNodeByPubKeyParams{
×
328
                                Version: int16(ProtocolV1),
×
329
                                PubKey:  pub.SerializeCompressed(),
×
330
                        },
×
331
                )
×
332
                if errors.Is(err, sql.ErrNoRows) {
×
333
                        return ErrNodeAliasNotFound
×
334
                } else if err != nil {
×
335
                        return fmt.Errorf("unable to fetch node: %w", err)
×
336
                }
×
337

338
                if !dbNode.Alias.Valid {
×
339
                        return ErrNodeAliasNotFound
×
340
                }
×
341

342
                alias = dbNode.Alias.String
×
343

×
344
                return nil
×
345
        }, sqldb.NoOpReset)
346
        if err != nil {
×
347
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
348
        }
×
349

350
        return alias, nil
×
351
}
352

353
// SourceNode returns the source node of the graph. The source node is treated
354
// as the center node within a star-graph. This method may be used to kick off
355
// a path finding algorithm in order to explore the reachability of another
356
// node based off the source node.
357
//
358
// NOTE: part of the V1Store interface.
359
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
×
360
        ctx := context.TODO()
×
361

×
NEW
362
        var node *models.LightningNode
×
NEW
363
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
364
                _, nodePub, err := getSourceNode(ctx, db, ProtocolV1)
×
365
                if err != nil {
×
366
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
367
                                err)
×
368
                }
×
369

370
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
371

×
372
                return err
×
373
        }, sqldb.NoOpReset)
374
        if err != nil {
×
375
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
376
        }
×
377

378
        return node, nil
×
379
}
380

381
// SetSourceNode sets the source node within the graph database. The source
382
// node is to be used as the center of a star-graph within path finding
383
// algorithms.
384
//
385
// NOTE: part of the V1Store interface.
386
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
×
387
        ctx := context.TODO()
×
388

×
NEW
389
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
390
                id, err := upsertNode(ctx, db, node)
×
391
                if err != nil {
×
392
                        return fmt.Errorf("unable to upsert source node: %w",
×
393
                                err)
×
394
                }
×
395

396
                // Make sure that if a source node for this version is already
397
                // set, then the ID is the same as the one we are about to set.
398
                dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1)
×
399
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
400
                        return fmt.Errorf("unable to fetch source node: %w",
×
401
                                err)
×
402
                } else if err == nil {
×
403
                        if dbSourceNodeID != id {
×
404
                                return fmt.Errorf("v1 source node already "+
×
405
                                        "set to a different node: %d vs %d",
×
406
                                        dbSourceNodeID, id)
×
407
                        }
×
408

409
                        return nil
×
410
                }
411

412
                return db.AddSourceNode(ctx, id)
×
413
        }, sqldb.NoOpReset)
414
}
415

416
// NodeUpdatesInHorizon returns all the known lightning node which have an
417
// update timestamp within the passed range. This method can be used by two
418
// nodes to quickly determine if they have the same set of up to date node
419
// announcements.
420
//
421
// NOTE: This is part of the V1Store interface.
422
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
423
        endTime time.Time) ([]models.LightningNode, error) {
×
424

×
425
        ctx := context.TODO()
×
426

×
NEW
427
        var nodes []models.LightningNode
×
NEW
428
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
429
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
430
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
431
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
432
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
433
                        },
×
434
                )
×
435
                if err != nil {
×
436
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
437
                }
×
438

439
                for _, dbNode := range dbNodes {
×
440
                        node, err := buildNode(ctx, db, &dbNode)
×
441
                        if err != nil {
×
442
                                return fmt.Errorf("unable to build node: %w",
×
443
                                        err)
×
444
                        }
×
445

446
                        nodes = append(nodes, *node)
×
447
                }
448

449
                return nil
×
450
        }, sqldb.NoOpReset)
451
        if err != nil {
×
452
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
453
        }
×
454

455
        return nodes, nil
×
456
}
457

458
// getNodeByPubKey attempts to look up a target node by its public key.
459
func getNodeByPubKey(ctx context.Context, db SQLQueries,
460
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
461

×
462
        dbNode, err := db.GetNodeByPubKey(
×
463
                ctx, sqlc.GetNodeByPubKeyParams{
×
464
                        Version: int16(ProtocolV1),
×
465
                        PubKey:  pubKey[:],
×
466
                },
×
467
        )
×
468
        if errors.Is(err, sql.ErrNoRows) {
×
469
                return 0, nil, ErrGraphNodeNotFound
×
470
        } else if err != nil {
×
471
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
472
        }
×
473

474
        node, err := buildNode(ctx, db, &dbNode)
×
475
        if err != nil {
×
476
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
477
        }
×
478

479
        return dbNode.ID, node, nil
×
480
}
481

482
// buildNode constructs a LightningNode instance from the given database node
483
// record. The node's features, addresses and extra signed fields are also
484
// fetched from the database and set on the node.
485
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
486
        *models.LightningNode, error) {
×
487

×
488
        if dbNode.Version != int16(ProtocolV1) {
×
489
                return nil, fmt.Errorf("unsupported node version: %d",
×
490
                        dbNode.Version)
×
491
        }
×
492

493
        var pub [33]byte
×
494
        copy(pub[:], dbNode.PubKey)
×
495

×
496
        node := &models.LightningNode{
×
497
                PubKeyBytes:     pub,
×
498
                Features:        lnwire.EmptyFeatureVector(),
×
499
                LastUpdate:      time.Unix(0, 0),
×
500
                ExtraOpaqueData: make([]byte, 0),
×
501
        }
×
502

×
503
        if len(dbNode.Signature) == 0 {
×
504
                return node, nil
×
505
        }
×
506

507
        node.HaveNodeAnnouncement = true
×
508
        node.AuthSigBytes = dbNode.Signature
×
509
        node.Alias = dbNode.Alias.String
×
510
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
511

×
512
        var err error
×
513
        node.Color, err = DecodeHexColor(dbNode.Color.String)
×
514
        if err != nil {
×
515
                return nil, fmt.Errorf("unable to decode color: %w", err)
×
516
        }
×
517

518
        // Fetch the node's features.
519
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
520
        if err != nil {
×
521
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
522
                        "features: %w", dbNode.ID, err)
×
523
        }
×
524

525
        // Fetch the node's addresses.
526
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
527
        if err != nil {
×
528
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
529
                        "addresses: %w", dbNode.ID, err)
×
530
        }
×
531

532
        // Fetch the node's extra signed fields.
533
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
534
        if err != nil {
×
535
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
536
                        "extra signed fields: %w", dbNode.ID, err)
×
537
        }
×
538

539
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
540
        if err != nil {
×
541
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
542
                        "fields: %w", err)
×
543
        }
×
544

545
        if len(recs) != 0 {
×
546
                node.ExtraOpaqueData = recs
×
547
        }
×
548

549
        return node, nil
×
550
}
551

552
// getNodeFeatures fetches the feature bits and constructs the feature vector
553
// for a node with the given DB ID.
554
func getNodeFeatures(ctx context.Context, db SQLQueries,
555
        nodeID int64) (*lnwire.FeatureVector, error) {
×
556

×
557
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
558
        if err != nil {
×
559
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
560
                        nodeID, err)
×
561
        }
×
562

563
        features := lnwire.EmptyFeatureVector()
×
564
        for _, feature := range rows {
×
565
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
566
        }
×
567

568
        return features, nil
×
569
}
570

571
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
572
// given DB ID.
573
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
574
        nodeID int64) (map[uint64][]byte, error) {
×
575

×
576
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
577
        if err != nil {
×
578
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
579
                        "signed fields: %w", nodeID, err)
×
580
        }
×
581

582
        extraFields := make(map[uint64][]byte)
×
583
        for _, field := range fields {
×
584
                extraFields[uint64(field.Type)] = field.Value
×
585
        }
×
586

587
        return extraFields, nil
×
588
}
589

590
// upsertNode upserts the node record into the database. If the node already
591
// exists, then the node's information is updated. If the node doesn't exist,
592
// then a new node is created. The node's features, addresses and extra TLV
593
// types are also updated. The node's DB ID is returned.
594
func upsertNode(ctx context.Context, db SQLQueries,
595
        node *models.LightningNode) (int64, error) {
×
596

×
597
        params := sqlc.UpsertNodeParams{
×
598
                Version: int16(ProtocolV1),
×
599
                PubKey:  node.PubKeyBytes[:],
×
600
        }
×
601

×
602
        if node.HaveNodeAnnouncement {
×
603
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
604
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
605
                params.Alias = sqldb.SQLStr(node.Alias)
×
606
                params.Signature = node.AuthSigBytes
×
607
        }
×
608

609
        nodeID, err := db.UpsertNode(ctx, params)
×
610
        if err != nil {
×
611
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
612
                        err)
×
613
        }
×
614

615
        // We can exit here if we don't have the announcement yet.
616
        if !node.HaveNodeAnnouncement {
×
617
                return nodeID, nil
×
618
        }
×
619

620
        // Update the node's features.
621
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
622
        if err != nil {
×
623
                return 0, fmt.Errorf("inserting node features: %w", err)
×
624
        }
×
625

626
        // Update the node's addresses.
627
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
628
        if err != nil {
×
629
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
630
        }
×
631

632
        // Convert the flat extra opaque data into a map of TLV types to
633
        // values.
634
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
635
        if err != nil {
×
636
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
637
                        err)
×
638
        }
×
639

640
        // Update the node's extra signed fields.
641
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
642
        if err != nil {
×
643
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
644
        }
×
645

646
        return nodeID, nil
×
647
}
648

649
// upsertNodeFeatures updates the node's features node_features table. This
650
// includes deleting any feature bits no longer present and inserting any new
651
// feature bits. If the feature bit does not yet exist in the features table,
652
// then an entry is created in that table first.
653
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
654
        features *lnwire.FeatureVector) error {
×
655

×
656
        // Get any existing features for the node.
×
657
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
658
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
659
                return err
×
660
        }
×
661

662
        // Copy the nodes latest set of feature bits.
663
        newFeatures := make(map[int32]struct{})
×
664
        if features != nil {
×
665
                for feature := range features.Features() {
×
666
                        newFeatures[int32(feature)] = struct{}{}
×
667
                }
×
668
        }
669

670
        // For any current feature that already exists in the DB, remove it from
671
        // the in-memory map. For any existing feature that does not exist in
672
        // the in-memory map, delete it from the database.
673
        for _, feature := range existingFeatures {
×
674
                // The feature is still present, so there are no updates to be
×
675
                // made.
×
676
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
677
                        delete(newFeatures, feature.FeatureBit)
×
678
                        continue
×
679
                }
680

681
                // The feature is no longer present, so we remove it from the
682
                // database.
683
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
684
                        NodeID:     nodeID,
×
685
                        FeatureBit: feature.FeatureBit,
×
686
                })
×
687
                if err != nil {
×
688
                        return fmt.Errorf("unable to delete node(%d) "+
×
689
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
690
                                err)
×
691
                }
×
692
        }
693

694
        // Any remaining entries in newFeatures are new features that need to be
695
        // added to the database for the first time.
696
        for feature := range newFeatures {
×
697
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
698
                        NodeID:     nodeID,
×
699
                        FeatureBit: feature,
×
700
                })
×
701
                if err != nil {
×
702
                        return fmt.Errorf("unable to insert node(%d) "+
×
703
                                "feature(%v): %w", nodeID, feature, err)
×
704
                }
×
705
        }
706

707
        return nil
×
708
}
709

710
// fetchNodeFeatures fetches the features for a node with the given public key.
711
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
712
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
713

×
714
        rows, err := queries.GetNodeFeaturesByPubKey(
×
715
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
716
                        PubKey:  nodePub[:],
×
717
                        Version: int16(ProtocolV1),
×
718
                },
×
719
        )
×
720
        if err != nil {
×
721
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
722
                        nodePub, err)
×
723
        }
×
724

725
        features := lnwire.EmptyFeatureVector()
×
726
        for _, bit := range rows {
×
727
                features.Set(lnwire.FeatureBit(bit))
×
728
        }
×
729

730
        return features, nil
×
731
}
732

733
// dbAddressType is an enum type that represents the different address types
734
// that we store in the node_addresses table. The address type determines how
735
// the address is to be serialised/deserialize.
736
type dbAddressType uint8
737

738
const (
739
        addressTypeIPv4   dbAddressType = 1
740
        addressTypeIPv6   dbAddressType = 2
741
        addressTypeTorV2  dbAddressType = 3
742
        addressTypeTorV3  dbAddressType = 4
743
        addressTypeOpaque dbAddressType = math.MaxInt8
744
)
745

746
// upsertNodeAddresses updates the node's addresses in the database. This
747
// includes deleting any existing addresses and inserting the new set of
748
// addresses. The deletion is necessary since the ordering of the addresses may
749
// change, and we need to ensure that the database reflects the latest set of
750
// addresses so that at the time of reconstructing the node announcement, the
751
// order is preserved and the signature over the message remains valid.
752
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
753
        addresses []net.Addr) error {
×
754

×
755
        // Delete any existing addresses for the node. This is required since
×
756
        // even if the new set of addresses is the same, the ordering may have
×
757
        // changed for a given address type.
×
758
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
759
        if err != nil {
×
760
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
761
                        nodeID, err)
×
762
        }
×
763

764
        // Copy the nodes latest set of addresses.
765
        newAddresses := map[dbAddressType][]string{
×
766
                addressTypeIPv4:   {},
×
767
                addressTypeIPv6:   {},
×
768
                addressTypeTorV2:  {},
×
769
                addressTypeTorV3:  {},
×
770
                addressTypeOpaque: {},
×
771
        }
×
772
        addAddr := func(t dbAddressType, addr net.Addr) {
×
773
                newAddresses[t] = append(newAddresses[t], addr.String())
×
774
        }
×
775

776
        for _, address := range addresses {
×
777
                switch addr := address.(type) {
×
778
                case *net.TCPAddr:
×
779
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
780
                                addAddr(addressTypeIPv4, addr)
×
781
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
782
                                addAddr(addressTypeIPv6, addr)
×
783
                        } else {
×
784
                                return fmt.Errorf("unhandled IP address: %v",
×
785
                                        addr)
×
786
                        }
×
787

788
                case *tor.OnionAddr:
×
789
                        switch len(addr.OnionService) {
×
790
                        case tor.V2Len:
×
791
                                addAddr(addressTypeTorV2, addr)
×
792
                        case tor.V3Len:
×
793
                                addAddr(addressTypeTorV3, addr)
×
794
                        default:
×
795
                                return fmt.Errorf("invalid length for a tor " +
×
796
                                        "address")
×
797
                        }
798

799
                case *lnwire.OpaqueAddrs:
×
800
                        addAddr(addressTypeOpaque, addr)
×
801

802
                default:
×
803
                        return fmt.Errorf("unhandled address type: %T", addr)
×
804
                }
805
        }
806

807
        // Any remaining entries in newAddresses are new addresses that need to
808
        // be added to the database for the first time.
809
        for addrType, addrList := range newAddresses {
×
810
                for position, addr := range addrList {
×
811
                        err := db.InsertNodeAddress(
×
812
                                ctx, sqlc.InsertNodeAddressParams{
×
813
                                        NodeID:   nodeID,
×
814
                                        Type:     int16(addrType),
×
815
                                        Address:  addr,
×
816
                                        Position: int32(position),
×
817
                                },
×
818
                        )
×
819
                        if err != nil {
×
820
                                return fmt.Errorf("unable to insert "+
×
821
                                        "node(%d) address(%v): %w", nodeID,
×
822
                                        addr, err)
×
823
                        }
×
824
                }
825
        }
826

827
        return nil
×
828
}
829

830
// getNodeAddresses fetches the addresses for a node with the given public key.
831
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
832
        []net.Addr, error) {
×
833

×
834
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
835
        // are returned in the same order as they were inserted.
×
836
        rows, err := db.GetNodeAddressesByPubKey(
×
837
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
838
                        Version: int16(ProtocolV1),
×
839
                        PubKey:  nodePub,
×
840
                },
×
841
        )
×
842
        if err != nil {
×
843
                return false, nil, err
×
844
        }
×
845

846
        // GetNodeAddressesByPubKey uses a left join so there should always be
847
        // at least one row returned if the node exists even if it has no
848
        // addresses.
849
        if len(rows) == 0 {
×
850
                return false, nil, nil
×
851
        }
×
852

853
        addresses := make([]net.Addr, 0, len(rows))
×
854
        for _, addr := range rows {
×
855
                if !(addr.Type.Valid && addr.Address.Valid) {
×
856
                        continue
×
857
                }
858

859
                address := addr.Address.String
×
860

×
861
                switch dbAddressType(addr.Type.Int16) {
×
862
                case addressTypeIPv4:
×
863
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
864
                        if err != nil {
×
865
                                return false, nil, nil
×
866
                        }
×
867
                        tcp.IP = tcp.IP.To4()
×
868

×
869
                        addresses = append(addresses, tcp)
×
870

871
                case addressTypeIPv6:
×
872
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
873
                        if err != nil {
×
874
                                return false, nil, nil
×
875
                        }
×
876
                        addresses = append(addresses, tcp)
×
877

878
                case addressTypeTorV3, addressTypeTorV2:
×
879
                        service, portStr, err := net.SplitHostPort(address)
×
880
                        if err != nil {
×
881
                                return false, nil, fmt.Errorf("unable to "+
×
882
                                        "split tor v3 address: %v",
×
883
                                        addr.Address)
×
884
                        }
×
885

886
                        port, err := strconv.Atoi(portStr)
×
887
                        if err != nil {
×
888
                                return false, nil, err
×
889
                        }
×
890

891
                        addresses = append(addresses, &tor.OnionAddr{
×
892
                                OnionService: service,
×
893
                                Port:         port,
×
894
                        })
×
895

896
                case addressTypeOpaque:
×
897
                        opaque, err := hex.DecodeString(address)
×
898
                        if err != nil {
×
899
                                return false, nil, fmt.Errorf("unable to "+
×
900
                                        "decode opaque address: %v", addr)
×
901
                        }
×
902

903
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
904
                                Payload: opaque,
×
905
                        })
×
906

907
                default:
×
908
                        return false, nil, fmt.Errorf("unknown address "+
×
909
                                "type: %v", addr.Type)
×
910
                }
911
        }
912

913
        return true, addresses, nil
×
914
}
915

916
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
917
// database. This includes updating any existing types, inserting any new types,
918
// and deleting any types that are no longer present.
919
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
920
        nodeID int64, extraFields map[uint64][]byte) error {
×
921

×
922
        // Get any existing extra signed fields for the node.
×
923
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
924
        if err != nil {
×
925
                return err
×
926
        }
×
927

928
        // Make a lookup map of the existing field types so that we can use it
929
        // to keep track of any fields we should delete.
930
        m := make(map[uint64]bool)
×
931
        for _, field := range existingFields {
×
932
                m[uint64(field.Type)] = true
×
933
        }
×
934

935
        // For all the new fields, we'll upsert them and remove them from the
936
        // map of existing fields.
937
        for tlvType, value := range extraFields {
×
938
                err = db.UpsertNodeExtraType(
×
939
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
940
                                NodeID: nodeID,
×
941
                                Type:   int64(tlvType),
×
942
                                Value:  value,
×
943
                        },
×
944
                )
×
945
                if err != nil {
×
946
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
947
                                "signed field(%v): %w", nodeID, tlvType, err)
×
948
                }
×
949

950
                // Remove the field from the map of existing fields if it was
951
                // present.
952
                delete(m, tlvType)
×
953
        }
954

955
        // For all the fields that are left in the map of existing fields, we'll
956
        // delete them as they are no longer present in the new set of fields.
957
        for tlvType := range m {
×
958
                err = db.DeleteExtraNodeType(
×
959
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
960
                                NodeID: nodeID,
×
961
                                Type:   int64(tlvType),
×
962
                        },
×
963
                )
×
964
                if err != nil {
×
965
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
966
                                "signed field(%v): %w", nodeID, tlvType, err)
×
967
                }
×
968
        }
969

970
        return nil
×
971
}
972

973
// getSourceNode returns the DB node ID and pub key of the source node for the
974
// specified protocol version.
975
func getSourceNode(ctx context.Context, db SQLQueries,
976
        version ProtocolVersion) (int64, route.Vertex, error) {
×
977

×
978
        var pubKey route.Vertex
×
979

×
980
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
981
        if err != nil {
×
982
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
983
                        err)
×
984
        }
×
985

986
        if len(nodes) == 0 {
×
987
                return 0, pubKey, ErrSourceNodeNotSet
×
988
        } else if len(nodes) > 1 {
×
989
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
990
                        "protocol %s found", version)
×
991
        }
×
992

993
        copy(pubKey[:], nodes[0].PubKey)
×
994

×
995
        return nodes[0].NodeID, pubKey, nil
×
996
}
997

998
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
999
// This then produces a map from TLV type to value. If the input is not a
1000
// valid TLV stream, then an error is returned.
1001
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
1002
        r := bytes.NewReader(data)
×
1003

×
1004
        tlvStream, err := tlv.NewStream()
×
1005
        if err != nil {
×
1006
                return nil, err
×
1007
        }
×
1008

1009
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
1010
        // pass it into the P2P decoding variant.
1011
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
1012
        if err != nil {
×
1013
                return nil, err
×
1014
        }
×
1015
        if len(parsedTypes) == 0 {
×
1016
                return nil, nil
×
1017
        }
×
1018

1019
        records := make(map[uint64][]byte)
×
1020
        for k, v := range parsedTypes {
×
1021
                records[uint64(k)] = v
×
1022
        }
×
1023

1024
        return records, nil
×
1025
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc