• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15847774552

24 Jun 2025 10:18AM UTC coverage: 67.965% (-0.2%) from 68.172%
15847774552

Pull #9936

github

web-flow
Merge 3687171cd into 45c15646c
Pull Request #9936: [12] graph/db: Implement more graph SQLStore methods

8 of 650 new or added lines in 3 files covered. (1.23%)

56 existing lines in 19 files now uncovered.

134725 of 198227 relevant lines covered (67.97%)

22070.37 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/batch"
23
        "github.com/lightningnetwork/lnd/fn/v2"
24
        "github.com/lightningnetwork/lnd/graph/db/models"
25
        "github.com/lightningnetwork/lnd/lnwire"
26
        "github.com/lightningnetwork/lnd/routing/route"
27
        "github.com/lightningnetwork/lnd/sqldb"
28
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
29
        "github.com/lightningnetwork/lnd/tlv"
30
        "github.com/lightningnetwork/lnd/tor"
31
)
32

33
// pageSize is the limit for the number of records that can be returned
34
// in a paginated query. This can be tuned after some benchmarks.
35
const pageSize = 2000
36

37
// ProtocolVersion is an enum that defines the gossip protocol version of a
38
// message.
39
type ProtocolVersion uint8
40

41
const (
42
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
43
        ProtocolV1 ProtocolVersion = 1
44
)
45

46
// String returns a string representation of the protocol version.
47
func (v ProtocolVersion) String() string {
×
48
        return fmt.Sprintf("V%d", v)
×
49
}
×
50

51
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
52
// execute queries against the SQL graph tables.
53
//
54
//nolint:ll,interfacebloat
55
type SQLQueries interface {
56
        /*
57
                Node queries.
58
        */
59
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
60
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
61
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
62
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
63
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
64
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
65
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
66

67
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
68
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
69
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
70

71
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
72
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
73
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
74

75
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
76
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
77
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
78
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
79

80
        /*
81
                Source node queries.
82
        */
83
        AddSourceNode(ctx context.Context, nodeID int64) error
84
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
85

86
        /*
87
                Channel queries.
88
        */
89
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
90
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
91
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
92
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
93
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
94
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
95
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
96
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
97
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
98

99
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
100
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
101

102
        /*
103
                Channel Policy table queries.
104
        */
105
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
106
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
107

108
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
109
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
110
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
111
}
112

113
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
114
// database operations.
115
type BatchedSQLQueries interface {
116
        SQLQueries
117
        sqldb.BatchedTx[SQLQueries]
118
}
119

120
// SQLStore is an implementation of the V1Store interface that uses a SQL
121
// database as the backend.
122
//
123
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
124
// implement the V1Store interface incrementally. For any method not
125
// implemented,  things will fall back to the KVStore. This is ONLY the case
126
// for the time being while this struct is purely used in unit tests only.
127
type SQLStore struct {
128
        cfg *SQLStoreConfig
129
        db  BatchedSQLQueries
130

131
        // cacheMu guards all caches (rejectCache and chanCache). If
132
        // this mutex will be acquired at the same time as the DB mutex then
133
        // the cacheMu MUST be acquired first to prevent deadlock.
134
        cacheMu     sync.RWMutex
135
        rejectCache *rejectCache
136
        chanCache   *channelCache
137

138
        chanScheduler batch.Scheduler[SQLQueries]
139
        nodeScheduler batch.Scheduler[SQLQueries]
140

141
        srcNodes  map[ProtocolVersion]*srcNodeInfo
142
        srcNodeMu sync.Mutex
143

144
        // Temporary fall-back to the KVStore so that we can implement the
145
        // interface incrementally.
146
        *KVStore
147
}
148

149
// A compile-time assertion to ensure that SQLStore implements the V1Store
150
// interface.
151
var _ V1Store = (*SQLStore)(nil)
152

153
// SQLStoreConfig holds the configuration for the SQLStore.
154
type SQLStoreConfig struct {
155
        // ChainHash is the genesis hash for the chain that all the gossip
156
        // messages in this store are aimed at.
157
        ChainHash chainhash.Hash
158
}
159

160
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
161
// storage backend.
162
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
163
        options ...StoreOptionModifier) (*SQLStore, error) {
×
164

×
165
        opts := DefaultOptions()
×
166
        for _, o := range options {
×
167
                o(opts)
×
168
        }
×
169

170
        if opts.NoMigration {
×
171
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
172
                        "supported for SQL stores")
×
173
        }
×
174

175
        s := &SQLStore{
×
176
                cfg:         cfg,
×
177
                db:          db,
×
178
                KVStore:     kvStore,
×
179
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
180
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
181
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
182
        }
×
183

×
184
        s.chanScheduler = batch.NewTimeScheduler(
×
185
                db, &s.cacheMu, opts.BatchCommitInterval,
×
186
        )
×
187
        s.nodeScheduler = batch.NewTimeScheduler(
×
188
                db, nil, opts.BatchCommitInterval,
×
189
        )
×
190

×
191
        return s, nil
×
192
}
193

194
// AddLightningNode adds a vertex/node to the graph database. If the node is not
195
// in the database from before, this will add a new, unconnected one to the
196
// graph. If it is present from before, this will update that node's
197
// information.
198
//
199
// NOTE: part of the V1Store interface.
200
func (s *SQLStore) AddLightningNode(ctx context.Context,
201
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
202

×
203
        r := &batch.Request[SQLQueries]{
×
204
                Opts: batch.NewSchedulerOptions(opts...),
×
205
                Do: func(queries SQLQueries) error {
×
206
                        _, err := upsertNode(ctx, queries, node)
×
207
                        return err
×
208
                },
×
209
        }
210

211
        return s.nodeScheduler.Execute(ctx, r)
×
212
}
213

214
// FetchLightningNode attempts to look up a target node by its identity public
215
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
216
// returned.
217
//
218
// NOTE: part of the V1Store interface.
219
func (s *SQLStore) FetchLightningNode(ctx context.Context,
220
        pubKey route.Vertex) (*models.LightningNode, error) {
×
221

×
222
        var node *models.LightningNode
×
223
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
224
                var err error
×
225
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
226

×
227
                return err
×
228
        }, sqldb.NoOpReset)
×
229
        if err != nil {
×
230
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
231
        }
×
232

233
        return node, nil
×
234
}
235

236
// HasLightningNode determines if the graph has a vertex identified by the
237
// target node identity public key. If the node exists in the database, a
238
// timestamp of when the data for the node was lasted updated is returned along
239
// with a true boolean. Otherwise, an empty time.Time is returned with a false
240
// boolean.
241
//
242
// NOTE: part of the V1Store interface.
243
func (s *SQLStore) HasLightningNode(ctx context.Context,
244
        pubKey [33]byte) (time.Time, bool, error) {
×
245

×
246
        var (
×
247
                exists     bool
×
248
                lastUpdate time.Time
×
249
        )
×
250
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
251
                dbNode, err := db.GetNodeByPubKey(
×
252
                        ctx, sqlc.GetNodeByPubKeyParams{
×
253
                                Version: int16(ProtocolV1),
×
254
                                PubKey:  pubKey[:],
×
255
                        },
×
256
                )
×
257
                if errors.Is(err, sql.ErrNoRows) {
×
258
                        return nil
×
259
                } else if err != nil {
×
260
                        return fmt.Errorf("unable to fetch node: %w", err)
×
261
                }
×
262

263
                exists = true
×
264

×
265
                if dbNode.LastUpdate.Valid {
×
266
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
267
                }
×
268

269
                return nil
×
270
        }, sqldb.NoOpReset)
271
        if err != nil {
×
272
                return time.Time{}, false,
×
273
                        fmt.Errorf("unable to fetch node: %w", err)
×
274
        }
×
275

276
        return lastUpdate, exists, nil
×
277
}
278

279
// AddrsForNode returns all known addresses for the target node public key
280
// that the graph DB is aware of. The returned boolean indicates if the
281
// given node is unknown to the graph DB or not.
282
//
283
// NOTE: part of the V1Store interface.
284
func (s *SQLStore) AddrsForNode(ctx context.Context,
285
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
286

×
287
        var (
×
288
                addresses []net.Addr
×
289
                known     bool
×
290
        )
×
291
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
292
                var err error
×
293
                known, addresses, err = getNodeAddresses(
×
294
                        ctx, db, nodePub.SerializeCompressed(),
×
295
                )
×
296
                if err != nil {
×
297
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
298
                                err)
×
299
                }
×
300

301
                return nil
×
302
        }, sqldb.NoOpReset)
303
        if err != nil {
×
304
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
305
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
306
        }
×
307

308
        return known, addresses, nil
×
309
}
310

311
// DeleteLightningNode starts a new database transaction to remove a vertex/node
312
// from the database according to the node's public key.
313
//
314
// NOTE: part of the V1Store interface.
315
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
316
        pubKey route.Vertex) error {
×
317

×
318
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
319
                res, err := db.DeleteNodeByPubKey(
×
320
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
321
                                Version: int16(ProtocolV1),
×
322
                                PubKey:  pubKey[:],
×
323
                        },
×
324
                )
×
325
                if err != nil {
×
326
                        return err
×
327
                }
×
328

329
                rows, err := res.RowsAffected()
×
330
                if err != nil {
×
331
                        return err
×
332
                }
×
333

334
                if rows == 0 {
×
335
                        return ErrGraphNodeNotFound
×
336
                } else if rows > 1 {
×
337
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
338
                }
×
339

340
                return err
×
341
        }, sqldb.NoOpReset)
342
        if err != nil {
×
343
                return fmt.Errorf("unable to delete node: %w", err)
×
344
        }
×
345

346
        return nil
×
347
}
348

349
// FetchNodeFeatures returns the features of the given node. If no features are
350
// known for the node, an empty feature vector is returned.
351
//
352
// NOTE: this is part of the graphdb.NodeTraverser interface.
353
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
354
        *lnwire.FeatureVector, error) {
×
355

×
356
        ctx := context.TODO()
×
357

×
358
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
359
}
×
360

361
// LookupAlias attempts to return the alias as advertised by the target node.
362
//
363
// NOTE: part of the V1Store interface.
364
func (s *SQLStore) LookupAlias(ctx context.Context,
365
        pub *btcec.PublicKey) (string, error) {
×
366

×
367
        var alias string
×
368
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
369
                dbNode, err := db.GetNodeByPubKey(
×
370
                        ctx, sqlc.GetNodeByPubKeyParams{
×
371
                                Version: int16(ProtocolV1),
×
372
                                PubKey:  pub.SerializeCompressed(),
×
373
                        },
×
374
                )
×
375
                if errors.Is(err, sql.ErrNoRows) {
×
376
                        return ErrNodeAliasNotFound
×
377
                } else if err != nil {
×
378
                        return fmt.Errorf("unable to fetch node: %w", err)
×
379
                }
×
380

381
                if !dbNode.Alias.Valid {
×
382
                        return ErrNodeAliasNotFound
×
383
                }
×
384

385
                alias = dbNode.Alias.String
×
386

×
387
                return nil
×
388
        }, sqldb.NoOpReset)
389
        if err != nil {
×
390
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
391
        }
×
392

393
        return alias, nil
×
394
}
395

396
// SourceNode returns the source node of the graph. The source node is treated
397
// as the center node within a star-graph. This method may be used to kick off
398
// a path finding algorithm in order to explore the reachability of another
399
// node based off the source node.
400
//
401
// NOTE: part of the V1Store interface.
402
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
403
        error) {
×
404

×
405
        var node *models.LightningNode
×
406
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
407
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
408
                if err != nil {
×
409
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
410
                                err)
×
411
                }
×
412

413
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
414

×
415
                return err
×
416
        }, sqldb.NoOpReset)
417
        if err != nil {
×
418
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
419
        }
×
420

421
        return node, nil
×
422
}
423

424
// SetSourceNode sets the source node within the graph database. The source
425
// node is to be used as the center of a star-graph within path finding
426
// algorithms.
427
//
428
// NOTE: part of the V1Store interface.
429
func (s *SQLStore) SetSourceNode(ctx context.Context,
430
        node *models.LightningNode) error {
×
431

×
432
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
433
                id, err := upsertNode(ctx, db, node)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to upsert source node: %w",
×
436
                                err)
×
437
                }
×
438

439
                // Make sure that if a source node for this version is already
440
                // set, then the ID is the same as the one we are about to set.
441
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
442
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
443
                        return fmt.Errorf("unable to fetch source node: %w",
×
444
                                err)
×
445
                } else if err == nil {
×
446
                        if dbSourceNodeID != id {
×
447
                                return fmt.Errorf("v1 source node already "+
×
448
                                        "set to a different node: %d vs %d",
×
449
                                        dbSourceNodeID, id)
×
450
                        }
×
451

452
                        return nil
×
453
                }
454

455
                return db.AddSourceNode(ctx, id)
×
456
        }, sqldb.NoOpReset)
457
}
458

459
// NodeUpdatesInHorizon returns all the known lightning node which have an
460
// update timestamp within the passed range. This method can be used by two
461
// nodes to quickly determine if they have the same set of up to date node
462
// announcements.
463
//
464
// NOTE: This is part of the V1Store interface.
465
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
466
        endTime time.Time) ([]models.LightningNode, error) {
×
467

×
468
        ctx := context.TODO()
×
469

×
470
        var nodes []models.LightningNode
×
471
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
472
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
473
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
474
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
475
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
476
                        },
×
477
                )
×
478
                if err != nil {
×
479
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
480
                }
×
481

482
                for _, dbNode := range dbNodes {
×
483
                        node, err := buildNode(ctx, db, &dbNode)
×
484
                        if err != nil {
×
485
                                return fmt.Errorf("unable to build node: %w",
×
486
                                        err)
×
487
                        }
×
488

489
                        nodes = append(nodes, *node)
×
490
                }
491

492
                return nil
×
493
        }, sqldb.NoOpReset)
494
        if err != nil {
×
495
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
496
        }
×
497

498
        return nodes, nil
×
499
}
500

501
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
502
// undirected edge from the two target nodes are created. The information stored
503
// denotes the static attributes of the channel, such as the channelID, the keys
504
// involved in creation of the channel, and the set of features that the channel
505
// supports. The chanPoint and chanID are used to uniquely identify the edge
506
// globally within the database.
507
//
508
// NOTE: part of the V1Store interface.
509
func (s *SQLStore) AddChannelEdge(ctx context.Context,
510
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
511

×
512
        var alreadyExists bool
×
513
        r := &batch.Request[SQLQueries]{
×
514
                Opts: batch.NewSchedulerOptions(opts...),
×
515
                Reset: func() {
×
516
                        alreadyExists = false
×
517
                },
×
518
                Do: func(tx SQLQueries) error {
×
519
                        err := insertChannel(ctx, tx, edge)
×
520

×
521
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
522
                        // succeed, but propagate the error via local state.
×
523
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
524
                                alreadyExists = true
×
525
                                return nil
×
526
                        }
×
527

528
                        return err
×
529
                },
530
                OnCommit: func(err error) error {
×
531
                        switch {
×
532
                        case err != nil:
×
533
                                return err
×
534
                        case alreadyExists:
×
535
                                return ErrEdgeAlreadyExist
×
536
                        default:
×
537
                                s.rejectCache.remove(edge.ChannelID)
×
538
                                s.chanCache.remove(edge.ChannelID)
×
539
                                return nil
×
540
                        }
541
                },
542
        }
543

544
        return s.chanScheduler.Execute(ctx, r)
×
545
}
546

547
// HighestChanID returns the "highest" known channel ID in the channel graph.
548
// This represents the "newest" channel from the PoV of the chain. This method
549
// can be used by peers to quickly determine if their graphs are in sync.
550
//
551
// NOTE: This is part of the V1Store interface.
552
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
553
        var highestChanID uint64
×
554
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
555
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
556
                if errors.Is(err, sql.ErrNoRows) {
×
557
                        return nil
×
558
                } else if err != nil {
×
559
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
560
                                err)
×
561
                }
×
562

563
                highestChanID = byteOrder.Uint64(chanID)
×
564

×
565
                return nil
×
566
        }, sqldb.NoOpReset)
567
        if err != nil {
×
568
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
569
        }
×
570

571
        return highestChanID, nil
×
572
}
573

574
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
575
// within the database for the referenced channel. The `flags` attribute within
576
// the ChannelEdgePolicy determines which of the directed edges are being
577
// updated. If the flag is 1, then the first node's information is being
578
// updated, otherwise it's the second node's information. The node ordering is
579
// determined by the lexicographical ordering of the identity public keys of the
580
// nodes on either side of the channel.
581
//
582
// NOTE: part of the V1Store interface.
583
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
584
        edge *models.ChannelEdgePolicy,
585
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
586

×
587
        var (
×
588
                isUpdate1    bool
×
589
                edgeNotFound bool
×
590
                from, to     route.Vertex
×
591
        )
×
592

×
593
        r := &batch.Request[SQLQueries]{
×
594
                Opts: batch.NewSchedulerOptions(opts...),
×
595
                Reset: func() {
×
596
                        isUpdate1 = false
×
597
                        edgeNotFound = false
×
598
                },
×
599
                Do: func(tx SQLQueries) error {
×
600
                        var err error
×
601
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
602
                                ctx, tx, edge,
×
603
                        )
×
604
                        if err != nil {
×
605
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
606
                        }
×
607

608
                        // Silence ErrEdgeNotFound so that the batch can
609
                        // succeed, but propagate the error via local state.
610
                        if errors.Is(err, ErrEdgeNotFound) {
×
611
                                edgeNotFound = true
×
612
                                return nil
×
613
                        }
×
614

615
                        return err
×
616
                },
617
                OnCommit: func(err error) error {
×
618
                        switch {
×
619
                        case err != nil:
×
620
                                return err
×
621
                        case edgeNotFound:
×
622
                                return ErrEdgeNotFound
×
623
                        default:
×
624
                                s.updateEdgeCache(edge, isUpdate1)
×
625
                                return nil
×
626
                        }
627
                },
628
        }
629

630
        err := s.chanScheduler.Execute(ctx, r)
×
631

×
632
        return from, to, err
×
633
}
634

635
// updateEdgeCache updates our reject and channel caches with the new
636
// edge policy information.
637
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
638
        isUpdate1 bool) {
×
639

×
640
        // If an entry for this channel is found in reject cache, we'll modify
×
641
        // the entry with the updated timestamp for the direction that was just
×
642
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
643
        // during the next query for this edge.
×
644
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
645
                if isUpdate1 {
×
646
                        entry.upd1Time = e.LastUpdate.Unix()
×
647
                } else {
×
648
                        entry.upd2Time = e.LastUpdate.Unix()
×
649
                }
×
650
                s.rejectCache.insert(e.ChannelID, entry)
×
651
        }
652

653
        // If an entry for this channel is found in channel cache, we'll modify
654
        // the entry with the updated policy for the direction that was just
655
        // written. If the edge doesn't exist, we'll defer loading the info and
656
        // policies and lazily read from disk during the next query.
657
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
658
                if isUpdate1 {
×
659
                        channel.Policy1 = e
×
660
                } else {
×
661
                        channel.Policy2 = e
×
662
                }
×
663
                s.chanCache.insert(e.ChannelID, channel)
×
664
        }
665
}
666

667
// ForEachSourceNodeChannel iterates through all channels of the source node,
668
// executing the passed callback on each. The call-back is provided with the
669
// channel's outpoint, whether we have a policy for the channel and the channel
670
// peer's node information.
671
//
672
// NOTE: part of the V1Store interface.
673
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
674
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
675

×
676
        var ctx = context.TODO()
×
677

×
678
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
679
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
680
                if err != nil {
×
681
                        return fmt.Errorf("unable to fetch source node: %w",
×
682
                                err)
×
683
                }
×
684

685
                return forEachNodeChannel(
×
686
                        ctx, db, s.cfg.ChainHash, nodeID,
×
687
                        func(info *models.ChannelEdgeInfo,
×
688
                                outPolicy *models.ChannelEdgePolicy,
×
689
                                _ *models.ChannelEdgePolicy) error {
×
690

×
691
                                // Fetch the other node.
×
692
                                var (
×
693
                                        otherNodePub [33]byte
×
694
                                        node1        = info.NodeKey1Bytes
×
695
                                        node2        = info.NodeKey2Bytes
×
696
                                )
×
697
                                switch {
×
698
                                case bytes.Equal(node1[:], nodePub[:]):
×
699
                                        otherNodePub = node2
×
700
                                case bytes.Equal(node2[:], nodePub[:]):
×
701
                                        otherNodePub = node1
×
702
                                default:
×
703
                                        return fmt.Errorf("node not " +
×
704
                                                "participating in this channel")
×
705
                                }
706

707
                                _, otherNode, err := getNodeByPubKey(
×
708
                                        ctx, db, otherNodePub,
×
709
                                )
×
710
                                if err != nil {
×
711
                                        return fmt.Errorf("unable to fetch "+
×
712
                                                "other node(%x): %w",
×
713
                                                otherNodePub, err)
×
714
                                }
×
715

716
                                return cb(
×
717
                                        info.ChannelPoint, outPolicy != nil,
×
718
                                        otherNode,
×
719
                                )
×
720
                        },
721
                )
722
        }, sqldb.NoOpReset)
723
}
724

725
// ForEachNode iterates through all the stored vertices/nodes in the graph,
726
// executing the passed callback with each node encountered. If the callback
727
// returns an error, then the transaction is aborted and the iteration stops
728
// early. Any operations performed on the NodeTx passed to the call-back are
729
// executed under the same read transaction and so, methods on the NodeTx object
730
// _MUST_ only be called from within the call-back.
731
//
732
// NOTE: part of the V1Store interface.
733
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
×
734
        var (
×
735
                ctx          = context.TODO()
×
736
                lastID int64 = 0
×
737
        )
×
738

×
739
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
×
740
                node, err := buildNode(ctx, db, &dbNode)
×
741
                if err != nil {
×
742
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
743
                                dbNode.ID, err)
×
744
                }
×
745

746
                err = cb(
×
747
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
748
                )
×
749
                if err != nil {
×
750
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
751
                                dbNode.ID, err)
×
752
                }
×
753

754
                return nil
×
755
        }
756

757
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
758
                for {
×
759
                        nodes, err := db.ListNodesPaginated(
×
760
                                ctx, sqlc.ListNodesPaginatedParams{
×
761
                                        Version: int16(ProtocolV1),
×
762
                                        ID:      lastID,
×
763
                                        Limit:   pageSize,
×
764
                                },
×
765
                        )
×
766
                        if err != nil {
×
767
                                return fmt.Errorf("unable to fetch nodes: %w",
×
768
                                        err)
×
769
                        }
×
770

771
                        if len(nodes) == 0 {
×
772
                                break
×
773
                        }
774

775
                        for _, dbNode := range nodes {
×
776
                                err = handleNode(db, dbNode)
×
777
                                if err != nil {
×
778
                                        return err
×
779
                                }
×
780

781
                                lastID = dbNode.ID
×
782
                        }
783
                }
784

785
                return nil
×
786
        }, sqldb.NoOpReset)
787
}
788

789
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
790
// SQLStore and a SQL transaction.
791
type sqlGraphNodeTx struct {
792
        db    SQLQueries
793
        id    int64
794
        node  *models.LightningNode
795
        chain chainhash.Hash
796
}
797

798
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
799
// interface.
800
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
801

802
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
803
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
804

×
805
        return &sqlGraphNodeTx{
×
806
                db:    db,
×
807
                chain: chain,
×
808
                id:    id,
×
809
                node:  node,
×
810
        }
×
811
}
×
812

813
// Node returns the raw information of the node.
814
//
815
// NOTE: This is a part of the NodeRTx interface.
816
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
817
        return s.node
×
818
}
×
819

820
// ForEachChannel can be used to iterate over the node's channels under the same
821
// transaction used to fetch the node.
822
//
823
// NOTE: This is a part of the NodeRTx interface.
824
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
825
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
826

×
827
        ctx := context.TODO()
×
828

×
829
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
830
}
×
831

832
// FetchNode fetches the node with the given pub key under the same transaction
833
// used to fetch the current node. The returned node is also a NodeRTx and any
834
// operations on that NodeRTx will also be done under the same transaction.
835
//
836
// NOTE: This is a part of the NodeRTx interface.
837
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
838
        ctx := context.TODO()
×
839

×
840
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
841
        if err != nil {
×
842
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
843
                        nodePub, err)
×
844
        }
×
845

846
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
847
}
848

849
// ForEachNodeDirectedChannel iterates through all channels of a given node,
850
// executing the passed callback on the directed edge representing the channel
851
// and its incoming policy. If the callback returns an error, then the iteration
852
// is halted with the error propagated back up to the caller.
853
//
854
// Unknown policies are passed into the callback as nil values.
855
//
856
// NOTE: this is part of the graphdb.NodeTraverser interface.
857
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
858
        cb func(channel *DirectedChannel) error) error {
×
859

×
860
        var ctx = context.TODO()
×
861

×
862
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
863
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
864
        }, sqldb.NoOpReset)
×
865
}
866

867
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
868
// graph, executing the passed callback with each node encountered. If the
869
// callback returns an error, then the transaction is aborted and the iteration
870
// stops early.
871
//
872
// NOTE: This is a part of the V1Store interface.
873
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
874
        *lnwire.FeatureVector) error) error {
×
875

×
876
        ctx := context.TODO()
×
877

×
878
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
879
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
880
                        nodePub route.Vertex) error {
×
881

×
882
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
883
                        if err != nil {
×
884
                                return fmt.Errorf("unable to fetch node "+
×
885
                                        "features: %w", err)
×
886
                        }
×
887

888
                        return cb(nodePub, features)
×
889
                })
890
        }, sqldb.NoOpReset)
891
        if err != nil {
×
892
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
893
        }
×
894

895
        return nil
×
896
}
897

898
// ForEachNodeChannel iterates through all channels of the given node,
899
// executing the passed callback with an edge info structure and the policies
900
// of each end of the channel. The first edge policy is the outgoing edge *to*
901
// the connecting node, while the second is the incoming edge *from* the
902
// connecting node. If the callback returns an error, then the iteration is
903
// halted with the error propagated back up to the caller.
904
//
905
// Unknown policies are passed into the callback as nil values.
906
//
907
// NOTE: part of the V1Store interface.
908
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
909
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
910
                *models.ChannelEdgePolicy) error) error {
×
911

×
912
        var ctx = context.TODO()
×
913

×
914
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
915
                dbNode, err := db.GetNodeByPubKey(
×
916
                        ctx, sqlc.GetNodeByPubKeyParams{
×
917
                                Version: int16(ProtocolV1),
×
918
                                PubKey:  nodePub[:],
×
919
                        },
×
920
                )
×
921
                if errors.Is(err, sql.ErrNoRows) {
×
922
                        return nil
×
923
                } else if err != nil {
×
924
                        return fmt.Errorf("unable to fetch node: %w", err)
×
925
                }
×
926

927
                return forEachNodeChannel(
×
928
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
929
                )
×
930
        }, sqldb.NoOpReset)
931
}
932

933
// ChanUpdatesInHorizon returns all the known channel edges which have at least
934
// one edge that has an update timestamp within the specified horizon.
935
//
936
// NOTE: This is part of the V1Store interface.
937
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
NEW
938
        endTime time.Time) ([]ChannelEdge, error) {
×
NEW
939

×
NEW
940
        s.cacheMu.Lock()
×
NEW
941
        defer s.cacheMu.Unlock()
×
NEW
942

×
NEW
943
        var (
×
NEW
944
                ctx = context.TODO()
×
NEW
945
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
NEW
946
                // an additional map to keep track of the edges already seen to
×
NEW
947
                // prevent re-adding it.
×
NEW
948
                edgesSeen    = make(map[uint64]struct{})
×
NEW
949
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
950
                edges        []ChannelEdge
×
NEW
951
                hits         int
×
NEW
952
        )
×
NEW
953
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
954
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
NEW
955
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
NEW
956
                                Version:   int16(ProtocolV1),
×
NEW
957
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
NEW
958
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
NEW
959
                        },
×
NEW
960
                )
×
NEW
961
                if err != nil {
×
NEW
962
                        return err
×
NEW
963
                }
×
964

NEW
965
                for _, row := range rows {
×
NEW
966
                        // If we've already retrieved the info and policies for
×
NEW
967
                        // this edge, then we can skip it as we don't need to do
×
NEW
968
                        // so again.
×
NEW
969
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
NEW
970
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
NEW
971
                                continue
×
972
                        }
973

NEW
974
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
NEW
975
                                hits++
×
NEW
976
                                edgesSeen[chanIDInt] = struct{}{}
×
NEW
977
                                edges = append(edges, channel)
×
NEW
978

×
NEW
979
                                continue
×
980
                        }
981

NEW
982
                        node1, node2, err := buildNodes(
×
NEW
983
                                ctx, db, row.Node, row.Node_2,
×
NEW
984
                        )
×
NEW
985
                        if err != nil {
×
NEW
986
                                return err
×
NEW
987
                        }
×
988

NEW
989
                        channel, err := getAndBuildEdgeInfo(
×
NEW
990
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
NEW
991
                                row.Channel, node1.PubKeyBytes,
×
NEW
992
                                node2.PubKeyBytes,
×
NEW
993
                        )
×
NEW
994
                        if err != nil {
×
NEW
995
                                return fmt.Errorf("unable to build channel "+
×
NEW
996
                                        "info: %w", err)
×
NEW
997
                        }
×
998

NEW
999
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1000
                        if err != nil {
×
NEW
1001
                                return fmt.Errorf("unable to extract channel "+
×
NEW
1002
                                        "policies: %w", err)
×
NEW
1003
                        }
×
1004

NEW
1005
                        p1, p2, err := getAndBuildChanPolicies(
×
NEW
1006
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
NEW
1007
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
NEW
1008
                        )
×
NEW
1009
                        if err != nil {
×
NEW
1010
                                return fmt.Errorf("unable to build channel "+
×
NEW
1011
                                        "policies: %w", err)
×
NEW
1012
                        }
×
1013

NEW
1014
                        edgesSeen[chanIDInt] = struct{}{}
×
NEW
1015
                        chanEdge := ChannelEdge{
×
NEW
1016
                                Info:    channel,
×
NEW
1017
                                Policy1: p1,
×
NEW
1018
                                Policy2: p2,
×
NEW
1019
                                Node1:   node1,
×
NEW
1020
                                Node2:   node2,
×
NEW
1021
                        }
×
NEW
1022
                        edges = append(edges, chanEdge)
×
NEW
1023
                        edgesToCache[chanIDInt] = chanEdge
×
1024
                }
1025

NEW
1026
                return nil
×
NEW
1027
        }, func() {
×
NEW
1028
                edgesSeen = make(map[uint64]struct{})
×
NEW
1029
                edgesToCache = make(map[uint64]ChannelEdge)
×
NEW
1030
                edges = nil
×
NEW
1031
        })
×
NEW
1032
        if err != nil {
×
NEW
1033
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
NEW
1034
        }
×
1035

1036
        // Insert any edges loaded from disk into the cache.
NEW
1037
        for chanid, channel := range edgesToCache {
×
NEW
1038
                s.chanCache.insert(chanid, channel)
×
NEW
1039
        }
×
1040

NEW
1041
        if len(edges) > 0 {
×
NEW
1042
                log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
×
NEW
1043
                        float64(hits)/float64(len(edges)), hits, len(edges))
×
NEW
1044
        } else {
×
NEW
1045
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
NEW
1046
                        "horizon (%s, %s)", startTime, endTime)
×
NEW
1047
        }
×
1048

NEW
1049
        return edges, nil
×
1050
}
1051

1052
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1053
// data to the call-back.
1054
//
1055
// NOTE: The callback contents MUST not be modified.
1056
//
1057
// NOTE: part of the V1Store interface.
1058
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
NEW
1059
        chans map[uint64]*DirectedChannel) error) error {
×
NEW
1060

×
NEW
1061
        var ctx = context.TODO()
×
NEW
1062

×
NEW
1063
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1064
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
NEW
1065
                        nodePub route.Vertex) error {
×
NEW
1066

×
NEW
1067
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
NEW
1068
                        if err != nil {
×
NEW
1069
                                return fmt.Errorf("unable to fetch "+
×
NEW
1070
                                        "node(id=%d) features: %w", nodeID, err)
×
NEW
1071
                        }
×
1072

NEW
1073
                        toNodeCallback := func() route.Vertex {
×
NEW
1074
                                return nodePub
×
NEW
1075
                        }
×
1076

NEW
1077
                        rows, err := db.ListChannelsByNodeID(
×
NEW
1078
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
1079
                                        Version: int16(ProtocolV1),
×
NEW
1080
                                        NodeID1: nodeID,
×
NEW
1081
                                },
×
NEW
1082
                        )
×
NEW
1083
                        if err != nil {
×
NEW
1084
                                return fmt.Errorf("unable to fetch channels "+
×
NEW
1085
                                        "of node(id=%d): %w", nodeID, err)
×
NEW
1086
                        }
×
1087

NEW
1088
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
NEW
1089
                        for _, row := range rows {
×
NEW
1090
                                node1, node2, err := buildNodeVertices(
×
NEW
1091
                                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1092
                                )
×
NEW
1093
                                if err != nil {
×
NEW
1094
                                        return err
×
NEW
1095
                                }
×
1096

NEW
1097
                                e, err := getAndBuildEdgeInfo(
×
NEW
1098
                                        ctx, db, s.cfg.ChainHash,
×
NEW
1099
                                        row.Channel.ID, row.Channel, node1,
×
NEW
1100
                                        node2,
×
NEW
1101
                                )
×
NEW
1102
                                if err != nil {
×
NEW
1103
                                        return fmt.Errorf("unable to build "+
×
NEW
1104
                                                "channel info: %w", err)
×
NEW
1105
                                }
×
1106

NEW
1107
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
NEW
1108
                                        row,
×
NEW
1109
                                )
×
NEW
1110
                                if err != nil {
×
NEW
1111
                                        return fmt.Errorf("unable to "+
×
NEW
1112
                                                "extract channel "+
×
NEW
1113
                                                "policies: %w", err)
×
NEW
1114
                                }
×
1115

NEW
1116
                                p1, p2, err := getAndBuildChanPolicies(
×
NEW
1117
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
NEW
1118
                                        node1, node2,
×
NEW
1119
                                )
×
NEW
1120
                                if err != nil {
×
NEW
1121
                                        return fmt.Errorf("unable to "+
×
NEW
1122
                                                "build channel policies: %w",
×
NEW
1123
                                                err)
×
NEW
1124
                                }
×
1125

1126
                                // Determine the outgoing and incoming policy
1127
                                // for this channel and node combo.
NEW
1128
                                outPolicy, inPolicy := p1, p2
×
NEW
1129
                                if p1 != nil && p1.ToNode == nodePub {
×
NEW
1130
                                        outPolicy, inPolicy = p2, p1
×
NEW
1131
                                } else if p2 != nil && p2.ToNode != nodePub {
×
NEW
1132
                                        outPolicy, inPolicy = p2, p1
×
NEW
1133
                                }
×
1134

NEW
1135
                                var cachedInPolicy *models.CachedEdgePolicy
×
NEW
1136
                                if inPolicy != nil {
×
NEW
1137
                                        cachedInPolicy = models.NewCachedPolicy(
×
NEW
1138
                                                p2,
×
NEW
1139
                                        )
×
NEW
1140
                                        cachedInPolicy.ToNodePubKey =
×
NEW
1141
                                                toNodeCallback
×
NEW
1142
                                        cachedInPolicy.ToNodeFeatures =
×
NEW
1143
                                                features
×
NEW
1144
                                }
×
1145

NEW
1146
                                var inboundFee lnwire.Fee
×
NEW
1147
                                outPolicy.InboundFee.WhenSome(
×
NEW
1148
                                        func(fee lnwire.Fee) {
×
NEW
1149
                                                inboundFee = fee
×
NEW
1150
                                        },
×
1151
                                )
1152

NEW
1153
                                directedChannel := &DirectedChannel{
×
NEW
1154
                                        ChannelID: e.ChannelID,
×
NEW
1155
                                        IsNode1: nodePub ==
×
NEW
1156
                                                e.NodeKey1Bytes,
×
NEW
1157
                                        OtherNode:    e.NodeKey2Bytes,
×
NEW
1158
                                        Capacity:     e.Capacity,
×
NEW
1159
                                        OutPolicySet: p1 != nil,
×
NEW
1160
                                        InPolicy:     cachedInPolicy,
×
NEW
1161
                                        InboundFee:   inboundFee,
×
NEW
1162
                                }
×
NEW
1163

×
NEW
1164
                                if nodePub == e.NodeKey2Bytes {
×
NEW
1165
                                        directedChannel.OtherNode =
×
NEW
1166
                                                e.NodeKey1Bytes
×
NEW
1167
                                }
×
1168

NEW
1169
                                channels[e.ChannelID] = directedChannel
×
1170
                        }
1171

NEW
1172
                        return cb(nodePub, channels)
×
1173
                })
1174
        }, sqldb.NoOpReset)
1175
}
1176

1177
// ForEachChannel iterates through all the channel edges stored within the
1178
// graph and invokes the passed callback for each edge. The callback takes two
1179
// edges as since this is a directed graph, both the in/out edges are visited.
1180
// If the callback returns an error, then the transaction is aborted and the
1181
// iteration stops early.
1182
//
1183
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1184
// for that particular channel edge routing policy will be passed into the
1185
// callback.
1186
//
1187
// NOTE: part of the V1Store interface.
1188
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
NEW
1189
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
NEW
1190

×
NEW
1191
        ctx := context.TODO()
×
NEW
1192

×
NEW
1193
        handleChannel := func(db SQLQueries,
×
NEW
1194
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
NEW
1195

×
NEW
1196
                node1, node2, err := buildNodeVertices(
×
NEW
1197
                        row.Node1Pubkey, row.Node2Pubkey,
×
NEW
1198
                )
×
NEW
1199
                if err != nil {
×
NEW
1200
                        return fmt.Errorf("unable to build node vertices: %w",
×
NEW
1201
                                err)
×
NEW
1202
                }
×
1203

NEW
1204
                edge, err := getAndBuildEdgeInfo(
×
NEW
1205
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
NEW
1206
                        node1, node2,
×
NEW
1207
                )
×
NEW
1208
                if err != nil {
×
NEW
1209
                        return fmt.Errorf("unable to build channel info: %w",
×
NEW
1210
                                err)
×
NEW
1211
                }
×
1212

NEW
1213
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
NEW
1214
                if err != nil {
×
NEW
1215
                        return fmt.Errorf("unable to extract channel "+
×
NEW
1216
                                "policies: %w", err)
×
NEW
1217
                }
×
1218

NEW
1219
                p1, p2, err := getAndBuildChanPolicies(
×
NEW
1220
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
NEW
1221
                )
×
NEW
1222
                if err != nil {
×
NEW
1223
                        return fmt.Errorf("unable to build channel "+
×
NEW
1224
                                "policies: %w", err)
×
NEW
1225
                }
×
1226

NEW
1227
                err = cb(edge, p1, p2)
×
NEW
1228
                if err != nil {
×
NEW
1229
                        return fmt.Errorf("callback failed for channel "+
×
NEW
1230
                                "id=%d: %w", edge.ChannelID, err)
×
NEW
1231
                }
×
1232

NEW
1233
                return nil
×
1234
        }
1235

NEW
1236
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1237
                var lastID int64
×
NEW
1238
                for {
×
NEW
1239
                        //nolint:ll
×
NEW
1240
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
NEW
1241
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
NEW
1242
                                        Version: int16(ProtocolV1),
×
NEW
1243
                                        ID:      lastID,
×
NEW
1244
                                        Limit:   pageSize,
×
NEW
1245
                                },
×
NEW
1246
                        )
×
NEW
1247
                        if err != nil {
×
NEW
1248
                                return err
×
NEW
1249
                        }
×
1250

NEW
1251
                        if len(rows) == 0 {
×
NEW
1252
                                break
×
1253
                        }
1254

NEW
1255
                        for _, row := range rows {
×
NEW
1256
                                err := handleChannel(db, row)
×
NEW
1257
                                if err != nil {
×
NEW
1258
                                        return err
×
NEW
1259
                                }
×
1260

NEW
1261
                                lastID = row.Channel.ID
×
1262
                        }
1263
                }
1264

NEW
1265
                return nil
×
1266
        }, sqldb.NoOpReset)
1267
}
1268

1269
// FilterChannelRange returns the channel ID's of all known channels which were
1270
// mined in a block height within the passed range. The channel IDs are grouped
1271
// by their common block height. This method can be used to quickly share with a
1272
// peer the set of channels we know of within a particular range to catch them
1273
// up after a period of time offline. If withTimestamps is true then the
1274
// timestamp info of the latest received channel update messages of the channel
1275
// will be included in the response.
1276
//
1277
// NOTE: This is part of the V1Store interface.
1278
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
NEW
1279
        withTimestamps bool) ([]BlockChannelRange, error) {
×
NEW
1280

×
NEW
1281
        var (
×
NEW
1282
                ctx       = context.TODO()
×
NEW
1283
                startSCID = &lnwire.ShortChannelID{
×
NEW
1284
                        BlockHeight: startHeight,
×
NEW
1285
                }
×
NEW
1286
                endSCID = lnwire.ShortChannelID{
×
NEW
1287
                        BlockHeight: endHeight,
×
NEW
1288
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
NEW
1289
                        TxPosition:  math.MaxUint16,
×
NEW
1290
                }
×
NEW
1291
        )
×
NEW
1292

×
NEW
1293
        var chanIDStart [8]byte
×
NEW
1294
        byteOrder.PutUint64(chanIDStart[:], startSCID.ToUint64())
×
NEW
1295
        var chanIDEnd [8]byte
×
NEW
1296
        byteOrder.PutUint64(chanIDEnd[:], endSCID.ToUint64())
×
NEW
1297

×
NEW
1298
        // 1) get all channels where channelID is between start and end chan ID.
×
NEW
1299
        // 2) skip if not public (ie, no channel_proof)
×
NEW
1300
        // 3) collect that channel.
×
NEW
1301
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
NEW
1302
        //    and add those timestamps to the collected channel.
×
NEW
1303
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
NEW
1304
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1305
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
NEW
1306
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
NEW
1307
                                StartScid: chanIDStart[:],
×
NEW
1308
                                EndScid:   chanIDEnd[:],
×
NEW
1309
                        },
×
NEW
1310
                )
×
NEW
1311
                if err != nil {
×
NEW
1312
                        return fmt.Errorf("unable to fetch channel range: %w",
×
NEW
1313
                                err)
×
NEW
1314
                }
×
1315

NEW
1316
                for _, dbChan := range dbChans {
×
NEW
1317
                        cid := lnwire.NewShortChanIDFromInt(
×
NEW
1318
                                byteOrder.Uint64(dbChan.Scid),
×
NEW
1319
                        )
×
NEW
1320
                        chanInfo := NewChannelUpdateInfo(
×
NEW
1321
                                cid, time.Time{}, time.Time{},
×
NEW
1322
                        )
×
NEW
1323

×
NEW
1324
                        if !withTimestamps {
×
NEW
1325
                                channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1326
                                        channelsPerBlock[cid.BlockHeight],
×
NEW
1327
                                        chanInfo,
×
NEW
1328
                                )
×
NEW
1329

×
NEW
1330
                                continue
×
1331
                        }
1332

1333
                        //nolint:ll
NEW
1334
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1335
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1336
                                        Version:   int16(ProtocolV1),
×
NEW
1337
                                        ChannelID: dbChan.ID,
×
NEW
1338
                                        NodeID:    dbChan.NodeID1,
×
NEW
1339
                                },
×
NEW
1340
                        )
×
NEW
1341
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1342
                                return fmt.Errorf("unable to fetch node1 "+
×
NEW
1343
                                        "policy: %w", err)
×
NEW
1344
                        } else if err == nil {
×
NEW
1345
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
NEW
1346
                                        node1Policy.LastUpdate.Int64, 0,
×
NEW
1347
                                )
×
NEW
1348
                        }
×
1349

1350
                        //nolint:ll
NEW
1351
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
1352
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1353
                                        Version:   int16(ProtocolV1),
×
NEW
1354
                                        ChannelID: dbChan.ID,
×
NEW
1355
                                        NodeID:    dbChan.NodeID2,
×
NEW
1356
                                },
×
NEW
1357
                        )
×
NEW
1358
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
1359
                                return fmt.Errorf("unable to fetch node2 "+
×
NEW
1360
                                        "policy: %w", err)
×
NEW
1361
                        } else if err == nil {
×
NEW
1362
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
NEW
1363
                                        node2Policy.LastUpdate.Int64, 0,
×
NEW
1364
                                )
×
NEW
1365
                        }
×
1366

NEW
1367
                        channelsPerBlock[cid.BlockHeight] = append(
×
NEW
1368
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
NEW
1369
                        )
×
1370
                }
1371

NEW
1372
                return nil
×
NEW
1373
        }, func() {
×
NEW
1374
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
NEW
1375
        })
×
NEW
1376
        if err != nil {
×
NEW
1377
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
NEW
1378
        }
×
1379

NEW
1380
        if len(channelsPerBlock) == 0 {
×
NEW
1381
                return nil, nil
×
NEW
1382
        }
×
1383

1384
        // Return the channel ranges in ascending block height order.
NEW
1385
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
NEW
1386
        slices.Sort(blocks)
×
NEW
1387

×
NEW
1388
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
NEW
1389
                return BlockChannelRange{
×
NEW
1390
                        Height:   block,
×
NEW
1391
                        Channels: channelsPerBlock[block],
×
NEW
1392
                }
×
NEW
1393
        }), nil
×
1394
}
1395

1396
// forEachNodeDirectedChannel iterates through all channels of a given
1397
// node, executing the passed callback on the directed edge representing the
1398
// channel and its incoming policy. If the node is not found, no error is
1399
// returned.
1400
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
1401
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
1402

×
1403
        toNodeCallback := func() route.Vertex {
×
1404
                return nodePub
×
1405
        }
×
1406

1407
        dbID, err := db.GetNodeIDByPubKey(
×
1408
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
1409
                        Version: int16(ProtocolV1),
×
1410
                        PubKey:  nodePub[:],
×
1411
                },
×
1412
        )
×
1413
        if errors.Is(err, sql.ErrNoRows) {
×
1414
                return nil
×
1415
        } else if err != nil {
×
1416
                return fmt.Errorf("unable to fetch node: %w", err)
×
1417
        }
×
1418

1419
        rows, err := db.ListChannelsByNodeID(
×
1420
                ctx, sqlc.ListChannelsByNodeIDParams{
×
1421
                        Version: int16(ProtocolV1),
×
1422
                        NodeID1: dbID,
×
1423
                },
×
1424
        )
×
1425
        if err != nil {
×
1426
                return fmt.Errorf("unable to fetch channels: %w", err)
×
1427
        }
×
1428

1429
        // Exit early if there are no channels for this node so we don't
1430
        // do the unnecessary feature fetching.
1431
        if len(rows) == 0 {
×
1432
                return nil
×
1433
        }
×
1434

1435
        features, err := getNodeFeatures(ctx, db, dbID)
×
1436
        if err != nil {
×
1437
                return fmt.Errorf("unable to fetch node features: %w", err)
×
1438
        }
×
1439

1440
        for _, row := range rows {
×
1441
                node1, node2, err := buildNodeVertices(
×
1442
                        row.Node1Pubkey, row.Node2Pubkey,
×
1443
                )
×
1444
                if err != nil {
×
1445
                        return fmt.Errorf("unable to build node vertices: %w",
×
1446
                                err)
×
1447
                }
×
1448

NEW
1449
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
1450

×
1451
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1452
                if err != nil {
×
1453
                        return err
×
1454
                }
×
1455

1456
                var p1, p2 *models.CachedEdgePolicy
×
1457
                if dbPol1 != nil {
×
1458
                        policy1, err := buildChanPolicy(
×
1459
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
1460
                        )
×
1461
                        if err != nil {
×
1462
                                return err
×
1463
                        }
×
1464

1465
                        p1 = models.NewCachedPolicy(policy1)
×
1466
                }
1467
                if dbPol2 != nil {
×
1468
                        policy2, err := buildChanPolicy(
×
1469
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
1470
                        )
×
1471
                        if err != nil {
×
1472
                                return err
×
1473
                        }
×
1474

1475
                        p2 = models.NewCachedPolicy(policy2)
×
1476
                }
1477

1478
                // Determine the outgoing and incoming policy for this
1479
                // channel and node combo.
1480
                outPolicy, inPolicy := p1, p2
×
1481
                if p1 != nil && node2 == nodePub {
×
1482
                        outPolicy, inPolicy = p2, p1
×
1483
                } else if p2 != nil && node1 != nodePub {
×
1484
                        outPolicy, inPolicy = p2, p1
×
1485
                }
×
1486

1487
                var cachedInPolicy *models.CachedEdgePolicy
×
1488
                if inPolicy != nil {
×
1489
                        cachedInPolicy = inPolicy
×
1490
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
1491
                        cachedInPolicy.ToNodeFeatures = features
×
1492
                }
×
1493

1494
                directedChannel := &DirectedChannel{
×
1495
                        ChannelID:    edge.ChannelID,
×
1496
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
1497
                        OtherNode:    edge.NodeKey2Bytes,
×
1498
                        Capacity:     edge.Capacity,
×
1499
                        OutPolicySet: outPolicy != nil,
×
1500
                        InPolicy:     cachedInPolicy,
×
1501
                }
×
1502
                if outPolicy != nil {
×
1503
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
1504
                                directedChannel.InboundFee = fee
×
1505
                        })
×
1506
                }
1507

1508
                if nodePub == edge.NodeKey2Bytes {
×
1509
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
1510
                }
×
1511

1512
                if err := cb(directedChannel); err != nil {
×
1513
                        return err
×
1514
                }
×
1515
        }
1516

1517
        return nil
×
1518
}
1519

1520
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
1521
// and executes the provided callback for each node.
1522
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
1523
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
1524

×
1525
        var lastID int64
×
1526

×
1527
        for {
×
1528
                nodes, err := db.ListNodeIDsAndPubKeys(
×
1529
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1530
                                Version: int16(ProtocolV1),
×
1531
                                ID:      lastID,
×
1532
                                Limit:   pageSize,
×
1533
                        },
×
1534
                )
×
1535
                if err != nil {
×
1536
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
1537
                }
×
1538

1539
                if len(nodes) == 0 {
×
1540
                        break
×
1541
                }
1542

1543
                for _, node := range nodes {
×
1544
                        var pub route.Vertex
×
1545
                        copy(pub[:], node.PubKey)
×
1546

×
1547
                        if err := cb(node.ID, pub); err != nil {
×
1548
                                return fmt.Errorf("forEachNodeCacheable "+
×
1549
                                        "callback failed for node(id=%d): %w",
×
1550
                                        node.ID, err)
×
1551
                        }
×
1552

1553
                        lastID = node.ID
×
1554
                }
1555
        }
1556

1557
        return nil
×
1558
}
1559

1560
// forEachNodeChannel iterates through all channels of a node, executing
1561
// the passed callback on each. The call-back is provided with the channel's
1562
// edge information, the outgoing policy and the incoming policy for the
1563
// channel and node combo.
1564
func forEachNodeChannel(ctx context.Context, db SQLQueries,
1565
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
1566
                *models.ChannelEdgePolicy,
1567
                *models.ChannelEdgePolicy) error) error {
×
1568

×
1569
        // Get all the V1 channels for this node.Add commentMore actions
×
1570
        rows, err := db.ListChannelsByNodeID(
×
1571
                ctx, sqlc.ListChannelsByNodeIDParams{
×
1572
                        Version: int16(ProtocolV1),
×
1573
                        NodeID1: id,
×
1574
                },
×
1575
        )
×
1576
        if err != nil {
×
1577
                return fmt.Errorf("unable to fetch channels: %w", err)
×
1578
        }
×
1579

1580
        // Call the call-back for each channel and its known policies.
1581
        for _, row := range rows {
×
1582
                node1, node2, err := buildNodeVertices(
×
1583
                        row.Node1Pubkey, row.Node2Pubkey,
×
1584
                )
×
1585
                if err != nil {
×
1586
                        return fmt.Errorf("unable to build node vertices: %w",
×
1587
                                err)
×
1588
                }
×
1589

1590
                edge, err := getAndBuildEdgeInfo(
×
1591
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
1592
                        node2,
×
1593
                )
×
1594
                if err != nil {
×
1595
                        return fmt.Errorf("unable to build channel info: %w",
×
1596
                                err)
×
1597
                }
×
1598

1599
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1600
                if err != nil {
×
1601
                        return fmt.Errorf("unable to extract channel "+
×
1602
                                "policies: %w", err)
×
1603
                }
×
1604

1605
                p1, p2, err := getAndBuildChanPolicies(
×
1606
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1607
                )
×
1608
                if err != nil {
×
1609
                        return fmt.Errorf("unable to build channel "+
×
1610
                                "policies: %w", err)
×
1611
                }
×
1612

1613
                // Determine the outgoing and incoming policy for this
1614
                // channel and node combo.
1615
                p1ToNode := row.Channel.NodeID2
×
1616
                p2ToNode := row.Channel.NodeID1
×
1617
                outPolicy, inPolicy := p1, p2
×
1618
                if (p1 != nil && p1ToNode == id) ||
×
1619
                        (p2 != nil && p2ToNode != id) {
×
1620

×
1621
                        outPolicy, inPolicy = p2, p1
×
1622
                }
×
1623

1624
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
1625
                        return err
×
1626
                }
×
1627
        }
1628

1629
        return nil
×
1630
}
1631

1632
// updateChanEdgePolicy upserts the channel policy info we have stored for
1633
// a channel we already know of.
1634
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
1635
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
1636
        error) {
×
1637

×
1638
        var (
×
1639
                node1Pub, node2Pub route.Vertex
×
1640
                isNode1            bool
×
1641
                chanIDB            [8]byte
×
1642
        )
×
1643
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
1644

×
1645
        // Check that this edge policy refers to a channel that we already
×
1646
        // know of. We do this explicitly so that we can return the appropriate
×
1647
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
1648
        // abort the transaction which would abort the entire batch.
×
1649
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
1650
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
1651
                        Scid:    chanIDB[:],
×
1652
                        Version: int16(ProtocolV1),
×
1653
                },
×
1654
        )
×
1655
        if errors.Is(err, sql.ErrNoRows) {
×
1656
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
1657
        } else if err != nil {
×
1658
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
1659
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
1660
        }
×
1661

1662
        copy(node1Pub[:], dbChan.Node1PubKey)
×
1663
        copy(node2Pub[:], dbChan.Node2PubKey)
×
1664

×
1665
        // Figure out which node this edge is from.
×
1666
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
1667
        nodeID := dbChan.NodeID1
×
1668
        if !isNode1 {
×
1669
                nodeID = dbChan.NodeID2
×
1670
        }
×
1671

1672
        var (
×
1673
                inboundBase sql.NullInt64
×
1674
                inboundRate sql.NullInt64
×
1675
        )
×
1676
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
1677
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
1678
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
1679
        })
×
1680

1681
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
1682
                Version:     int16(ProtocolV1),
×
1683
                ChannelID:   dbChan.ID,
×
1684
                NodeID:      nodeID,
×
1685
                Timelock:    int32(edge.TimeLockDelta),
×
1686
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
1687
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
1688
                MinHtlcMsat: int64(edge.MinHTLC),
×
1689
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
1690
                Disabled: sql.NullBool{
×
1691
                        Valid: true,
×
1692
                        Bool:  edge.IsDisabled(),
×
1693
                },
×
1694
                MaxHtlcMsat: sql.NullInt64{
×
1695
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
1696
                        Int64: int64(edge.MaxHTLC),
×
1697
                },
×
1698
                InboundBaseFeeMsat:      inboundBase,
×
1699
                InboundFeeRateMilliMsat: inboundRate,
×
1700
                Signature:               edge.SigBytes,
×
1701
        })
×
1702
        if err != nil {
×
1703
                return node1Pub, node2Pub, isNode1,
×
1704
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
1705
        }
×
1706

1707
        // Convert the flat extra opaque data into a map of TLV types to
1708
        // values.
1709
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1710
        if err != nil {
×
1711
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
1712
                        "marshal extra opaque data: %w", err)
×
1713
        }
×
1714

1715
        // Update the channel policy's extra signed fields.
1716
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
1717
        if err != nil {
×
1718
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
1719
                        "policy extra TLVs: %w", err)
×
1720
        }
×
1721

1722
        return node1Pub, node2Pub, isNode1, nil
×
1723
}
1724

1725
// getNodeByPubKey attempts to look up a target node by its public key.
1726
func getNodeByPubKey(ctx context.Context, db SQLQueries,
1727
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
1728

×
1729
        dbNode, err := db.GetNodeByPubKey(
×
1730
                ctx, sqlc.GetNodeByPubKeyParams{
×
1731
                        Version: int16(ProtocolV1),
×
1732
                        PubKey:  pubKey[:],
×
1733
                },
×
1734
        )
×
1735
        if errors.Is(err, sql.ErrNoRows) {
×
1736
                return 0, nil, ErrGraphNodeNotFound
×
1737
        } else if err != nil {
×
1738
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
1739
        }
×
1740

1741
        node, err := buildNode(ctx, db, &dbNode)
×
1742
        if err != nil {
×
1743
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
1744
        }
×
1745

1746
        return dbNode.ID, node, nil
×
1747
}
1748

1749
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
1750
// provided database channel row and the public keys of the two nodes
1751
// involved in the channel.
1752
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
NEW
1753
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
1754

×
1755
        return &models.CachedEdgeInfo{
×
1756
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
1757
                NodeKey1Bytes: node1Pub,
×
1758
                NodeKey2Bytes: node2Pub,
×
1759
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
1760
        }
×
1761
}
×
1762

1763
// buildNode constructs a LightningNode instance from the given database node
1764
// record. The node's features, addresses and extra signed fields are also
1765
// fetched from the database and set on the node.
1766
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
1767
        *models.LightningNode, error) {
×
1768

×
1769
        if dbNode.Version != int16(ProtocolV1) {
×
1770
                return nil, fmt.Errorf("unsupported node version: %d",
×
1771
                        dbNode.Version)
×
1772
        }
×
1773

1774
        var pub [33]byte
×
1775
        copy(pub[:], dbNode.PubKey)
×
1776

×
1777
        node := &models.LightningNode{
×
1778
                PubKeyBytes: pub,
×
1779
                Features:    lnwire.EmptyFeatureVector(),
×
1780
                LastUpdate:  time.Unix(0, 0),
×
1781
        }
×
1782

×
1783
        if len(dbNode.Signature) == 0 {
×
1784
                return node, nil
×
1785
        }
×
1786

1787
        node.HaveNodeAnnouncement = true
×
1788
        node.AuthSigBytes = dbNode.Signature
×
1789
        node.Alias = dbNode.Alias.String
×
1790
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
1791

×
1792
        var err error
×
1793
        if dbNode.Color.Valid {
×
1794
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
1795
                if err != nil {
×
1796
                        return nil, fmt.Errorf("unable to decode color: %w",
×
1797
                                err)
×
1798
                }
×
1799
        }
1800

1801
        // Fetch the node's features.
1802
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
1803
        if err != nil {
×
1804
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1805
                        "features: %w", dbNode.ID, err)
×
1806
        }
×
1807

1808
        // Fetch the node's addresses.
1809
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
1810
        if err != nil {
×
1811
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1812
                        "addresses: %w", dbNode.ID, err)
×
1813
        }
×
1814

1815
        // Fetch the node's extra signed fields.
1816
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
1817
        if err != nil {
×
1818
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
1819
                        "extra signed fields: %w", dbNode.ID, err)
×
1820
        }
×
1821

1822
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
1823
        if err != nil {
×
1824
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
1825
                        "fields: %w", err)
×
1826
        }
×
1827

1828
        if len(recs) != 0 {
×
1829
                node.ExtraOpaqueData = recs
×
1830
        }
×
1831

1832
        return node, nil
×
1833
}
1834

1835
// getNodeFeatures fetches the feature bits and constructs the feature vector
1836
// for a node with the given DB ID.
1837
func getNodeFeatures(ctx context.Context, db SQLQueries,
1838
        nodeID int64) (*lnwire.FeatureVector, error) {
×
1839

×
1840
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
1841
        if err != nil {
×
1842
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
1843
                        nodeID, err)
×
1844
        }
×
1845

1846
        features := lnwire.EmptyFeatureVector()
×
1847
        for _, feature := range rows {
×
1848
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
1849
        }
×
1850

1851
        return features, nil
×
1852
}
1853

1854
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
1855
// given DB ID.
1856
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
1857
        nodeID int64) (map[uint64][]byte, error) {
×
1858

×
1859
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
1860
        if err != nil {
×
1861
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
1862
                        "signed fields: %w", nodeID, err)
×
1863
        }
×
1864

1865
        extraFields := make(map[uint64][]byte)
×
1866
        for _, field := range fields {
×
1867
                extraFields[uint64(field.Type)] = field.Value
×
1868
        }
×
1869

1870
        return extraFields, nil
×
1871
}
1872

1873
// upsertNode upserts the node record into the database. If the node already
1874
// exists, then the node's information is updated. If the node doesn't exist,
1875
// then a new node is created. The node's features, addresses and extra TLV
1876
// types are also updated. The node's DB ID is returned.
1877
func upsertNode(ctx context.Context, db SQLQueries,
1878
        node *models.LightningNode) (int64, error) {
×
1879

×
1880
        params := sqlc.UpsertNodeParams{
×
1881
                Version: int16(ProtocolV1),
×
1882
                PubKey:  node.PubKeyBytes[:],
×
1883
        }
×
1884

×
1885
        if node.HaveNodeAnnouncement {
×
1886
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
1887
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
1888
                params.Alias = sqldb.SQLStr(node.Alias)
×
1889
                params.Signature = node.AuthSigBytes
×
1890
        }
×
1891

1892
        nodeID, err := db.UpsertNode(ctx, params)
×
1893
        if err != nil {
×
1894
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
1895
                        err)
×
1896
        }
×
1897

1898
        // We can exit here if we don't have the announcement yet.
1899
        if !node.HaveNodeAnnouncement {
×
1900
                return nodeID, nil
×
1901
        }
×
1902

1903
        // Update the node's features.
1904
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
1905
        if err != nil {
×
1906
                return 0, fmt.Errorf("inserting node features: %w", err)
×
1907
        }
×
1908

1909
        // Update the node's addresses.
1910
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
1911
        if err != nil {
×
1912
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
1913
        }
×
1914

1915
        // Convert the flat extra opaque data into a map of TLV types to
1916
        // values.
1917
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
1918
        if err != nil {
×
1919
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
1920
                        err)
×
1921
        }
×
1922

1923
        // Update the node's extra signed fields.
1924
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
1925
        if err != nil {
×
1926
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
1927
        }
×
1928

1929
        return nodeID, nil
×
1930
}
1931

1932
// upsertNodeFeatures updates the node's features node_features table. This
1933
// includes deleting any feature bits no longer present and inserting any new
1934
// feature bits. If the feature bit does not yet exist in the features table,
1935
// then an entry is created in that table first.
1936
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
1937
        features *lnwire.FeatureVector) error {
×
1938

×
1939
        // Get any existing features for the node.
×
1940
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
1941
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1942
                return err
×
1943
        }
×
1944

1945
        // Copy the nodes latest set of feature bits.
1946
        newFeatures := make(map[int32]struct{})
×
1947
        if features != nil {
×
1948
                for feature := range features.Features() {
×
1949
                        newFeatures[int32(feature)] = struct{}{}
×
1950
                }
×
1951
        }
1952

1953
        // For any current feature that already exists in the DB, remove it from
1954
        // the in-memory map. For any existing feature that does not exist in
1955
        // the in-memory map, delete it from the database.
1956
        for _, feature := range existingFeatures {
×
1957
                // The feature is still present, so there are no updates to be
×
1958
                // made.
×
1959
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
1960
                        delete(newFeatures, feature.FeatureBit)
×
1961
                        continue
×
1962
                }
1963

1964
                // The feature is no longer present, so we remove it from the
1965
                // database.
1966
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
1967
                        NodeID:     nodeID,
×
1968
                        FeatureBit: feature.FeatureBit,
×
1969
                })
×
1970
                if err != nil {
×
1971
                        return fmt.Errorf("unable to delete node(%d) "+
×
1972
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
1973
                                err)
×
1974
                }
×
1975
        }
1976

1977
        // Any remaining entries in newFeatures are new features that need to be
1978
        // added to the database for the first time.
1979
        for feature := range newFeatures {
×
1980
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
1981
                        NodeID:     nodeID,
×
1982
                        FeatureBit: feature,
×
1983
                })
×
1984
                if err != nil {
×
1985
                        return fmt.Errorf("unable to insert node(%d) "+
×
1986
                                "feature(%v): %w", nodeID, feature, err)
×
1987
                }
×
1988
        }
1989

1990
        return nil
×
1991
}
1992

1993
// fetchNodeFeatures fetches the features for a node with the given public key.
1994
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
1995
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
1996

×
1997
        rows, err := queries.GetNodeFeaturesByPubKey(
×
1998
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
1999
                        PubKey:  nodePub[:],
×
2000
                        Version: int16(ProtocolV1),
×
2001
                },
×
2002
        )
×
2003
        if err != nil {
×
2004
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
2005
                        nodePub, err)
×
2006
        }
×
2007

2008
        features := lnwire.EmptyFeatureVector()
×
2009
        for _, bit := range rows {
×
2010
                features.Set(lnwire.FeatureBit(bit))
×
2011
        }
×
2012

2013
        return features, nil
×
2014
}
2015

2016
// dbAddressType is an enum type that represents the different address types
2017
// that we store in the node_addresses table. The address type determines how
2018
// the address is to be serialised/deserialize.
2019
type dbAddressType uint8
2020

2021
const (
2022
        addressTypeIPv4   dbAddressType = 1
2023
        addressTypeIPv6   dbAddressType = 2
2024
        addressTypeTorV2  dbAddressType = 3
2025
        addressTypeTorV3  dbAddressType = 4
2026
        addressTypeOpaque dbAddressType = math.MaxInt8
2027
)
2028

2029
// upsertNodeAddresses updates the node's addresses in the database. This
2030
// includes deleting any existing addresses and inserting the new set of
2031
// addresses. The deletion is necessary since the ordering of the addresses may
2032
// change, and we need to ensure that the database reflects the latest set of
2033
// addresses so that at the time of reconstructing the node announcement, the
2034
// order is preserved and the signature over the message remains valid.
2035
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
2036
        addresses []net.Addr) error {
×
2037

×
2038
        // Delete any existing addresses for the node. This is required since
×
2039
        // even if the new set of addresses is the same, the ordering may have
×
2040
        // changed for a given address type.
×
2041
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
2042
        if err != nil {
×
2043
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
2044
                        nodeID, err)
×
2045
        }
×
2046

2047
        // Copy the nodes latest set of addresses.
2048
        newAddresses := map[dbAddressType][]string{
×
2049
                addressTypeIPv4:   {},
×
2050
                addressTypeIPv6:   {},
×
2051
                addressTypeTorV2:  {},
×
2052
                addressTypeTorV3:  {},
×
2053
                addressTypeOpaque: {},
×
2054
        }
×
2055
        addAddr := func(t dbAddressType, addr net.Addr) {
×
2056
                newAddresses[t] = append(newAddresses[t], addr.String())
×
2057
        }
×
2058

2059
        for _, address := range addresses {
×
2060
                switch addr := address.(type) {
×
2061
                case *net.TCPAddr:
×
2062
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
2063
                                addAddr(addressTypeIPv4, addr)
×
2064
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
2065
                                addAddr(addressTypeIPv6, addr)
×
2066
                        } else {
×
2067
                                return fmt.Errorf("unhandled IP address: %v",
×
2068
                                        addr)
×
2069
                        }
×
2070

2071
                case *tor.OnionAddr:
×
2072
                        switch len(addr.OnionService) {
×
2073
                        case tor.V2Len:
×
2074
                                addAddr(addressTypeTorV2, addr)
×
2075
                        case tor.V3Len:
×
2076
                                addAddr(addressTypeTorV3, addr)
×
2077
                        default:
×
2078
                                return fmt.Errorf("invalid length for a tor " +
×
2079
                                        "address")
×
2080
                        }
2081

2082
                case *lnwire.OpaqueAddrs:
×
2083
                        addAddr(addressTypeOpaque, addr)
×
2084

2085
                default:
×
2086
                        return fmt.Errorf("unhandled address type: %T", addr)
×
2087
                }
2088
        }
2089

2090
        // Any remaining entries in newAddresses are new addresses that need to
2091
        // be added to the database for the first time.
2092
        for addrType, addrList := range newAddresses {
×
2093
                for position, addr := range addrList {
×
2094
                        err := db.InsertNodeAddress(
×
2095
                                ctx, sqlc.InsertNodeAddressParams{
×
2096
                                        NodeID:   nodeID,
×
2097
                                        Type:     int16(addrType),
×
2098
                                        Address:  addr,
×
2099
                                        Position: int32(position),
×
2100
                                },
×
2101
                        )
×
2102
                        if err != nil {
×
2103
                                return fmt.Errorf("unable to insert "+
×
2104
                                        "node(%d) address(%v): %w", nodeID,
×
2105
                                        addr, err)
×
2106
                        }
×
2107
                }
2108
        }
2109

2110
        return nil
×
2111
}
2112

2113
// getNodeAddresses fetches the addresses for a node with the given public key.
2114
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
2115
        []net.Addr, error) {
×
2116

×
2117
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
2118
        // are returned in the same order as they were inserted.
×
2119
        rows, err := db.GetNodeAddressesByPubKey(
×
2120
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
×
2121
                        Version: int16(ProtocolV1),
×
2122
                        PubKey:  nodePub,
×
2123
                },
×
2124
        )
×
2125
        if err != nil {
×
2126
                return false, nil, err
×
2127
        }
×
2128

2129
        // GetNodeAddressesByPubKey uses a left join so there should always be
2130
        // at least one row returned if the node exists even if it has no
2131
        // addresses.
2132
        if len(rows) == 0 {
×
2133
                return false, nil, nil
×
2134
        }
×
2135

2136
        addresses := make([]net.Addr, 0, len(rows))
×
2137
        for _, addr := range rows {
×
2138
                if !(addr.Type.Valid && addr.Address.Valid) {
×
2139
                        continue
×
2140
                }
2141

2142
                address := addr.Address.String
×
2143

×
2144
                switch dbAddressType(addr.Type.Int16) {
×
2145
                case addressTypeIPv4:
×
2146
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
2147
                        if err != nil {
×
2148
                                return false, nil, nil
×
2149
                        }
×
2150
                        tcp.IP = tcp.IP.To4()
×
2151

×
2152
                        addresses = append(addresses, tcp)
×
2153

2154
                case addressTypeIPv6:
×
2155
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
2156
                        if err != nil {
×
2157
                                return false, nil, nil
×
2158
                        }
×
2159
                        addresses = append(addresses, tcp)
×
2160

2161
                case addressTypeTorV3, addressTypeTorV2:
×
2162
                        service, portStr, err := net.SplitHostPort(address)
×
2163
                        if err != nil {
×
2164
                                return false, nil, fmt.Errorf("unable to "+
×
2165
                                        "split tor v3 address: %v",
×
2166
                                        addr.Address)
×
2167
                        }
×
2168

2169
                        port, err := strconv.Atoi(portStr)
×
2170
                        if err != nil {
×
2171
                                return false, nil, err
×
2172
                        }
×
2173

2174
                        addresses = append(addresses, &tor.OnionAddr{
×
2175
                                OnionService: service,
×
2176
                                Port:         port,
×
2177
                        })
×
2178

2179
                case addressTypeOpaque:
×
2180
                        opaque, err := hex.DecodeString(address)
×
2181
                        if err != nil {
×
2182
                                return false, nil, fmt.Errorf("unable to "+
×
2183
                                        "decode opaque address: %v", addr)
×
2184
                        }
×
2185

2186
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
2187
                                Payload: opaque,
×
2188
                        })
×
2189

2190
                default:
×
2191
                        return false, nil, fmt.Errorf("unknown address "+
×
2192
                                "type: %v", addr.Type)
×
2193
                }
2194
        }
2195

2196
        return true, addresses, nil
×
2197
}
2198

2199
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
2200
// database. This includes updating any existing types, inserting any new types,
2201
// and deleting any types that are no longer present.
2202
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
2203
        nodeID int64, extraFields map[uint64][]byte) error {
×
2204

×
2205
        // Get any existing extra signed fields for the node.
×
2206
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
2207
        if err != nil {
×
2208
                return err
×
2209
        }
×
2210

2211
        // Make a lookup map of the existing field types so that we can use it
2212
        // to keep track of any fields we should delete.
2213
        m := make(map[uint64]bool)
×
2214
        for _, field := range existingFields {
×
2215
                m[uint64(field.Type)] = true
×
2216
        }
×
2217

2218
        // For all the new fields, we'll upsert them and remove them from the
2219
        // map of existing fields.
2220
        for tlvType, value := range extraFields {
×
2221
                err = db.UpsertNodeExtraType(
×
2222
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
2223
                                NodeID: nodeID,
×
2224
                                Type:   int64(tlvType),
×
2225
                                Value:  value,
×
2226
                        },
×
2227
                )
×
2228
                if err != nil {
×
2229
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
2230
                                "signed field(%v): %w", nodeID, tlvType, err)
×
2231
                }
×
2232

2233
                // Remove the field from the map of existing fields if it was
2234
                // present.
2235
                delete(m, tlvType)
×
2236
        }
2237

2238
        // For all the fields that are left in the map of existing fields, we'll
2239
        // delete them as they are no longer present in the new set of fields.
2240
        for tlvType := range m {
×
2241
                err = db.DeleteExtraNodeType(
×
2242
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
2243
                                NodeID: nodeID,
×
2244
                                Type:   int64(tlvType),
×
2245
                        },
×
2246
                )
×
2247
                if err != nil {
×
2248
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
2249
                                "signed field(%v): %w", nodeID, tlvType, err)
×
2250
                }
×
2251
        }
2252

2253
        return nil
×
2254
}
2255

2256
// srcNodeInfo holds the information about the source node of the graph.
2257
type srcNodeInfo struct {
2258
        // id is the DB level ID of the source node entry in the "nodes" table.
2259
        id int64
2260

2261
        // pub is the public key of the source node.
2262
        pub route.Vertex
2263
}
2264

2265
// getSourceNode returns the DB node ID and pub key of the source node for the
2266
// specified protocol version.
2267
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
2268
        version ProtocolVersion) (int64, route.Vertex, error) {
×
2269

×
2270
        s.srcNodeMu.Lock()
×
2271
        defer s.srcNodeMu.Unlock()
×
2272

×
2273
        // If we already have the source node ID and pub key cached, then
×
2274
        // return them.
×
2275
        if info, ok := s.srcNodes[version]; ok {
×
2276
                return info.id, info.pub, nil
×
2277
        }
×
2278

2279
        var pubKey route.Vertex
×
2280

×
2281
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
2282
        if err != nil {
×
2283
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
2284
                        err)
×
2285
        }
×
2286

2287
        if len(nodes) == 0 {
×
2288
                return 0, pubKey, ErrSourceNodeNotSet
×
2289
        } else if len(nodes) > 1 {
×
2290
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
2291
                        "protocol %s found", version)
×
2292
        }
×
2293

2294
        copy(pubKey[:], nodes[0].PubKey)
×
2295

×
2296
        s.srcNodes[version] = &srcNodeInfo{
×
2297
                id:  nodes[0].NodeID,
×
2298
                pub: pubKey,
×
2299
        }
×
2300

×
2301
        return nodes[0].NodeID, pubKey, nil
×
2302
}
2303

2304
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
2305
// This then produces a map from TLV type to value. If the input is not a
2306
// valid TLV stream, then an error is returned.
2307
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
2308
        r := bytes.NewReader(data)
×
2309

×
2310
        tlvStream, err := tlv.NewStream()
×
2311
        if err != nil {
×
2312
                return nil, err
×
2313
        }
×
2314

2315
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
2316
        // pass it into the P2P decoding variant.
2317
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
2318
        if err != nil {
×
2319
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
2320
        }
×
2321
        if len(parsedTypes) == 0 {
×
2322
                return nil, nil
×
2323
        }
×
2324

2325
        records := make(map[uint64][]byte)
×
2326
        for k, v := range parsedTypes {
×
2327
                records[uint64(k)] = v
×
2328
        }
×
2329

2330
        return records, nil
×
2331
}
2332

2333
// insertChannel inserts a new channel record into the database.
2334
func insertChannel(ctx context.Context, db SQLQueries,
2335
        edge *models.ChannelEdgeInfo) error {
×
2336

×
2337
        var chanIDB [8]byte
×
2338
        byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
×
2339

×
2340
        // Make sure that the channel doesn't already exist. We do this
×
2341
        // explicitly instead of relying on catching a unique constraint error
×
2342
        // because relying on SQL to throw that error would abort the entire
×
2343
        // batch of transactions.
×
2344
        _, err := db.GetChannelBySCID(
×
2345
                ctx, sqlc.GetChannelBySCIDParams{
×
2346
                        Scid:    chanIDB[:],
×
2347
                        Version: int16(ProtocolV1),
×
2348
                },
×
2349
        )
×
2350
        if err == nil {
×
2351
                return ErrEdgeAlreadyExist
×
2352
        } else if !errors.Is(err, sql.ErrNoRows) {
×
2353
                return fmt.Errorf("unable to fetch channel: %w", err)
×
2354
        }
×
2355

2356
        // Make sure that at least a "shell" entry for each node is present in
2357
        // the nodes table.
2358
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
2359
        if err != nil {
×
2360
                return fmt.Errorf("unable to create shell node: %w", err)
×
2361
        }
×
2362

2363
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
2364
        if err != nil {
×
2365
                return fmt.Errorf("unable to create shell node: %w", err)
×
2366
        }
×
2367

2368
        var capacity sql.NullInt64
×
2369
        if edge.Capacity != 0 {
×
2370
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
2371
        }
×
2372

2373
        createParams := sqlc.CreateChannelParams{
×
2374
                Version:     int16(ProtocolV1),
×
2375
                Scid:        chanIDB[:],
×
2376
                NodeID1:     node1DBID,
×
2377
                NodeID2:     node2DBID,
×
2378
                Outpoint:    edge.ChannelPoint.String(),
×
2379
                Capacity:    capacity,
×
2380
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
2381
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
2382
        }
×
2383

×
2384
        if edge.AuthProof != nil {
×
2385
                proof := edge.AuthProof
×
2386

×
2387
                createParams.Node1Signature = proof.NodeSig1Bytes
×
2388
                createParams.Node2Signature = proof.NodeSig2Bytes
×
2389
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
2390
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
2391
        }
×
2392

2393
        // Insert the new channel record.
2394
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
2395
        if err != nil {
×
2396
                return err
×
2397
        }
×
2398

2399
        // Insert any channel features.
2400
        if len(edge.Features) != 0 {
×
2401
                chanFeatures := lnwire.NewRawFeatureVector()
×
2402
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
×
2403
                if err != nil {
×
2404
                        return err
×
2405
                }
×
2406

2407
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
×
2408
                for feature := range fv.Features() {
×
2409
                        err = db.InsertChannelFeature(
×
2410
                                ctx, sqlc.InsertChannelFeatureParams{
×
2411
                                        ChannelID:  dbChanID,
×
2412
                                        FeatureBit: int32(feature),
×
2413
                                },
×
2414
                        )
×
2415
                        if err != nil {
×
2416
                                return fmt.Errorf("unable to insert "+
×
2417
                                        "channel(%d) feature(%v): %w", dbChanID,
×
2418
                                        feature, err)
×
2419
                        }
×
2420
                }
2421
        }
2422

2423
        // Finally, insert any extra TLV fields in the channel announcement.
2424
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
2425
        if err != nil {
×
2426
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
2427
                        err)
×
2428
        }
×
2429

2430
        for tlvType, value := range extra {
×
2431
                err := db.CreateChannelExtraType(
×
2432
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
2433
                                ChannelID: dbChanID,
×
2434
                                Type:      int64(tlvType),
×
2435
                                Value:     value,
×
2436
                        },
×
2437
                )
×
2438
                if err != nil {
×
2439
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
2440
                                "signed field(%v): %w", edge.ChannelID,
×
2441
                                tlvType, err)
×
2442
                }
×
2443
        }
2444

2445
        return nil
×
2446
}
2447

2448
// maybeCreateShellNode checks if a shell node entry exists for the
2449
// given public key. If it does not exist, then a new shell node entry is
2450
// created. The ID of the node is returned. A shell node only has a protocol
2451
// version and public key persisted.
2452
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
2453
        pubKey route.Vertex) (int64, error) {
×
2454

×
2455
        dbNode, err := db.GetNodeByPubKey(
×
2456
                ctx, sqlc.GetNodeByPubKeyParams{
×
2457
                        PubKey:  pubKey[:],
×
2458
                        Version: int16(ProtocolV1),
×
2459
                },
×
2460
        )
×
2461
        // The node exists. Return the ID.
×
2462
        if err == nil {
×
2463
                return dbNode.ID, nil
×
2464
        } else if !errors.Is(err, sql.ErrNoRows) {
×
2465
                return 0, err
×
2466
        }
×
2467

2468
        // Otherwise, the node does not exist, so we create a shell entry for
2469
        // it.
2470
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
2471
                Version: int16(ProtocolV1),
×
2472
                PubKey:  pubKey[:],
×
2473
        })
×
2474
        if err != nil {
×
2475
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
2476
        }
×
2477

2478
        return id, nil
×
2479
}
2480

2481
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
2482
// the database. This includes deleting any existing types and then inserting
2483
// the new types.
2484
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
2485
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
2486

×
2487
        // Delete all existing extra signed fields for the channel policy.
×
2488
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
2489
        if err != nil {
×
2490
                return fmt.Errorf("unable to delete "+
×
2491
                        "existing policy extra signed fields for policy %d: %w",
×
2492
                        chanPolicyID, err)
×
2493
        }
×
2494

2495
        // Insert all new extra signed fields for the channel policy.
2496
        for tlvType, value := range extraFields {
×
2497
                err = db.InsertChanPolicyExtraType(
×
2498
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
2499
                                ChannelPolicyID: chanPolicyID,
×
2500
                                Type:            int64(tlvType),
×
2501
                                Value:           value,
×
2502
                        },
×
2503
                )
×
2504
                if err != nil {
×
2505
                        return fmt.Errorf("unable to insert "+
×
2506
                                "channel_policy(%d) extra signed field(%v): %w",
×
2507
                                chanPolicyID, tlvType, err)
×
2508
                }
×
2509
        }
2510

2511
        return nil
×
2512
}
2513

2514
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
2515
// provided dbChanRow and also fetches any other required information
2516
// to construct the edge info.
2517
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
2518
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
2519
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
2520

×
2521
        fv, extras, err := getChanFeaturesAndExtras(
×
2522
                ctx, db, dbChanID,
×
2523
        )
×
2524
        if err != nil {
×
2525
                return nil, err
×
2526
        }
×
2527

2528
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
2529
        if err != nil {
×
2530
                return nil, err
×
2531
        }
×
2532

2533
        var featureBuf bytes.Buffer
×
2534
        if err := fv.Encode(&featureBuf); err != nil {
×
2535
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
2536
        }
×
2537

2538
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
2539
        if err != nil {
×
2540
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
2541
                        "fields: %w", err)
×
2542
        }
×
2543
        if recs == nil {
×
2544
                recs = make([]byte, 0)
×
2545
        }
×
2546

2547
        var btcKey1, btcKey2 route.Vertex
×
2548
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
2549
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
2550

×
2551
        channel := &models.ChannelEdgeInfo{
×
2552
                ChainHash:        chain,
×
2553
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
2554
                NodeKey1Bytes:    node1,
×
2555
                NodeKey2Bytes:    node2,
×
2556
                BitcoinKey1Bytes: btcKey1,
×
2557
                BitcoinKey2Bytes: btcKey2,
×
2558
                ChannelPoint:     *op,
×
2559
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
2560
                Features:         featureBuf.Bytes(),
×
2561
                ExtraOpaqueData:  recs,
×
2562
        }
×
2563

×
2564
        // We always set all the signatures at the same time, so we can
×
2565
        // safely check if one signature is present to determine if we have the
×
2566
        // rest of the signatures for the auth proof.
×
2567
        if len(dbChan.Bitcoin1Signature) > 0 {
×
2568
                channel.AuthProof = &models.ChannelAuthProof{
×
2569
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
2570
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
2571
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
2572
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
2573
                }
×
2574
        }
×
2575

2576
        return channel, nil
×
2577
}
2578

2579
// buildNodeVertices is a helper that converts raw node public keys
2580
// into route.Vertex instances.
2581
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
2582
        route.Vertex, error) {
×
2583

×
2584
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
2585
        if err != nil {
×
2586
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
2587
                        "create vertex from node1 pubkey: %w", err)
×
2588
        }
×
2589

2590
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
2591
        if err != nil {
×
2592
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
2593
                        "create vertex from node2 pubkey: %w", err)
×
2594
        }
×
2595

2596
        return node1Vertex, node2Vertex, nil
×
2597
}
2598

2599
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
2600
// for a channel with the given ID.
2601
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
2602
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
2603

×
2604
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
2605
        if err != nil {
×
2606
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
2607
                        "features and extras: %w", err)
×
2608
        }
×
2609

2610
        var (
×
2611
                fv     = lnwire.EmptyFeatureVector()
×
2612
                extras = make(map[uint64][]byte)
×
2613
        )
×
2614
        for _, row := range rows {
×
2615
                if row.IsFeature {
×
2616
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
2617

×
2618
                        continue
×
2619
                }
2620

2621
                tlvType, ok := row.ExtraKey.(int64)
×
2622
                if !ok {
×
2623
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
2624
                                "TLV type: %T", row.ExtraKey)
×
2625
                }
×
2626

2627
                valueBytes, ok := row.Value.([]byte)
×
2628
                if !ok {
×
2629
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
2630
                                "Value: %T", row.Value)
×
2631
                }
×
2632

2633
                extras[uint64(tlvType)] = valueBytes
×
2634
        }
2635

2636
        return fv, extras, nil
×
2637
}
2638

2639
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
2640
// all the extra info required to build the complete models.ChannelEdgePolicy
2641
// types. It returns two policies, which may be nil if the provided
2642
// sqlc.ChannelPolicy records are nil.
2643
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
2644
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
2645
        node2 route.Vertex) (*models.ChannelEdgePolicy,
2646
        *models.ChannelEdgePolicy, error) {
×
2647

×
2648
        if dbPol1 == nil && dbPol2 == nil {
×
2649
                return nil, nil, nil
×
2650
        }
×
2651

2652
        var (
×
2653
                policy1ID int64
×
2654
                policy2ID int64
×
2655
        )
×
2656
        if dbPol1 != nil {
×
2657
                policy1ID = dbPol1.ID
×
2658
        }
×
2659
        if dbPol2 != nil {
×
2660
                policy2ID = dbPol2.ID
×
2661
        }
×
2662
        rows, err := db.GetChannelPolicyExtraTypes(
×
2663
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
2664
                        ID:   policy1ID,
×
2665
                        ID_2: policy2ID,
×
2666
                },
×
2667
        )
×
2668
        if err != nil {
×
2669
                return nil, nil, err
×
2670
        }
×
2671

2672
        var (
×
2673
                dbPol1Extras = make(map[uint64][]byte)
×
2674
                dbPol2Extras = make(map[uint64][]byte)
×
2675
        )
×
2676
        for _, row := range rows {
×
2677
                switch row.PolicyID {
×
2678
                case policy1ID:
×
2679
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
2680
                case policy2ID:
×
2681
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
2682
                default:
×
2683
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
2684
                                "in row: %v", row.PolicyID, row)
×
2685
                }
2686
        }
2687

2688
        var pol1, pol2 *models.ChannelEdgePolicy
×
2689
        if dbPol1 != nil {
×
2690
                pol1, err = buildChanPolicy(
×
2691
                        *dbPol1, channelID, dbPol1Extras, node2, true,
×
2692
                )
×
2693
                if err != nil {
×
2694
                        return nil, nil, err
×
2695
                }
×
2696
        }
2697
        if dbPol2 != nil {
×
2698
                pol2, err = buildChanPolicy(
×
2699
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
2700
                )
×
2701
                if err != nil {
×
2702
                        return nil, nil, err
×
2703
                }
×
2704
        }
2705

2706
        return pol1, pol2, nil
×
2707
}
2708

2709
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
2710
// provided sqlc.ChannelPolicy and other required information.
2711
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
2712
        extras map[uint64][]byte, toNode route.Vertex,
2713
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
2714

×
2715
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
2716
        if err != nil {
×
2717
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
2718
                        "fields: %w", err)
×
2719
        }
×
2720

2721
        var msgFlags lnwire.ChanUpdateMsgFlags
×
2722
        if dbPolicy.MaxHtlcMsat.Valid {
×
2723
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
2724
        }
×
2725

2726
        var chanFlags lnwire.ChanUpdateChanFlags
×
2727
        if !isNode1 {
×
2728
                chanFlags |= lnwire.ChanUpdateDirection
×
2729
        }
×
2730
        if dbPolicy.Disabled.Bool {
×
2731
                chanFlags |= lnwire.ChanUpdateDisabled
×
2732
        }
×
2733

2734
        var inboundFee fn.Option[lnwire.Fee]
×
2735
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
2736
                dbPolicy.InboundBaseFeeMsat.Valid {
×
2737

×
2738
                inboundFee = fn.Some(lnwire.Fee{
×
2739
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
2740
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
2741
                })
×
2742
        }
×
2743

2744
        return &models.ChannelEdgePolicy{
×
2745
                SigBytes:  dbPolicy.Signature,
×
2746
                ChannelID: channelID,
×
2747
                LastUpdate: time.Unix(
×
2748
                        dbPolicy.LastUpdate.Int64, 0,
×
2749
                ),
×
2750
                MessageFlags:  msgFlags,
×
2751
                ChannelFlags:  chanFlags,
×
2752
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
2753
                MinHTLC: lnwire.MilliSatoshi(
×
2754
                        dbPolicy.MinHtlcMsat,
×
2755
                ),
×
2756
                MaxHTLC: lnwire.MilliSatoshi(
×
2757
                        dbPolicy.MaxHtlcMsat.Int64,
×
2758
                ),
×
2759
                FeeBaseMSat: lnwire.MilliSatoshi(
×
2760
                        dbPolicy.BaseFeeMsat,
×
2761
                ),
×
2762
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
2763
                ToNode:                    toNode,
×
2764
                InboundFee:                inboundFee,
×
2765
                ExtraOpaqueData:           recs,
×
2766
        }, nil
×
2767
}
2768

2769
// buildNodes builds the models.LightningNode instances for the
2770
// given row which is expected to be a sqlc type that contains node information.
2771
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
2772
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
NEW
2773
        error) {
×
NEW
2774

×
NEW
2775
        node1, err := buildNode(ctx, db, &dbNode1)
×
NEW
2776
        if err != nil {
×
NEW
2777
                return nil, nil, err
×
NEW
2778
        }
×
2779

NEW
2780
        node2, err := buildNode(ctx, db, &dbNode2)
×
NEW
2781
        if err != nil {
×
NEW
2782
                return nil, nil, err
×
NEW
2783
        }
×
2784

NEW
2785
        return node1, node2, nil
×
2786
}
2787

2788
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
2789
// row which is expected to be a sqlc type that contains channel policy
2790
// information. It returns two policies, which may be nil if the policy
2791
// information is not present in the row.
2792
//
2793
//nolint:ll,dupl
2794
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
2795
        error) {
×
2796

×
2797
        var policy1, policy2 *sqlc.ChannelPolicy
×
2798
        switch r := row.(type) {
×
NEW
2799
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
NEW
2800
                if r.Policy1ID.Valid {
×
NEW
2801
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
2802
                                ID:                      r.Policy1ID.Int64,
×
NEW
2803
                                Version:                 r.Policy1Version.Int16,
×
NEW
2804
                                ChannelID:               r.Channel.ID,
×
NEW
2805
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
2806
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
2807
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
2808
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
2809
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
2810
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
2811
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
2812
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
2813
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
2814
                                Disabled:                r.Policy1Disabled,
×
NEW
2815
                                Signature:               r.Policy1Signature,
×
NEW
2816
                        }
×
NEW
2817
                }
×
NEW
2818
                if r.Policy2ID.Valid {
×
NEW
2819
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
2820
                                ID:                      r.Policy2ID.Int64,
×
NEW
2821
                                Version:                 r.Policy2Version.Int16,
×
NEW
2822
                                ChannelID:               r.Channel.ID,
×
NEW
2823
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
2824
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
2825
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
2826
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
2827
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
2828
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
2829
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
2830
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
2831
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
2832
                                Disabled:                r.Policy2Disabled,
×
NEW
2833
                                Signature:               r.Policy2Signature,
×
NEW
2834
                        }
×
NEW
2835
                }
×
2836

NEW
2837
                return policy1, policy2, nil
×
2838

2839
        case sqlc.ListChannelsByNodeIDRow:
×
2840
                if r.Policy1ID.Valid {
×
2841
                        policy1 = &sqlc.ChannelPolicy{
×
2842
                                ID:                      r.Policy1ID.Int64,
×
2843
                                Version:                 r.Policy1Version.Int16,
×
2844
                                ChannelID:               r.Channel.ID,
×
2845
                                NodeID:                  r.Policy1NodeID.Int64,
×
2846
                                Timelock:                r.Policy1Timelock.Int32,
×
2847
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
2848
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
2849
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
2850
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
2851
                                LastUpdate:              r.Policy1LastUpdate,
×
2852
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
2853
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
2854
                                Disabled:                r.Policy1Disabled,
×
2855
                                Signature:               r.Policy1Signature,
×
2856
                        }
×
2857
                }
×
2858
                if r.Policy2ID.Valid {
×
2859
                        policy2 = &sqlc.ChannelPolicy{
×
2860
                                ID:                      r.Policy2ID.Int64,
×
2861
                                Version:                 r.Policy2Version.Int16,
×
2862
                                ChannelID:               r.Channel.ID,
×
2863
                                NodeID:                  r.Policy2NodeID.Int64,
×
2864
                                Timelock:                r.Policy2Timelock.Int32,
×
2865
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
2866
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
2867
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
2868
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
2869
                                LastUpdate:              r.Policy2LastUpdate,
×
2870
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
2871
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
2872
                                Disabled:                r.Policy2Disabled,
×
2873
                                Signature:               r.Policy2Signature,
×
2874
                        }
×
2875
                }
×
2876

NEW
2877
                return policy1, policy2, nil
×
2878

NEW
2879
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
NEW
2880
                if r.Policy1ID.Valid {
×
NEW
2881
                        policy1 = &sqlc.ChannelPolicy{
×
NEW
2882
                                ID:                      r.Policy1ID.Int64,
×
NEW
2883
                                Version:                 r.Policy1Version.Int16,
×
NEW
2884
                                ChannelID:               r.Channel.ID,
×
NEW
2885
                                NodeID:                  r.Policy1NodeID.Int64,
×
NEW
2886
                                Timelock:                r.Policy1Timelock.Int32,
×
NEW
2887
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
NEW
2888
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
NEW
2889
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
NEW
2890
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
NEW
2891
                                LastUpdate:              r.Policy1LastUpdate,
×
NEW
2892
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
NEW
2893
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
NEW
2894
                                Disabled:                r.Policy1Disabled,
×
NEW
2895
                                Signature:               r.Policy1Signature,
×
NEW
2896
                        }
×
NEW
2897
                }
×
NEW
2898
                if r.Policy2ID.Valid {
×
NEW
2899
                        policy2 = &sqlc.ChannelPolicy{
×
NEW
2900
                                ID:                      r.Policy2ID.Int64,
×
NEW
2901
                                Version:                 r.Policy2Version.Int16,
×
NEW
2902
                                ChannelID:               r.Channel.ID,
×
NEW
2903
                                NodeID:                  r.Policy2NodeID.Int64,
×
NEW
2904
                                Timelock:                r.Policy2Timelock.Int32,
×
NEW
2905
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
NEW
2906
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
NEW
2907
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
NEW
2908
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
NEW
2909
                                LastUpdate:              r.Policy2LastUpdate,
×
NEW
2910
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
NEW
2911
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
NEW
2912
                                Disabled:                r.Policy2Disabled,
×
NEW
2913
                                Signature:               r.Policy2Signature,
×
NEW
2914
                        }
×
NEW
2915
                }
×
2916

2917
                return policy1, policy2, nil
×
2918
        default:
×
2919
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
2920
                        "extractChannelPolicies: %T", r)
×
2921
        }
2922
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc