• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15736305184

18 Jun 2025 02:55PM UTC coverage: 68.145% (-0.1%) from 68.248%
15736305184

Pull #9935

github

web-flow
Merge ece157b40 into 31c74f20f
Pull Request #9935: [11] graph/db: Implement various "ForEach" methods on the graph SQLStore

4 of 317 new or added lines in 3 files covered. (1.26%)

71 existing lines in 19 files now uncovered.

134474 of 197335 relevant lines covered (68.15%)

22161.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "math"
11
        "net"
12
        "strconv"
13
        "sync"
14
        "time"
15

16
        "github.com/btcsuite/btcd/btcec/v2"
17
        "github.com/btcsuite/btcd/btcutil"
18
        "github.com/btcsuite/btcd/chaincfg/chainhash"
19
        "github.com/btcsuite/btcd/wire"
20
        "github.com/lightningnetwork/lnd/batch"
21
        "github.com/lightningnetwork/lnd/fn/v2"
22
        "github.com/lightningnetwork/lnd/graph/db/models"
23
        "github.com/lightningnetwork/lnd/lnwire"
24
        "github.com/lightningnetwork/lnd/routing/route"
25
        "github.com/lightningnetwork/lnd/sqldb"
26
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
27
        "github.com/lightningnetwork/lnd/tlv"
28
        "github.com/lightningnetwork/lnd/tor"
29
)
30

31
// pageSize is the limit for the number of records that can be returned
32
// in a paginated query. This can be tuned after some benchmarks.
33
const pageSize = 2000
34

35
// ProtocolVersion is an enum that defines the gossip protocol version of a
36
// message.
37
type ProtocolVersion uint8
38

39
const (
40
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
41
        ProtocolV1 ProtocolVersion = 1
42
)
43

44
// String returns a string representation of the protocol version.
45
func (v ProtocolVersion) String() string {
×
46
        return fmt.Sprintf("V%d", v)
×
47
}
×
48

49
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
50
// execute queries against the SQL graph tables.
51
//
52
//nolint:ll,interfacebloat
53
type SQLQueries interface {
54
        /*
55
                Node queries.
56
        */
57
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
58
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
59
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
60
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
61
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
62
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
63
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
64

65
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
66
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
67
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
68

69
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
70
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
71
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
72

73
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
74
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
75
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
76
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
77

78
        /*
79
                Source node queries.
80
        */
81
        AddSourceNode(ctx context.Context, nodeID int64) error
82
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
83

84
        /*
85
                Channel queries.
86
        */
87
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
88
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
89
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
90
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
91
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
92
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
93

94
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
95
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
96

97
        /*
98
                Channel Policy table queries.
99
        */
100
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
101

102
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
103
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
104
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
105
}
106

107
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
108
// database operations.
109
type BatchedSQLQueries interface {
110
        SQLQueries
111
        sqldb.BatchedTx[SQLQueries]
112
}
113

114
// SQLStore is an implementation of the V1Store interface that uses a SQL
115
// database as the backend.
116
//
117
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
118
// implement the V1Store interface incrementally. For any method not
119
// implemented,  things will fall back to the KVStore. This is ONLY the case
120
// for the time being while this struct is purely used in unit tests only.
121
type SQLStore struct {
122
        cfg *SQLStoreConfig
123
        db  BatchedSQLQueries
124

125
        // cacheMu guards all caches (rejectCache and chanCache). If
126
        // this mutex will be acquired at the same time as the DB mutex then
127
        // the cacheMu MUST be acquired first to prevent deadlock.
128
        cacheMu     sync.RWMutex
129
        rejectCache *rejectCache
130
        chanCache   *channelCache
131

132
        chanScheduler batch.Scheduler[SQLQueries]
133
        nodeScheduler batch.Scheduler[SQLQueries]
134

135
        srcNodes  map[ProtocolVersion]*srcNodeInfo
136
        srcNodeMu sync.Mutex
137

138
        // Temporary fall-back to the KVStore so that we can implement the
139
        // interface incrementally.
140
        *KVStore
141
}
142

143
// A compile-time assertion to ensure that SQLStore implements the V1Store
144
// interface.
145
var _ V1Store = (*SQLStore)(nil)
146

147
// SQLStoreConfig holds the configuration for the SQLStore.
148
type SQLStoreConfig struct {
149
        // ChainHash is the genesis hash for the chain that all the gossip
150
        // messages in this store are aimed at.
151
        ChainHash chainhash.Hash
152
}
153

154
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
155
// storage backend.
156
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
157
        options ...StoreOptionModifier) (*SQLStore, error) {
×
158

×
159
        opts := DefaultOptions()
×
160
        for _, o := range options {
×
161
                o(opts)
×
162
        }
×
163

164
        if opts.NoMigration {
×
165
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
166
                        "supported for SQL stores")
×
167
        }
×
168

169
        s := &SQLStore{
×
170
                cfg:         cfg,
×
171
                db:          db,
×
172
                KVStore:     kvStore,
×
173
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
174
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
175
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
176
        }
×
177

×
178
        s.chanScheduler = batch.NewTimeScheduler(
×
179
                db, &s.cacheMu, opts.BatchCommitInterval,
×
180
        )
×
181
        s.nodeScheduler = batch.NewTimeScheduler(
×
182
                db, nil, opts.BatchCommitInterval,
×
183
        )
×
184

×
185
        return s, nil
×
186
}
187

188
// AddLightningNode adds a vertex/node to the graph database. If the node is not
189
// in the database from before, this will add a new, unconnected one to the
190
// graph. If it is present from before, this will update that node's
191
// information.
192
//
193
// NOTE: part of the V1Store interface.
194
func (s *SQLStore) AddLightningNode(ctx context.Context,
195
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
196

×
197
        r := &batch.Request[SQLQueries]{
×
198
                Opts: batch.NewSchedulerOptions(opts...),
×
199
                Do: func(queries SQLQueries) error {
×
200
                        _, err := upsertNode(ctx, queries, node)
×
201
                        return err
×
202
                },
×
203
        }
204

205
        return s.nodeScheduler.Execute(ctx, r)
×
206
}
207

208
// FetchLightningNode attempts to look up a target node by its identity public
209
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
210
// returned.
211
//
212
// NOTE: part of the V1Store interface.
213
func (s *SQLStore) FetchLightningNode(ctx context.Context,
214
        pubKey route.Vertex) (*models.LightningNode, error) {
×
215

×
216
        var node *models.LightningNode
×
217
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
218
                var err error
×
219
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
220

×
221
                return err
×
222
        }, sqldb.NoOpReset)
×
223
        if err != nil {
×
224
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
225
        }
×
226

227
        return node, nil
×
228
}
229

230
// HasLightningNode determines if the graph has a vertex identified by the
231
// target node identity public key. If the node exists in the database, a
232
// timestamp of when the data for the node was lasted updated is returned along
233
// with a true boolean. Otherwise, an empty time.Time is returned with a false
234
// boolean.
235
//
236
// NOTE: part of the V1Store interface.
237
func (s *SQLStore) HasLightningNode(ctx context.Context,
238
        pubKey [33]byte) (time.Time, bool, error) {
×
239

×
240
        var (
×
241
                exists     bool
×
242
                lastUpdate time.Time
×
243
        )
×
244
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
245
                dbNode, err := db.GetNodeByPubKey(
×
246
                        ctx, sqlc.GetNodeByPubKeyParams{
×
247
                                Version: int16(ProtocolV1),
×
248
                                PubKey:  pubKey[:],
×
249
                        },
×
250
                )
×
251
                if errors.Is(err, sql.ErrNoRows) {
×
252
                        return nil
×
253
                } else if err != nil {
×
254
                        return fmt.Errorf("unable to fetch node: %w", err)
×
255
                }
×
256

257
                exists = true
×
258

×
259
                if dbNode.LastUpdate.Valid {
×
260
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
261
                }
×
262

263
                return nil
×
264
        }, sqldb.NoOpReset)
265
        if err != nil {
×
266
                return time.Time{}, false,
×
267
                        fmt.Errorf("unable to fetch node: %w", err)
×
268
        }
×
269

270
        return lastUpdate, exists, nil
×
271
}
272

273
// AddrsForNode returns all known addresses for the target node public key
274
// that the graph DB is aware of. The returned boolean indicates if the
275
// given node is unknown to the graph DB or not.
276
//
277
// NOTE: part of the V1Store interface.
278
func (s *SQLStore) AddrsForNode(ctx context.Context,
279
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
280

×
281
        var (
×
282
                addresses []net.Addr
×
283
                known     bool
×
284
        )
×
285
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
286
                var err error
×
287
                known, addresses, err = getNodeAddresses(
×
288
                        ctx, db, nodePub.SerializeCompressed(),
×
289
                )
×
290
                if err != nil {
×
291
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
292
                                err)
×
293
                }
×
294

295
                return nil
×
296
        }, sqldb.NoOpReset)
297
        if err != nil {
×
298
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
299
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
300
        }
×
301

302
        return known, addresses, nil
×
303
}
304

305
// DeleteLightningNode starts a new database transaction to remove a vertex/node
306
// from the database according to the node's public key.
307
//
308
// NOTE: part of the V1Store interface.
309
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
310
        pubKey route.Vertex) error {
×
311

×
312
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
313
                res, err := db.DeleteNodeByPubKey(
×
314
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
315
                                Version: int16(ProtocolV1),
×
316
                                PubKey:  pubKey[:],
×
317
                        },
×
318
                )
×
319
                if err != nil {
×
320
                        return err
×
321
                }
×
322

323
                rows, err := res.RowsAffected()
×
324
                if err != nil {
×
325
                        return err
×
326
                }
×
327

328
                if rows == 0 {
×
329
                        return ErrGraphNodeNotFound
×
330
                } else if rows > 1 {
×
331
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
332
                }
×
333

334
                return err
×
335
        }, sqldb.NoOpReset)
336
        if err != nil {
×
337
                return fmt.Errorf("unable to delete node: %w", err)
×
338
        }
×
339

340
        return nil
×
341
}
342

343
// FetchNodeFeatures returns the features of the given node. If no features are
344
// known for the node, an empty feature vector is returned.
345
//
346
// NOTE: this is part of the graphdb.NodeTraverser interface.
347
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
348
        *lnwire.FeatureVector, error) {
×
349

×
350
        ctx := context.TODO()
×
351

×
352
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
353
}
×
354

355
// LookupAlias attempts to return the alias as advertised by the target node.
356
//
357
// NOTE: part of the V1Store interface.
358
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
×
359
        var (
×
360
                ctx   = context.TODO()
×
361
                alias string
×
362
        )
×
363
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
364
                dbNode, err := db.GetNodeByPubKey(
×
365
                        ctx, sqlc.GetNodeByPubKeyParams{
×
366
                                Version: int16(ProtocolV1),
×
367
                                PubKey:  pub.SerializeCompressed(),
×
368
                        },
×
369
                )
×
370
                if errors.Is(err, sql.ErrNoRows) {
×
371
                        return ErrNodeAliasNotFound
×
372
                } else if err != nil {
×
373
                        return fmt.Errorf("unable to fetch node: %w", err)
×
374
                }
×
375

376
                if !dbNode.Alias.Valid {
×
377
                        return ErrNodeAliasNotFound
×
378
                }
×
379

380
                alias = dbNode.Alias.String
×
381

×
382
                return nil
×
383
        }, sqldb.NoOpReset)
384
        if err != nil {
×
385
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
386
        }
×
387

388
        return alias, nil
×
389
}
390

391
// SourceNode returns the source node of the graph. The source node is treated
392
// as the center node within a star-graph. This method may be used to kick off
393
// a path finding algorithm in order to explore the reachability of another
394
// node based off the source node.
395
//
396
// NOTE: part of the V1Store interface.
397
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
×
398
        ctx := context.TODO()
×
399

×
400
        var node *models.LightningNode
×
401
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
402
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
403
                if err != nil {
×
404
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
405
                                err)
×
406
                }
×
407

408
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
409

×
410
                return err
×
411
        }, sqldb.NoOpReset)
412
        if err != nil {
×
413
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
414
        }
×
415

416
        return node, nil
×
417
}
418

419
// SetSourceNode sets the source node within the graph database. The source
420
// node is to be used as the center of a star-graph within path finding
421
// algorithms.
422
//
423
// NOTE: part of the V1Store interface.
424
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
×
425
        ctx := context.TODO()
×
426

×
427
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
428
                id, err := upsertNode(ctx, db, node)
×
429
                if err != nil {
×
430
                        return fmt.Errorf("unable to upsert source node: %w",
×
431
                                err)
×
432
                }
×
433

434
                // Make sure that if a source node for this version is already
435
                // set, then the ID is the same as the one we are about to set.
436
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
437
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
438
                        return fmt.Errorf("unable to fetch source node: %w",
×
439
                                err)
×
440
                } else if err == nil {
×
441
                        if dbSourceNodeID != id {
×
442
                                return fmt.Errorf("v1 source node already "+
×
443
                                        "set to a different node: %d vs %d",
×
444
                                        dbSourceNodeID, id)
×
445
                        }
×
446

447
                        return nil
×
448
                }
449

450
                return db.AddSourceNode(ctx, id)
×
451
        }, sqldb.NoOpReset)
452
}
453

454
// NodeUpdatesInHorizon returns all the known lightning node which have an
455
// update timestamp within the passed range. This method can be used by two
456
// nodes to quickly determine if they have the same set of up to date node
457
// announcements.
458
//
459
// NOTE: This is part of the V1Store interface.
460
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
461
        endTime time.Time) ([]models.LightningNode, error) {
×
462

×
463
        ctx := context.TODO()
×
464

×
465
        var nodes []models.LightningNode
×
466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
467
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
468
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
469
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
470
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
471
                        },
×
472
                )
×
473
                if err != nil {
×
474
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
475
                }
×
476

477
                for _, dbNode := range dbNodes {
×
478
                        node, err := buildNode(ctx, db, &dbNode)
×
479
                        if err != nil {
×
480
                                return fmt.Errorf("unable to build node: %w",
×
481
                                        err)
×
482
                        }
×
483

484
                        nodes = append(nodes, *node)
×
485
                }
486

487
                return nil
×
488
        }, sqldb.NoOpReset)
489
        if err != nil {
×
490
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
491
        }
×
492

493
        return nodes, nil
×
494
}
495

496
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
497
// undirected edge from the two target nodes are created. The information stored
498
// denotes the static attributes of the channel, such as the channelID, the keys
499
// involved in creation of the channel, and the set of features that the channel
500
// supports. The chanPoint and chanID are used to uniquely identify the edge
501
// globally within the database.
502
//
503
// NOTE: part of the V1Store interface.
504
func (s *SQLStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
505
        opts ...batch.SchedulerOption) error {
×
506

×
507
        ctx := context.TODO()
×
508

×
509
        var alreadyExists bool
×
510
        r := &batch.Request[SQLQueries]{
×
511
                Opts: batch.NewSchedulerOptions(opts...),
×
512
                Reset: func() {
×
513
                        alreadyExists = false
×
514
                },
×
515
                Do: func(tx SQLQueries) error {
×
516
                        err := insertChannel(ctx, tx, edge)
×
517

×
518
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
519
                        // succeed, but propagate the error via local state.
×
520
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
521
                                alreadyExists = true
×
522
                                return nil
×
523
                        }
×
524

525
                        return err
×
526
                },
527
                OnCommit: func(err error) error {
×
528
                        switch {
×
529
                        case err != nil:
×
530
                                return err
×
531
                        case alreadyExists:
×
532
                                return ErrEdgeAlreadyExist
×
533
                        default:
×
534
                                s.rejectCache.remove(edge.ChannelID)
×
535
                                s.chanCache.remove(edge.ChannelID)
×
536
                                return nil
×
537
                        }
538
                },
539
        }
540

541
        return s.chanScheduler.Execute(ctx, r)
×
542
}
543

544
// HighestChanID returns the "highest" known channel ID in the channel graph.
545
// This represents the "newest" channel from the PoV of the chain. This method
546
// can be used by peers to quickly determine if their graphs are in sync.
547
//
548
// NOTE: This is part of the V1Store interface.
549
func (s *SQLStore) HighestChanID() (uint64, error) {
×
550
        ctx := context.TODO()
×
551

×
552
        var highestChanID uint64
×
553
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
554
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
555
                if errors.Is(err, sql.ErrNoRows) {
×
556
                        return nil
×
557
                } else if err != nil {
×
558
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
559
                                err)
×
560
                }
×
561

562
                highestChanID = byteOrder.Uint64(chanID)
×
563

×
564
                return nil
×
565
        }, sqldb.NoOpReset)
566
        if err != nil {
×
567
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
568
        }
×
569

570
        return highestChanID, nil
×
571
}
572

573
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
574
// within the database for the referenced channel. The `flags` attribute within
575
// the ChannelEdgePolicy determines which of the directed edges are being
576
// updated. If the flag is 1, then the first node's information is being
577
// updated, otherwise it's the second node's information. The node ordering is
578
// determined by the lexicographical ordering of the identity public keys of the
579
// nodes on either side of the channel.
580
//
581
// NOTE: part of the V1Store interface.
582
func (s *SQLStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
583
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
584

×
585
        ctx := context.TODO()
×
586

×
587
        var (
×
588
                isUpdate1    bool
×
589
                edgeNotFound bool
×
590
                from, to     route.Vertex
×
591
        )
×
592

×
593
        r := &batch.Request[SQLQueries]{
×
594
                Opts: batch.NewSchedulerOptions(opts...),
×
595
                Reset: func() {
×
596
                        isUpdate1 = false
×
597
                        edgeNotFound = false
×
598
                },
×
599
                Do: func(tx SQLQueries) error {
×
600
                        var err error
×
601
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
602
                                ctx, tx, edge,
×
603
                        )
×
604
                        if err != nil {
×
605
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
606
                        }
×
607

608
                        // Silence ErrEdgeNotFound so that the batch can
609
                        // succeed, but propagate the error via local state.
610
                        if errors.Is(err, ErrEdgeNotFound) {
×
611
                                edgeNotFound = true
×
612
                                return nil
×
613
                        }
×
614

615
                        return err
×
616
                },
617
                OnCommit: func(err error) error {
×
618
                        switch {
×
619
                        case err != nil:
×
620
                                return err
×
621
                        case edgeNotFound:
×
622
                                return ErrEdgeNotFound
×
623
                        default:
×
624
                                s.updateEdgeCache(edge, isUpdate1)
×
625
                                return nil
×
626
                        }
627
                },
628
        }
629

630
        err := s.chanScheduler.Execute(ctx, r)
×
631

×
632
        return from, to, err
×
633
}
634

635
// updateEdgeCache updates our reject and channel caches with the new
636
// edge policy information.
637
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
638
        isUpdate1 bool) {
×
639

×
640
        // If an entry for this channel is found in reject cache, we'll modify
×
641
        // the entry with the updated timestamp for the direction that was just
×
642
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
643
        // during the next query for this edge.
×
644
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
645
                if isUpdate1 {
×
646
                        entry.upd1Time = e.LastUpdate.Unix()
×
647
                } else {
×
648
                        entry.upd2Time = e.LastUpdate.Unix()
×
649
                }
×
650
                s.rejectCache.insert(e.ChannelID, entry)
×
651
        }
652

653
        // If an entry for this channel is found in channel cache, we'll modify
654
        // the entry with the updated policy for the direction that was just
655
        // written. If the edge doesn't exist, we'll defer loading the info and
656
        // policies and lazily read from disk during the next query.
657
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
658
                if isUpdate1 {
×
659
                        channel.Policy1 = e
×
660
                } else {
×
661
                        channel.Policy2 = e
×
662
                }
×
663
                s.chanCache.insert(e.ChannelID, channel)
×
664
        }
665
}
666

667
// ForEachSourceNodeChannel iterates through all channels of the source node,
668
// executing the passed callback on each. The call-back is provided with the
669
// channel's outpoint, whether we have a policy for the channel and the channel
670
// peer's node information.
671
//
672
// NOTE: part of the V1Store interface.
673
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
674
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
675

×
676
        var ctx = context.TODO()
×
677

×
678
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
679
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
680
                if err != nil {
×
681
                        return fmt.Errorf("unable to fetch source node: %w",
×
682
                                err)
×
683
                }
×
684

685
                return forEachNodeChannel(
×
686
                        ctx, db, s.cfg.ChainHash, nodeID,
×
687
                        func(info *models.ChannelEdgeInfo,
×
688
                                outPolicy *models.ChannelEdgePolicy,
×
689
                                _ *models.ChannelEdgePolicy) error {
×
690

×
691
                                // Fetch the other node.
×
692
                                var (
×
693
                                        otherNodePub [33]byte
×
694
                                        node1        = info.NodeKey1Bytes
×
695
                                        node2        = info.NodeKey2Bytes
×
696
                                )
×
697
                                switch {
×
698
                                case bytes.Equal(node1[:], nodePub[:]):
×
699
                                        otherNodePub = node2
×
700
                                case bytes.Equal(node2[:], nodePub[:]):
×
701
                                        otherNodePub = node1
×
702
                                default:
×
703
                                        return fmt.Errorf("node not " +
×
704
                                                "participating in this channel")
×
705
                                }
706

707
                                _, otherNode, err := getNodeByPubKey(
×
708
                                        ctx, db, otherNodePub,
×
709
                                )
×
710
                                if err != nil {
×
711
                                        return fmt.Errorf("unable to fetch "+
×
712
                                                "other node(%x): %w",
×
713
                                                otherNodePub, err)
×
714
                                }
×
715

716
                                return cb(
×
717
                                        info.ChannelPoint, outPolicy != nil,
×
718
                                        otherNode,
×
719
                                )
×
720
                        },
721
                )
722
        }, sqldb.NoOpReset)
723
}
724

725
// ForEachNode iterates through all the stored vertices/nodes in the graph,
726
// executing the passed callback with each node encountered. If the callback
727
// returns an error, then the transaction is aborted and the iteration stops
728
// early. Any operations performed on the NodeTx passed to the call-back are
729
// executed under the same read transaction and so, methods on the NodeTx object
730
// _MUST_ only be called from within the call-back.
731
//
732
// NOTE: part of the V1Store interface.
NEW
733
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
NEW
734
        var (
×
NEW
735
                ctx          = context.TODO()
×
NEW
736
                lastID int64 = 0
×
NEW
737
        )
×
NEW
738

×
NEW
739
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
NEW
740
                node, err := buildNode(ctx, db, &dbNode)
×
NEW
741
                if err != nil {
×
NEW
742
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
NEW
743
                                dbNode.ID, err)
×
NEW
744
                }
×
745

NEW
746
                err = cb(
×
NEW
747
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
NEW
748
                )
×
NEW
749
                if err != nil {
×
NEW
750
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
NEW
751
                                dbNode.ID, err)
×
NEW
752
                }
×
753

NEW
754
                return nil
×
755
        }
756

NEW
757
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
758
                for {
×
NEW
759
                        nodes, err := db.ListNodesPaginated(
×
NEW
760
                                ctx, sqlc.ListNodesPaginatedParams{
×
NEW
761
                                        Version: int16(ProtocolV1),
×
NEW
762
                                        ID:      lastID,
×
NEW
763
                                        Limit:   pageSize,
×
NEW
764
                                },
×
NEW
765
                        )
×
NEW
766
                        if err != nil {
×
NEW
767
                                return fmt.Errorf("unable to fetch nodes: %w",
×
NEW
768
                                        err)
×
NEW
769
                        }
×
770

NEW
771
                        if len(nodes) == 0 {
×
NEW
772
                                break
×
773
                        }
774

NEW
775
                        for _, dbNode := range nodes {
×
NEW
776
                                err = handleNode(db, dbNode)
×
NEW
777
                                if err != nil {
×
NEW
778
                                        return err
×
NEW
779
                                }
×
780

NEW
781
                                lastID = dbNode.ID
×
782
                        }
783
                }
784

NEW
785
                return nil
×
786
        }, sqldb.NoOpReset)
787
}
788

789
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
790
// SQLStore and a SQL transaction.
791
type sqlGraphNodeTx struct {
792
        db    SQLQueries
793
        id    int64
794
        node  *models.LightningNode
795
        chain chainhash.Hash
796
}
797

798
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
799
// interface.
800
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
801

802
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
NEW
803
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
NEW
804

×
NEW
805
        return &sqlGraphNodeTx{
×
NEW
806
                db:    db,
×
NEW
807
                chain: chain,
×
NEW
808
                id:    id,
×
NEW
809
                node:  node,
×
NEW
810
        }
×
NEW
811
}
×
812

813
// Node returns the raw information of the node.
814
//
815
// NOTE: This is a part of the NodeRTx interface.
NEW
816
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
NEW
817
        return s.node
×
NEW
818
}
×
819

820
// ForEachChannel can be used to iterate over the node's channels under the same
821
// transaction used to fetch the node.
822
//
823
// NOTE: This is a part of the NodeRTx interface.
824
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
NEW
825
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
NEW
826

×
NEW
827
        ctx := context.TODO()
×
NEW
828

×
NEW
829
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
NEW
830
}
×
831

832
// FetchNode fetches the node with the given pub key under the same transaction
833
// used to fetch the current node. The returned node is also a NodeRTx and any
834
// operations on that NodeRTx will also be done under the same transaction.
835
//
836
// NOTE: This is a part of the NodeRTx interface.
NEW
837
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
NEW
838
        ctx := context.TODO()
×
NEW
839

×
NEW
840
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
NEW
841
        if err != nil {
×
NEW
842
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
NEW
843
                        nodePub, err)
×
NEW
844
        }
×
845

NEW
846
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
847
}
848

849
// ForEachNodeDirectedChannel iterates through all channels of a given node,
850
// executing the passed callback on the directed edge representing the channel
851
// and its incoming policy. If the callback returns an error, then the iteration
852
// is halted with the error propagated back up to the caller.
853
//
854
// Unknown policies are passed into the callback as nil values.
855
//
856
// NOTE: this is part of the graphdb.NodeTraverser interface.
857
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
NEW
858
        cb func(channel *DirectedChannel) error) error {
×
NEW
859

×
NEW
860
        var ctx = context.TODO()
×
NEW
861

×
NEW
862
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
863
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
NEW
864
        }, sqldb.NoOpReset)
×
865
}
866

867
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
868
// graph, executing the passed callback with each node encountered. If the
869
// callback returns an error, then the transaction is aborted and the iteration
870
// stops early.
871
//
872
// NOTE: This is a part of the V1Store interface.
873
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
NEW
874
        *lnwire.FeatureVector) error) error {
×
NEW
875

×
NEW
876
        ctx := context.TODO()
×
NEW
877

×
NEW
878
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
879
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
NEW
880
                        nodePub route.Vertex) error {
×
NEW
881

×
NEW
882
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
883
                        if err != nil {
×
NEW
884
                                return fmt.Errorf("unable to fetch node "+
×
NEW
885
                                        "features: %w", err)
×
NEW
886
                        }
×
887

NEW
888
                        return cb(nodePub, features)
×
889
                })
890
        }, sqldb.NoOpReset)
NEW
891
        if err != nil {
×
NEW
892
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
NEW
893
        }
×
894

NEW
895
        return nil
×
896
}
897

898
// ForEachNodeChannel iterates through all channels of the given node,
899
// executing the passed callback with an edge info structure and the policies
900
// of each end of the channel. The first edge policy is the outgoing edge *to*
901
// the connecting node, while the second is the incoming edge *from* the
902
// connecting node. If the callback returns an error, then the iteration is
903
// halted with the error propagated back up to the caller.
904
//
905
// Unknown policies are passed into the callback as nil values.
906
//
907
// NOTE: part of the V1Store interface.
908
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
909
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
NEW
910
                *models.ChannelEdgePolicy) error) error {
×
NEW
911

×
NEW
912
        var ctx = context.TODO()
×
NEW
913

×
NEW
914
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
915
                dbNode, err := db.GetNodeByPubKey(
×
NEW
916
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
917
                                Version: int16(ProtocolV1),
×
NEW
918
                                PubKey:  nodePub[:],
×
NEW
919
                        },
×
NEW
920
                )
×
NEW
921
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
922
                        return nil
×
NEW
923
                } else if err != nil {
×
NEW
924
                        return fmt.Errorf("unable to fetch node: %w", err)
×
NEW
925
                }
×
926

NEW
927
                return forEachNodeChannel(
×
NEW
928
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
NEW
929
                )
×
930
        }, sqldb.NoOpReset)
931
}
932

933
// forEachNodeDirectedChannel iterates through all channels of a given
934
// node, executing the passed callback on the directed edge representing the
935
// channel and its incoming policy. If the node is not found, no error is
936
// returned.
937
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
NEW
938
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
NEW
939

×
NEW
940
        toNodeCallback := func() route.Vertex {
×
NEW
941
                return nodePub
×
NEW
942
        }
×
943

NEW
944
        dbID, err := db.GetNodeIDByPubKey(
×
NEW
945
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
946
                        Version: int16(ProtocolV1),
×
NEW
947
                        PubKey:  nodePub[:],
×
NEW
948
                },
×
NEW
949
        )
×
NEW
950
        if errors.Is(err, sql.ErrNoRows) {
×
NEW
951
                return nil
×
NEW
952
        } else if err != nil {
×
NEW
953
                return fmt.Errorf("unable to fetch node: %w", err)
×
NEW
954
        }
×
955

NEW
956
        rows, err := db.ListChannelsByNodeID(
×
NEW
957
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
958
                        Version: int16(ProtocolV1),
×
NEW
959
                        NodeID1: dbID,
×
NEW
960
                },
×
NEW
961
        )
×
NEW
962
        if err != nil {
×
NEW
963
                return fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
964
        }
×
965

966
        // Exit early if there are no channels for this node so we don't
967
        // do the unnecessary feature fetching.
NEW
968
        if len(rows) == 0 {
×
NEW
969
                return nil
×
NEW
970
        }
×
971

NEW
972
        features, err := getNodeFeatures(ctx, db, dbID)
×
NEW
973
        if err != nil {
×
NEW
974
                return fmt.Errorf("unable to fetch node features: %w", err)
×
NEW
975
        }
×
976

NEW
977
        for _, row := range rows {
×
NEW
978
                node1, node2, err := buildNodeVertices(
×
NEW
979
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
980
                )
×
NEW
981
                if err != nil {
×
NEW
982
                        return fmt.Errorf("unable to build node vertices: %w",
×
NEW
983
                                err)
×
NEW
984
                }
×
985

NEW
986
                edge, err := buildCacheableChannelInfo(
×
NEW
987
                        row.Channel, node1, node2,
×
NEW
988
                )
×
NEW
989
                if err != nil {
×
NEW
990
                        return err
×
NEW
991
                }
×
992

NEW
993
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
994
                if err != nil {
×
NEW
995
                        return err
×
NEW
996
                }
×
997

NEW
998
                var p1, p2 *models.CachedEdgePolicy
×
NEW
999
                if dbPol1 != nil {
×
NEW
1000
                        policy1, err := buildChanPolicy(
×
NEW
1001
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
NEW
1002
                        )
×
NEW
1003
                        if err != nil {
×
NEW
1004
                                return err
×
NEW
1005
                        }
×
1006

NEW
1007
                        p1 = models.NewCachedPolicy(policy1)
×
1008
                }
NEW
1009
                if dbPol2 != nil {
×
NEW
1010
                        policy2, err := buildChanPolicy(
×
NEW
1011
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
NEW
1012
                        )
×
NEW
1013
                        if err != nil {
×
NEW
1014
                                return err
×
NEW
1015
                        }
×
1016

NEW
1017
                        p2 = models.NewCachedPolicy(policy2)
×
1018
                }
1019

1020
                // Determine the outgoing and incoming policy for this
1021
                // channel and node combo.
NEW
1022
                outPolicy, inPolicy := p1, p2
×
NEW
1023
                if p1 != nil && node2 == nodePub {
×
NEW
1024
                        outPolicy, inPolicy = p2, p1
×
NEW
1025
                } else if p2 != nil && node1 != nodePub {
×
NEW
1026
                        outPolicy, inPolicy = p2, p1
×
NEW
1027
                }
×
1028

NEW
1029
                var cachedInPolicy *models.CachedEdgePolicy
×
NEW
1030
                if inPolicy != nil {
×
NEW
1031
                        cachedInPolicy = inPolicy
×
NEW
1032
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
NEW
1033
                        cachedInPolicy.ToNodeFeatures = features
×
NEW
1034
                }
×
1035

NEW
1036
                directedChannel := &DirectedChannel{
×
NEW
1037
                        ChannelID:    edge.ChannelID,
×
NEW
1038
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
NEW
1039
                        OtherNode:    edge.NodeKey2Bytes,
×
NEW
1040
                        Capacity:     edge.Capacity,
×
NEW
1041
                        OutPolicySet: outPolicy != nil,
×
NEW
1042
                        InPolicy:     cachedInPolicy,
×
NEW
1043
                }
×
NEW
1044
                if outPolicy != nil {
×
NEW
1045
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
NEW
1046
                                directedChannel.InboundFee = fee
×
NEW
1047
                        })
×
1048
                }
1049

NEW
1050
                if nodePub == edge.NodeKey2Bytes {
×
NEW
1051
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
NEW
1052
                }
×
1053

NEW
1054
                if err := cb(directedChannel); err != nil {
×
NEW
1055
                        return err
×
NEW
1056
                }
×
1057
        }
1058

NEW
1059
        return nil
×
1060
}
1061

1062
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
1063
// and executes the provided callback for each node.
1064
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
NEW
1065
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
NEW
1066

×
NEW
1067
        var lastID int64
×
NEW
1068

×
NEW
1069
        for {
×
NEW
1070
                nodes, err := db.ListNodeIDsAndPubKeys(
×
NEW
1071
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
NEW
1072
                                Version: int16(ProtocolV1),
×
NEW
1073
                                ID:      lastID,
×
NEW
1074
                                Limit:   pageSize,
×
NEW
1075
                        },
×
NEW
1076
                )
×
NEW
1077
                if err != nil {
×
NEW
1078
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
NEW
1079
                }
×
1080

NEW
1081
                if len(nodes) == 0 {
×
NEW
1082
                        break
×
1083
                }
1084

NEW
1085
                for _, node := range nodes {
×
NEW
1086
                        var pub route.Vertex
×
NEW
1087
                        copy(pub[:], node.PubKey)
×
NEW
1088

×
NEW
1089
                        if err := cb(node.ID, pub); err != nil {
×
NEW
1090
                                return fmt.Errorf("forEachNodeCacheable "+
×
NEW
1091
                                        "callback failed for node(id=%d): %w",
×
NEW
1092
                                        node.ID, err)
×
NEW
1093
                        }
×
1094

NEW
1095
                        lastID = node.ID
×
1096
                }
1097
        }
1098

NEW
1099
        return nil
×
1100
}
1101

1102
// forEachNodeChannel iterates through all channels of a node, executing
1103
// the passed callback on each. The call-back is provided with the channel's
1104
// edge information, the outgoing policy and the incoming policy for the
1105
// channel and node combo.
1106
func forEachNodeChannel(ctx context.Context, db SQLQueries,
1107
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
1108
                *models.ChannelEdgePolicy,
1109
                *models.ChannelEdgePolicy) error) error {
×
1110

×
1111
        // Get all the V1 channels for this node.Add commentMore actions
×
1112
        rows, err := db.ListChannelsByNodeID(
×
1113
                ctx, sqlc.ListChannelsByNodeIDParams{
×
1114
                        Version: int16(ProtocolV1),
×
1115
                        NodeID1: id,
×
1116
                },
×
1117
        )
×
1118
        if err != nil {
×
1119
                return fmt.Errorf("unable to fetch channels: %w", err)
×
1120
        }
×
1121

1122
        // Call the call-back for each channel and its known policies.
1123
        for _, row := range rows {
×
1124
                node1, node2, err := buildNodeVertices(
×
1125
                        row.Node1Pubkey, row.Node2Pubkey,
×
1126
                )
×
1127
                if err != nil {
×
1128
                        return fmt.Errorf("unable to build node vertices: %w",
×
1129
                                err)
×
1130
                }
×
1131

1132
                edge, err := getAndBuildEdgeInfo(
×
1133
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
1134
                        node2,
×
1135
                )
×
1136
                if err != nil {
×
1137
                        return fmt.Errorf("unable to build channel info: %w",
×
1138
                                err)
×
1139
                }
×
1140

1141
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1142
                if err != nil {
×
1143
                        return fmt.Errorf("unable to extract channel "+
×
1144
                                "policies: %w", err)
×
1145
                }
×
1146

1147
                p1, p2, err := getAndBuildChanPolicies(
×
1148
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1149
                )
×
1150
                if err != nil {
×
1151
                        return fmt.Errorf("unable to build channel "+
×
1152
                                "policies: %w", err)
×
1153
                }
×
1154

1155
                // Determine the outgoing and incoming policy for this
1156
                // channel and node combo.
1157
                p1ToNode := row.Channel.NodeID2
×
1158
                p2ToNode := row.Channel.NodeID1
×
1159
                outPolicy, inPolicy := p1, p2
×
1160
                if (p1 != nil && p1ToNode == id) ||
×
1161
                        (p2 != nil && p2ToNode != id) {
×
1162

×
1163
                        outPolicy, inPolicy = p2, p1
×
1164
                }
×
1165

1166
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
1167
                        return err
×
1168
                }
×
1169
        }
1170

1171
        return nil
×
1172
}
1173

1174
// updateChanEdgePolicy upserts the channel policy info we have stored for
1175
// a channel we already know of.
1176
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
1177
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
1178
        error) {
×
1179

×
1180
        var (
×
1181
                node1Pub, node2Pub route.Vertex
×
1182
                isNode1            bool
×
1183
                chanIDB            [8]byte
×
1184
        )
×
1185
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
1186

×
1187
        // Check that this edge policy refers to a channel that we already
×
1188
        // know of. We do this explicitly so that we can return the appropriate
×
1189
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
1190
        // abort the transaction which would abort the entire batch.
×
1191
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
1192
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
1193
                        Scid:    chanIDB[:],
×
1194
                        Version: int16(ProtocolV1),
×
1195
                },
×
1196
        )
×
1197
        if errors.Is(err, sql.ErrNoRows) {
×
1198
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
1199
        } else if err != nil {
×
1200
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
1201
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
1202
        }
×
1203

1204
        copy(node1Pub[:], dbChan.Node1PubKey)
×
1205
        copy(node2Pub[:], dbChan.Node2PubKey)
×
1206

×
1207
        // Figure out which node this edge is from.
×
1208
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
1209
        nodeID := dbChan.NodeID1
×
1210
        if !isNode1 {
×
1211
                nodeID = dbChan.NodeID2
×
1212
        }
×
1213

1214
        var (
×
1215
                inboundBase sql.NullInt64
×
1216
                inboundRate sql.NullInt64
×
1217
        )
×
1218
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
1219
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
1220
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
1221
        })
×
1222

1223
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
1224
                Version:     int16(ProtocolV1),
×
1225
                ChannelID:   dbChan.ID,
×
1226
                NodeID:      nodeID,
×
1227
                Timelock:    int32(edge.TimeLockDelta),
×
1228
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
1229
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
1230
                MinHtlcMsat: int64(edge.MinHTLC),
×
1231
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
1232
                Disabled: sql.NullBool{
×
1233
                        Valid: true,
×
1234
                        Bool:  edge.IsDisabled(),
×
1235
                },
×
1236
                MaxHtlcMsat: sql.NullInt64{
×
1237
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
1238
                        Int64: int64(edge.MaxHTLC),
×
1239
                },
×
1240
                InboundBaseFeeMsat:      inboundBase,
×
1241
                InboundFeeRateMilliMsat: inboundRate,
×
1242
                Signature:               edge.SigBytes,
×
1243
        })
×
1244
        if err != nil {
×
1245
                return node1Pub, node2Pub, isNode1,
×
1246
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
1247
        }
×
1248

1249
        // Convert the flat extra opaque data into a map of TLV types to
1250
        // values.
1251
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1252
        if err != nil {
×
1253
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
1254
                        "marshal extra opaque data: %w", err)
×
1255
        }
×
1256

1257
        // Update the channel policy's extra signed fields.
1258
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
1259
        if err != nil {
×
1260
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
1261
                        "policy extra TLVs: %w", err)
×
1262
        }
×
1263

1264
        return node1Pub, node2Pub, isNode1, nil
×
1265
}
1266

1267
// getNodeByPubKey attempts to look up a target node by its public key.
1268
func getNodeByPubKey(ctx context.Context, db SQLQueries,
1269
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
1270

×
1271
        dbNode, err := db.GetNodeByPubKey(
×
1272
                ctx, sqlc.GetNodeByPubKeyParams{
×
1273
                        Version: int16(ProtocolV1),
×
1274
                        PubKey:  pubKey[:],
×
1275
                },
×
1276
        )
×
1277
        if errors.Is(err, sql.ErrNoRows) {
×
1278
                return 0, nil, ErrGraphNodeNotFound
×
1279
        } else if err != nil {
×
1280
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
1281
        }
×
1282

1283
        node, err := buildNode(ctx, db, &dbNode)
×
1284
        if err != nil {
×
1285
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
1286
        }
×
1287

1288
        return dbNode.ID, node, nil
×
1289
}
1290

1291
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
1292
// provided database channel row and the public keys of the two nodes
1293
// involved in the channel.
1294
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
NEW
1295
        node2Pub route.Vertex) (*models.CachedEdgeInfo, error) {
×
NEW
1296

×
NEW
1297
        return &models.CachedEdgeInfo{
×
NEW
1298
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
NEW
1299
                NodeKey1Bytes: node1Pub,
×
NEW
1300
                NodeKey2Bytes: node2Pub,
×
NEW
1301
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
1302
        }, nil
×
NEW
1303
}
×
1304

1305
// buildNode constructs a LightningNode instance from the given database node
1306
// record. The node's features, addresses and extra signed fields are also
1307
// fetched from the database and set on the node.
1308
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
1309
        *models.LightningNode, error) {
×
1310

×
1311
        if dbNode.Version != int16(ProtocolV1) {
×
1312
                return nil, fmt.Errorf("unsupported node version: %d",
×
1313
                        dbNode.Version)
×
1314
        }
×
1315

1316
        var pub [33]byte
×
1317
        copy(pub[:], dbNode.PubKey)
×
1318

×
1319
        node := &models.LightningNode{
×
1320
                PubKeyBytes: pub,
×
1321
                Features:    lnwire.EmptyFeatureVector(),
×
1322
                LastUpdate:  time.Unix(0, 0),
×
1323
        }
×
1324

×
1325
        if len(dbNode.Signature) == 0 {
×
1326
                return node, nil
×
1327
        }
×
1328

1329
        node.HaveNodeAnnouncement = true
×
1330
        node.AuthSigBytes = dbNode.Signature
×
1331
        node.Alias = dbNode.Alias.String
×
1332
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
1333

×
1334
        var err error
×
NEW
1335
        if dbNode.Color.Valid {
×
NEW
1336
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
NEW
1337
                if err != nil {
×
NEW
1338
                        return nil, fmt.Errorf("unable to decode color: %w",
×
NEW
1339
                                err)
×
NEW
1340
                }
×
1341
        }
1342

1343
        // Fetch the node's features.
1344
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
1345
        if err != nil {
×
1346
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1347
                        "features: %w", dbNode.ID, err)
×
1348
        }
×
1349

1350
        // Fetch the node's addresses.
1351
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
1352
        if err != nil {
×
1353
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1354
                        "addresses: %w", dbNode.ID, err)
×
1355
        }
×
1356

1357
        // Fetch the node's extra signed fields.
1358
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
1359
        if err != nil {
×
1360
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1361
                        "extra signed fields: %w", dbNode.ID, err)
×
1362
        }
×
1363

1364
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
1365
        if err != nil {
×
1366
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
1367
                        "fields: %w", err)
×
1368
        }
×
1369

1370
        if len(recs) != 0 {
×
1371
                node.ExtraOpaqueData = recs
×
1372
        }
×
1373

1374
        return node, nil
×
1375
}
1376

1377
// getNodeFeatures fetches the feature bits and constructs the feature vector
1378
// for a node with the given DB ID.
1379
func getNodeFeatures(ctx context.Context, db SQLQueries,
1380
        nodeID int64) (*lnwire.FeatureVector, error) {
×
1381

×
1382
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
1383
        if err != nil {
×
1384
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
1385
                        nodeID, err)
×
1386
        }
×
1387

1388
        features := lnwire.EmptyFeatureVector()
×
1389
        for _, feature := range rows {
×
1390
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
1391
        }
×
1392

1393
        return features, nil
×
1394
}
1395

1396
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
1397
// given DB ID.
1398
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
1399
        nodeID int64) (map[uint64][]byte, error) {
×
1400

×
1401
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
1402
        if err != nil {
×
1403
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
1404
                        "signed fields: %w", nodeID, err)
×
1405
        }
×
1406

1407
        extraFields := make(map[uint64][]byte)
×
1408
        for _, field := range fields {
×
1409
                extraFields[uint64(field.Type)] = field.Value
×
1410
        }
×
1411

1412
        return extraFields, nil
×
1413
}
1414

1415
// upsertNode upserts the node record into the database. If the node already
1416
// exists, then the node's information is updated. If the node doesn't exist,
1417
// then a new node is created. The node's features, addresses and extra TLV
1418
// types are also updated. The node's DB ID is returned.
1419
func upsertNode(ctx context.Context, db SQLQueries,
1420
        node *models.LightningNode) (int64, error) {
×
1421

×
1422
        params := sqlc.UpsertNodeParams{
×
1423
                Version: int16(ProtocolV1),
×
1424
                PubKey:  node.PubKeyBytes[:],
×
1425
        }
×
1426

×
1427
        if node.HaveNodeAnnouncement {
×
1428
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
1429
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
1430
                params.Alias = sqldb.SQLStr(node.Alias)
×
1431
                params.Signature = node.AuthSigBytes
×
1432
        }
×
1433

1434
        nodeID, err := db.UpsertNode(ctx, params)
×
1435
        if err != nil {
×
1436
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
1437
                        err)
×
1438
        }
×
1439

1440
        // We can exit here if we don't have the announcement yet.
1441
        if !node.HaveNodeAnnouncement {
×
1442
                return nodeID, nil
×
1443
        }
×
1444

1445
        // Update the node's features.
1446
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
1447
        if err != nil {
×
1448
                return 0, fmt.Errorf("inserting node features: %w", err)
×
1449
        }
×
1450

1451
        // Update the node's addresses.
1452
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
1453
        if err != nil {
×
1454
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
1455
        }
×
1456

1457
        // Convert the flat extra opaque data into a map of TLV types to
1458
        // values.
1459
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
1460
        if err != nil {
×
1461
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
1462
                        err)
×
1463
        }
×
1464

1465
        // Update the node's extra signed fields.
1466
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
1467
        if err != nil {
×
1468
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
1469
        }
×
1470

1471
        return nodeID, nil
×
1472
}
1473

1474
// upsertNodeFeatures updates the node's features node_features table. This
1475
// includes deleting any feature bits no longer present and inserting any new
1476
// feature bits. If the feature bit does not yet exist in the features table,
1477
// then an entry is created in that table first.
1478
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
1479
        features *lnwire.FeatureVector) error {
×
1480

×
1481
        // Get any existing features for the node.
×
1482
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
1483
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1484
                return err
×
1485
        }
×
1486

1487
        // Copy the nodes latest set of feature bits.
1488
        newFeatures := make(map[int32]struct{})
×
1489
        if features != nil {
×
1490
                for feature := range features.Features() {
×
1491
                        newFeatures[int32(feature)] = struct{}{}
×
1492
                }
×
1493
        }
1494

1495
        // For any current feature that already exists in the DB, remove it from
1496
        // the in-memory map. For any existing feature that does not exist in
1497
        // the in-memory map, delete it from the database.
1498
        for _, feature := range existingFeatures {
×
1499
                // The feature is still present, so there are no updates to be
×
1500
                // made.
×
1501
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
1502
                        delete(newFeatures, feature.FeatureBit)
×
1503
                        continue
×
1504
                }
1505

1506
                // The feature is no longer present, so we remove it from the
1507
                // database.
1508
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
1509
                        NodeID:     nodeID,
×
1510
                        FeatureBit: feature.FeatureBit,
×
1511
                })
×
1512
                if err != nil {
×
1513
                        return fmt.Errorf("unable to delete node(%d) "+
×
1514
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
1515
                                err)
×
1516
                }
×
1517
        }
1518

1519
        // Any remaining entries in newFeatures are new features that need to be
1520
        // added to the database for the first time.
1521
        for feature := range newFeatures {
×
1522
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
1523
                        NodeID:     nodeID,
×
1524
                        FeatureBit: feature,
×
1525
                })
×
1526
                if err != nil {
×
1527
                        return fmt.Errorf("unable to insert node(%d) "+
×
1528
                                "feature(%v): %w", nodeID, feature, err)
×
1529
                }
×
1530
        }
1531

1532
        return nil
×
1533
}
1534

1535
// fetchNodeFeatures fetches the features for a node with the given public key.
1536
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
1537
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
1538

×
1539
        rows, err := queries.GetNodeFeaturesByPubKey(
×
1540
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
1541
                        PubKey:  nodePub[:],
×
1542
                        Version: int16(ProtocolV1),
×
1543
                },
×
1544
        )
×
1545
        if err != nil {
×
1546
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
1547
                        nodePub, err)
×
1548
        }
×
1549

1550
        features := lnwire.EmptyFeatureVector()
×
1551
        for _, bit := range rows {
×
1552
                features.Set(lnwire.FeatureBit(bit))
×
1553
        }
×
1554

1555
        return features, nil
×
1556
}
1557

1558
// dbAddressType is an enum type that represents the different address types
1559
// that we store in the node_addresses table. The address type determines how
1560
// the address is to be serialised/deserialize.
1561
type dbAddressType uint8
1562

1563
const (
1564
        addressTypeIPv4   dbAddressType = 1
1565
        addressTypeIPv6   dbAddressType = 2
1566
        addressTypeTorV2  dbAddressType = 3
1567
        addressTypeTorV3  dbAddressType = 4
1568
        addressTypeOpaque dbAddressType = math.MaxInt8
1569
)
1570

1571
// upsertNodeAddresses updates the node's addresses in the database. This
1572
// includes deleting any existing addresses and inserting the new set of
1573
// addresses. The deletion is necessary since the ordering of the addresses may
1574
// change, and we need to ensure that the database reflects the latest set of
1575
// addresses so that at the time of reconstructing the node announcement, the
1576
// order is preserved and the signature over the message remains valid.
1577
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
1578
        addresses []net.Addr) error {
×
1579

×
1580
        // Delete any existing addresses for the node. This is required since
×
1581
        // even if the new set of addresses is the same, the ordering may have
×
1582
        // changed for a given address type.
×
1583
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
1584
        if err != nil {
×
1585
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
1586
                        nodeID, err)
×
1587
        }
×
1588

1589
        // Copy the nodes latest set of addresses.
1590
        newAddresses := map[dbAddressType][]string{
×
1591
                addressTypeIPv4:   {},
×
1592
                addressTypeIPv6:   {},
×
1593
                addressTypeTorV2:  {},
×
1594
                addressTypeTorV3:  {},
×
1595
                addressTypeOpaque: {},
×
1596
        }
×
1597
        addAddr := func(t dbAddressType, addr net.Addr) {
×
1598
                newAddresses[t] = append(newAddresses[t], addr.String())
×
1599
        }
×
1600

1601
        for _, address := range addresses {
×
1602
                switch addr := address.(type) {
×
1603
                case *net.TCPAddr:
×
1604
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
1605
                                addAddr(addressTypeIPv4, addr)
×
1606
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
1607
                                addAddr(addressTypeIPv6, addr)
×
1608
                        } else {
×
1609
                                return fmt.Errorf("unhandled IP address: %v",
×
1610
                                        addr)
×
1611
                        }
×
1612

1613
                case *tor.OnionAddr:
×
1614
                        switch len(addr.OnionService) {
×
1615
                        case tor.V2Len:
×
1616
                                addAddr(addressTypeTorV2, addr)
×
1617
                        case tor.V3Len:
×
1618
                                addAddr(addressTypeTorV3, addr)
×
1619
                        default:
×
1620
                                return fmt.Errorf("invalid length for a tor " +
×
1621
                                        "address")
×
1622
                        }
1623

1624
                case *lnwire.OpaqueAddrs:
×
1625
                        addAddr(addressTypeOpaque, addr)
×
1626

1627
                default:
×
1628
                        return fmt.Errorf("unhandled address type: %T", addr)
×
1629
                }
1630
        }
1631

1632
        // Any remaining entries in newAddresses are new addresses that need to
1633
        // be added to the database for the first time.
1634
        for addrType, addrList := range newAddresses {
×
1635
                for position, addr := range addrList {
×
1636
                        err := db.InsertNodeAddress(
×
1637
                                ctx, sqlc.InsertNodeAddressParams{
×
1638
                                        NodeID:   nodeID,
×
1639
                                        Type:     int16(addrType),
×
1640
                                        Address:  addr,
×
1641
                                        Position: int32(position),
×
1642
                                },
×
1643
                        )
×
1644
                        if err != nil {
×
1645
                                return fmt.Errorf("unable to insert "+
×
1646
                                        "node(%d) address(%v): %w", nodeID,
×
1647
                                        addr, err)
×
1648
                        }
×
1649
                }
1650
        }
1651

1652
        return nil
×
1653
}
1654

1655
// getNodeAddresses fetches the addresses for a node with the given public key.
1656
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
1657
        []net.Addr, error) {
×
1658

×
1659
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
1660
        // are returned in the same order as they were inserted.
×
1661
        rows, err := db.GetNodeAddressesByPubKey(
×
1662
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
1663
                        Version: int16(ProtocolV1),
×
1664
                        PubKey:  nodePub,
×
1665
                },
×
1666
        )
×
1667
        if err != nil {
×
1668
                return false, nil, err
×
1669
        }
×
1670

1671
        // GetNodeAddressesByPubKey uses a left join so there should always be
1672
        // at least one row returned if the node exists even if it has no
1673
        // addresses.
1674
        if len(rows) == 0 {
×
1675
                return false, nil, nil
×
1676
        }
×
1677

1678
        addresses := make([]net.Addr, 0, len(rows))
×
1679
        for _, addr := range rows {
×
1680
                if !(addr.Type.Valid && addr.Address.Valid) {
×
1681
                        continue
×
1682
                }
1683

1684
                address := addr.Address.String
×
1685

×
1686
                switch dbAddressType(addr.Type.Int16) {
×
1687
                case addressTypeIPv4:
×
1688
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
1689
                        if err != nil {
×
1690
                                return false, nil, nil
×
1691
                        }
×
1692
                        tcp.IP = tcp.IP.To4()
×
1693

×
1694
                        addresses = append(addresses, tcp)
×
1695

1696
                case addressTypeIPv6:
×
1697
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
1698
                        if err != nil {
×
1699
                                return false, nil, nil
×
1700
                        }
×
1701
                        addresses = append(addresses, tcp)
×
1702

1703
                case addressTypeTorV3, addressTypeTorV2:
×
1704
                        service, portStr, err := net.SplitHostPort(address)
×
1705
                        if err != nil {
×
1706
                                return false, nil, fmt.Errorf("unable to "+
×
1707
                                        "split tor v3 address: %v",
×
1708
                                        addr.Address)
×
1709
                        }
×
1710

1711
                        port, err := strconv.Atoi(portStr)
×
1712
                        if err != nil {
×
1713
                                return false, nil, err
×
1714
                        }
×
1715

1716
                        addresses = append(addresses, &tor.OnionAddr{
×
1717
                                OnionService: service,
×
1718
                                Port:         port,
×
1719
                        })
×
1720

1721
                case addressTypeOpaque:
×
1722
                        opaque, err := hex.DecodeString(address)
×
1723
                        if err != nil {
×
1724
                                return false, nil, fmt.Errorf("unable to "+
×
1725
                                        "decode opaque address: %v", addr)
×
1726
                        }
×
1727

1728
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
1729
                                Payload: opaque,
×
1730
                        })
×
1731

1732
                default:
×
1733
                        return false, nil, fmt.Errorf("unknown address "+
×
1734
                                "type: %v", addr.Type)
×
1735
                }
1736
        }
1737

1738
        return true, addresses, nil
×
1739
}
1740

1741
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
1742
// database. This includes updating any existing types, inserting any new types,
1743
// and deleting any types that are no longer present.
1744
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
1745
        nodeID int64, extraFields map[uint64][]byte) error {
×
1746

×
1747
        // Get any existing extra signed fields for the node.
×
1748
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
1749
        if err != nil {
×
1750
                return err
×
1751
        }
×
1752

1753
        // Make a lookup map of the existing field types so that we can use it
1754
        // to keep track of any fields we should delete.
1755
        m := make(map[uint64]bool)
×
1756
        for _, field := range existingFields {
×
1757
                m[uint64(field.Type)] = true
×
1758
        }
×
1759

1760
        // For all the new fields, we'll upsert them and remove them from the
1761
        // map of existing fields.
1762
        for tlvType, value := range extraFields {
×
1763
                err = db.UpsertNodeExtraType(
×
1764
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
1765
                                NodeID: nodeID,
×
1766
                                Type:   int64(tlvType),
×
1767
                                Value:  value,
×
1768
                        },
×
1769
                )
×
1770
                if err != nil {
×
1771
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
1772
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1773
                }
×
1774

1775
                // Remove the field from the map of existing fields if it was
1776
                // present.
1777
                delete(m, tlvType)
×
1778
        }
1779

1780
        // For all the fields that are left in the map of existing fields, we'll
1781
        // delete them as they are no longer present in the new set of fields.
1782
        for tlvType := range m {
×
1783
                err = db.DeleteExtraNodeType(
×
1784
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
1785
                                NodeID: nodeID,
×
1786
                                Type:   int64(tlvType),
×
1787
                        },
×
1788
                )
×
1789
                if err != nil {
×
1790
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
1791
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1792
                }
×
1793
        }
1794

1795
        return nil
×
1796
}
1797

1798
// srcNodeInfo holds the information about the source node of the graph.
1799
type srcNodeInfo struct {
1800
        // id is the DB level ID of the source node entry in the "nodes" table.
1801
        id int64
1802

1803
        // pub is the public key of the source node.
1804
        pub route.Vertex
1805
}
1806

1807
// getSourceNode returns the DB node ID and pub key of the source node for the
1808
// specified protocol version.
1809
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
1810
        version ProtocolVersion) (int64, route.Vertex, error) {
×
1811

×
1812
        s.srcNodeMu.Lock()
×
1813
        defer s.srcNodeMu.Unlock()
×
1814

×
1815
        // If we already have the source node ID and pub key cached, then
×
1816
        // return them.
×
1817
        if info, ok := s.srcNodes[version]; ok {
×
1818
                return info.id, info.pub, nil
×
1819
        }
×
1820

1821
        var pubKey route.Vertex
×
1822

×
1823
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
1824
        if err != nil {
×
1825
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
1826
                        err)
×
1827
        }
×
1828

1829
        if len(nodes) == 0 {
×
1830
                return 0, pubKey, ErrSourceNodeNotSet
×
1831
        } else if len(nodes) > 1 {
×
1832
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
1833
                        "protocol %s found", version)
×
1834
        }
×
1835

1836
        copy(pubKey[:], nodes[0].PubKey)
×
1837

×
1838
        s.srcNodes[version] = &srcNodeInfo{
×
1839
                id:  nodes[0].NodeID,
×
1840
                pub: pubKey,
×
1841
        }
×
1842

×
1843
        return nodes[0].NodeID, pubKey, nil
×
1844
}
1845

1846
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
1847
// This then produces a map from TLV type to value. If the input is not a
1848
// valid TLV stream, then an error is returned.
1849
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
1850
        r := bytes.NewReader(data)
×
1851

×
1852
        tlvStream, err := tlv.NewStream()
×
1853
        if err != nil {
×
1854
                return nil, err
×
1855
        }
×
1856

1857
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
1858
        // pass it into the P2P decoding variant.
1859
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
1860
        if err != nil {
×
NEW
1861
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
1862
        }
×
1863
        if len(parsedTypes) == 0 {
×
1864
                return nil, nil
×
1865
        }
×
1866

1867
        records := make(map[uint64][]byte)
×
1868
        for k, v := range parsedTypes {
×
1869
                records[uint64(k)] = v
×
1870
        }
×
1871

1872
        return records, nil
×
1873
}
1874

1875
// insertChannel inserts a new channel record into the database.
1876
func insertChannel(ctx context.Context, db SQLQueries,
1877
        edge *models.ChannelEdgeInfo) error {
×
1878

×
1879
        var chanIDB [8]byte
×
1880
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
1881

×
1882
        // Make sure that the channel doesn't already exist. We do this
×
1883
        // explicitly instead of relying on catching a unique constraint error
×
1884
        // because relying on SQL to throw that error would abort the entire
×
1885
        // batch of transactions.
×
1886
        _, err := db.GetChannelBySCID(
×
1887
                ctx, sqlc.GetChannelBySCIDParams{
×
1888
                        Scid:    chanIDB[:],
×
1889
                        Version: int16(ProtocolV1),
×
1890
                },
×
1891
        )
×
1892
        if err == nil {
×
1893
                return ErrEdgeAlreadyExist
×
1894
        } else if !errors.Is(err, sql.ErrNoRows) {
×
1895
                return fmt.Errorf("unable to fetch channel: %w", err)
×
1896
        }
×
1897

1898
        // Make sure that at least a "shell" entry for each node is present in
1899
        // the nodes table.
1900
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
1901
        if err != nil {
×
1902
                return fmt.Errorf("unable to create shell node: %w", err)
×
1903
        }
×
1904

1905
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
1906
        if err != nil {
×
1907
                return fmt.Errorf("unable to create shell node: %w", err)
×
1908
        }
×
1909

1910
        var capacity sql.NullInt64
×
1911
        if edge.Capacity != 0 {
×
1912
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
1913
        }
×
1914

1915
        createParams := sqlc.CreateChannelParams{
×
1916
                Version:     int16(ProtocolV1),
×
1917
                Scid:        chanIDB[:],
×
1918
                NodeID1:     node1DBID,
×
1919
                NodeID2:     node2DBID,
×
1920
                Outpoint:    edge.ChannelPoint.String(),
×
1921
                Capacity:    capacity,
×
1922
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
1923
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
1924
        }
×
1925

×
1926
        if edge.AuthProof != nil {
×
1927
                proof := edge.AuthProof
×
1928

×
1929
                createParams.Node1Signature = proof.NodeSig1Bytes
×
1930
                createParams.Node2Signature = proof.NodeSig2Bytes
×
1931
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
1932
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
1933
        }
×
1934

1935
        // Insert the new channel record.
1936
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
1937
        if err != nil {
×
1938
                return err
×
1939
        }
×
1940

1941
        // Insert any channel features.
1942
        if len(edge.Features) != 0 {
×
1943
                chanFeatures := lnwire.NewRawFeatureVector()
×
1944
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
1945
                if err != nil {
×
1946
                        return err
×
1947
                }
×
1948

1949
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
1950
                for feature := range fv.Features() {
×
1951
                        err = db.InsertChannelFeature(
×
1952
                                ctx, sqlc.InsertChannelFeatureParams{
×
1953
                                        ChannelID:  dbChanID,
×
1954
                                        FeatureBit: int32(feature),
×
1955
                                },
×
1956
                        )
×
1957
                        if err != nil {
×
1958
                                return fmt.Errorf("unable to insert "+
×
1959
                                        "channel(%d) feature(%v): %w", dbChanID,
×
1960
                                        feature, err)
×
1961
                        }
×
1962
                }
1963
        }
1964

1965
        // Finally, insert any extra TLV fields in the channel announcement.
1966
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1967
        if err != nil {
×
1968
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
1969
                        err)
×
1970
        }
×
1971

1972
        for tlvType, value := range extra {
×
1973
                err := db.CreateChannelExtraType(
×
1974
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
1975
                                ChannelID: dbChanID,
×
1976
                                Type:      int64(tlvType),
×
1977
                                Value:     value,
×
1978
                        },
×
1979
                )
×
1980
                if err != nil {
×
1981
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
1982
                                "signed field(%v): %w", edge.ChannelID,
×
1983
                                tlvType, err)
×
1984
                }
×
1985
        }
1986

1987
        return nil
×
1988
}
1989

1990
// maybeCreateShellNode checks if a shell node entry exists for the
1991
// given public key. If it does not exist, then a new shell node entry is
1992
// created. The ID of the node is returned. A shell node only has a protocol
1993
// version and public key persisted.
1994
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
1995
        pubKey route.Vertex) (int64, error) {
×
1996

×
1997
        dbNode, err := db.GetNodeByPubKey(
×
1998
                ctx, sqlc.GetNodeByPubKeyParams{
×
1999
                        PubKey:  pubKey[:],
×
2000
                        Version: int16(ProtocolV1),
×
2001
                },
×
2002
        )
×
2003
        // The node exists. Return the ID.
×
2004
        if err == nil {
×
2005
                return dbNode.ID, nil
×
2006
        } else if !errors.Is(err, sql.ErrNoRows) {
×
2007
                return 0, err
×
2008
        }
×
2009

2010
        // Otherwise, the node does not exist, so we create a shell entry for
2011
        // it.
2012
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
2013
                Version: int16(ProtocolV1),
×
2014
                PubKey:  pubKey[:],
×
2015
        })
×
2016
        if err != nil {
×
2017
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
2018
        }
×
2019

2020
        return id, nil
×
2021
}
2022

2023
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
2024
// the database. This includes deleting any existing types and then inserting
2025
// the new types.
2026
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
2027
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
2028

×
2029
        // Delete all existing extra signed fields for the channel policy.
×
2030
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
2031
        if err != nil {
×
2032
                return fmt.Errorf("unable to delete "+
×
2033
                        "existing policy extra signed fields for policy %d: %w",
×
2034
                        chanPolicyID, err)
×
2035
        }
×
2036

2037
        // Insert all new extra signed fields for the channel policy.
2038
        for tlvType, value := range extraFields {
×
2039
                err = db.InsertChanPolicyExtraType(
×
2040
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
2041
                                ChannelPolicyID: chanPolicyID,
×
2042
                                Type:            int64(tlvType),
×
2043
                                Value:           value,
×
2044
                        },
×
2045
                )
×
2046
                if err != nil {
×
2047
                        return fmt.Errorf("unable to insert "+
×
2048
                                "channel_policy(%d) extra signed field(%v): %w",
×
2049
                                chanPolicyID, tlvType, err)
×
2050
                }
×
2051
        }
2052

2053
        return nil
×
2054
}
2055

2056
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
2057
// provided dbChanRow and also fetches any other required information
2058
// to construct the edge info.
2059
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
2060
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
2061
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
2062

×
2063
        fv, extras, err := getChanFeaturesAndExtras(
×
2064
                ctx, db, dbChanID,
×
2065
        )
×
2066
        if err != nil {
×
2067
                return nil, err
×
2068
        }
×
2069

2070
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
2071
        if err != nil {
×
2072
                return nil, err
×
2073
        }
×
2074

2075
        var featureBuf bytes.Buffer
×
2076
        if err := fv.Encode(&featureBuf); err != nil {
×
2077
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
2078
        }
×
2079

2080
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
2081
        if err != nil {
×
2082
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
2083
                        "fields: %w", err)
×
2084
        }
×
2085
        if recs == nil {
×
2086
                recs = make([]byte, 0)
×
2087
        }
×
2088

2089
        var btcKey1, btcKey2 route.Vertex
×
2090
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
2091
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
2092

×
2093
        channel := &models.ChannelEdgeInfo{
×
2094
                ChainHash:        chain,
×
2095
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
2096
                NodeKey1Bytes:    node1,
×
2097
                NodeKey2Bytes:    node2,
×
2098
                BitcoinKey1Bytes: btcKey1,
×
2099
                BitcoinKey2Bytes: btcKey2,
×
2100
                ChannelPoint:     *op,
×
2101
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
2102
                Features:         featureBuf.Bytes(),
×
2103
                ExtraOpaqueData:  recs,
×
2104
        }
×
2105

×
NEW
2106
        // We always set all the signatures at the same time, so we can
×
NEW
2107
        // safely check if one signature is present to determine if we have the
×
NEW
2108
        // rest of the signatures for the auth proof.
×
NEW
2109
        if len(dbChan.Bitcoin1Signature) > 0 {
×
2110
                channel.AuthProof = &models.ChannelAuthProof{
×
2111
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
2112
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
2113
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
2114
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
2115
                }
×
2116
        }
×
2117

2118
        return channel, nil
×
2119
}
2120

2121
// buildNodeVertices is a helper that converts raw node public keys
2122
// into route.Vertex instances.
2123
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
2124
        route.Vertex, error) {
×
2125

×
2126
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
2127
        if err != nil {
×
2128
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
2129
                        "create vertex from node1 pubkey: %w", err)
×
2130
        }
×
2131

2132
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
2133
        if err != nil {
×
2134
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
2135
                        "create vertex from node2 pubkey: %w", err)
×
2136
        }
×
2137

2138
        return node1Vertex, node2Vertex, nil
×
2139
}
2140

2141
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
2142
// for a channel with the given ID.
2143
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
2144
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
2145

×
2146
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
2147
        if err != nil {
×
2148
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
2149
                        "features and extras: %w", err)
×
2150
        }
×
2151

2152
        var (
×
2153
                fv     = lnwire.EmptyFeatureVector()
×
2154
                extras = make(map[uint64][]byte)
×
2155
        )
×
2156
        for _, row := range rows {
×
2157
                if row.IsFeature {
×
2158
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
2159

×
2160
                        continue
×
2161
                }
2162

2163
                tlvType, ok := row.ExtraKey.(int64)
×
2164
                if !ok {
×
2165
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
2166
                                "TLV type: %T", row.ExtraKey)
×
2167
                }
×
2168

2169
                valueBytes, ok := row.Value.([]byte)
×
2170
                if !ok {
×
2171
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
2172
                                "Value: %T", row.Value)
×
2173
                }
×
2174

2175
                extras[uint64(tlvType)] = valueBytes
×
2176
        }
2177

2178
        return fv, extras, nil
×
2179
}
2180

2181
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
2182
// all the extra info required to build the complete models.ChannelEdgePolicy
2183
// types. It returns two policies, which may be nil if the provided
2184
// sqlc.ChannelPolicy records are nil.
2185
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
2186
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
2187
        node2 route.Vertex) (*models.ChannelEdgePolicy,
2188
        *models.ChannelEdgePolicy, error) {
×
2189

×
2190
        if dbPol1 == nil && dbPol2 == nil {
×
2191
                return nil, nil, nil
×
2192
        }
×
2193

2194
        var (
×
2195
                policy1ID int64
×
2196
                policy2ID int64
×
2197
        )
×
2198
        if dbPol1 != nil {
×
2199
                policy1ID = dbPol1.ID
×
2200
        }
×
2201
        if dbPol2 != nil {
×
2202
                policy2ID = dbPol2.ID
×
2203
        }
×
2204
        rows, err := db.GetChannelPolicyExtraTypes(
×
2205
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
2206
                        ID:   policy1ID,
×
2207
                        ID_2: policy2ID,
×
2208
                },
×
2209
        )
×
2210
        if err != nil {
×
2211
                return nil, nil, err
×
2212
        }
×
2213

2214
        var (
×
2215
                dbPol1Extras = make(map[uint64][]byte)
×
2216
                dbPol2Extras = make(map[uint64][]byte)
×
2217
        )
×
2218
        for _, row := range rows {
×
2219
                switch row.PolicyID {
×
2220
                case policy1ID:
×
2221
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
2222
                case policy2ID:
×
2223
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
2224
                default:
×
2225
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
2226
                                "in row: %v", row.PolicyID, row)
×
2227
                }
2228
        }
2229

2230
        var pol1, pol2 *models.ChannelEdgePolicy
×
2231
        if dbPol1 != nil {
×
2232
                pol1, err = buildChanPolicy(
×
2233
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
2234
                )
×
2235
                if err != nil {
×
2236
                        return nil, nil, err
×
2237
                }
×
2238
        }
2239
        if dbPol2 != nil {
×
2240
                pol2, err = buildChanPolicy(
×
2241
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
2242
                )
×
2243
                if err != nil {
×
2244
                        return nil, nil, err
×
2245
                }
×
2246
        }
2247

2248
        return pol1, pol2, nil
×
2249
}
2250

2251
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
2252
// provided sqlc.ChannelPolicy and other required information.
2253
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
2254
        extras map[uint64][]byte, toNode route.Vertex,
2255
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
2256

×
2257
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
2258
        if err != nil {
×
2259
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
2260
                        "fields: %w", err)
×
2261
        }
×
2262

2263
        var msgFlags lnwire.ChanUpdateMsgFlags
×
2264
        if dbPolicy.MaxHtlcMsat.Valid {
×
2265
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
2266
        }
×
2267

2268
        var chanFlags lnwire.ChanUpdateChanFlags
×
2269
        if !isNode1 {
×
2270
                chanFlags |= lnwire.ChanUpdateDirection
×
2271
        }
×
2272
        if dbPolicy.Disabled.Bool {
×
2273
                chanFlags |= lnwire.ChanUpdateDisabled
×
2274
        }
×
2275

2276
        var inboundFee fn.Option[lnwire.Fee]
×
2277
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
2278
                dbPolicy.InboundBaseFeeMsat.Valid {
×
2279

×
2280
                inboundFee = fn.Some(lnwire.Fee{
×
2281
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
2282
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
2283
                })
×
2284
        }
×
2285

2286
        return &models.ChannelEdgePolicy{
×
2287
                SigBytes:  dbPolicy.Signature,
×
2288
                ChannelID: channelID,
×
2289
                LastUpdate: time.Unix(
×
2290
                        dbPolicy.LastUpdate.Int64, 0,
×
2291
                ),
×
2292
                MessageFlags:  msgFlags,
×
2293
                ChannelFlags:  chanFlags,
×
2294
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
2295
                MinHTLC: lnwire.MilliSatoshi(
×
2296
                        dbPolicy.MinHtlcMsat,
×
2297
                ),
×
2298
                MaxHTLC: lnwire.MilliSatoshi(
×
2299
                        dbPolicy.MaxHtlcMsat.Int64,
×
2300
                ),
×
2301
                FeeBaseMSat: lnwire.MilliSatoshi(
×
2302
                        dbPolicy.BaseFeeMsat,
×
2303
                ),
×
2304
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
2305
                ToNode:                    toNode,
×
2306
                InboundFee:                inboundFee,
×
2307
                ExtraOpaqueData:           recs,
×
2308
        }, nil
×
2309
}
2310

2311
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
2312
// row which is expected to be a sqlc type that contains channel policy
2313
// information. It returns two policies, which may be nil if the policy
2314
// information is not present in the row.
2315
//
2316
//nolint:ll
2317
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
2318
        error) {
×
2319

×
2320
        var policy1, policy2 *sqlc.ChannelPolicy
×
2321
        switch r := row.(type) {
×
2322
        case sqlc.ListChannelsByNodeIDRow:
×
2323
                if r.Policy1ID.Valid {
×
2324
                        policy1 = &sqlc.ChannelPolicy{
×
2325
                                ID:                      r.Policy1ID.Int64,
×
2326
                                Version:                 r.Policy1Version.Int16,
×
2327
                                ChannelID:               r.Channel.ID,
×
2328
                                NodeID:                  r.Policy1NodeID.Int64,
×
2329
                                Timelock:                r.Policy1Timelock.Int32,
×
2330
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
2331
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
2332
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
2333
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
2334
                                LastUpdate:              r.Policy1LastUpdate,
×
2335
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
2336
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
2337
                                Disabled:                r.Policy1Disabled,
×
2338
                                Signature:               r.Policy1Signature,
×
2339
                        }
×
2340
                }
×
2341
                if r.Policy2ID.Valid {
×
2342
                        policy2 = &sqlc.ChannelPolicy{
×
2343
                                ID:                      r.Policy2ID.Int64,
×
2344
                                Version:                 r.Policy2Version.Int16,
×
2345
                                ChannelID:               r.Channel.ID,
×
2346
                                NodeID:                  r.Policy2NodeID.Int64,
×
2347
                                Timelock:                r.Policy2Timelock.Int32,
×
2348
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
2349
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
2350
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
2351
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
2352
                                LastUpdate:              r.Policy2LastUpdate,
×
2353
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
2354
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
2355
                                Disabled:                r.Policy2Disabled,
×
2356
                                Signature:               r.Policy2Signature,
×
2357
                        }
×
2358
                }
×
2359

2360
                return policy1, policy2, nil
×
2361
        default:
×
2362
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
2363
                        "extractChannelPolicies: %T", r)
×
2364
        }
2365
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc