• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 15874499852

25 Jun 2025 10:57AM UTC coverage: 67.645% (+9.7%) from 57.952%
15874499852

Pull #9925

github

web-flow
Merge a4459cc5a into fb720c174
Pull Request #9925: routing: clean-up & fix blinded path incoming chained channel logic

48 of 49 new or added lines in 2 files covered. (97.96%)

1124 existing lines in 13 files now uncovered.

134989 of 199554 relevant lines covered (67.65%)

21997.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/batch"
23
        "github.com/lightningnetwork/lnd/fn/v2"
24
        "github.com/lightningnetwork/lnd/graph/db/models"
25
        "github.com/lightningnetwork/lnd/lnwire"
26
        "github.com/lightningnetwork/lnd/routing/route"
27
        "github.com/lightningnetwork/lnd/sqldb"
28
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
29
        "github.com/lightningnetwork/lnd/tlv"
30
        "github.com/lightningnetwork/lnd/tor"
31
)
32

33
// pageSize is the limit for the number of records that can be returned
34
// in a paginated query. This can be tuned after some benchmarks.
35
const pageSize = 2000
36

37
// ProtocolVersion is an enum that defines the gossip protocol version of a
38
// message.
39
type ProtocolVersion uint8
40

41
const (
42
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
43
        ProtocolV1 ProtocolVersion = 1
44
)
45

46
// String returns a string representation of the protocol version.
47
func (v ProtocolVersion) String() string {
48
        return fmt.Sprintf("V%d", v)
×
49
}
×
UNCOV
50

×
51
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
52
// execute queries against the SQL graph tables.
53
//
54
//nolint:ll,interfacebloat
55
type SQLQueries interface {
56
        /*
57
                Node queries.
58
        */
59
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
60
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
61
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
62
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
63
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
64
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
65
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
66
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
67

68
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
69
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
70
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
71

72
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
73
        GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
74
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
75

76
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
77
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
78
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
79
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
80

81
        /*
82
                Source node queries.
83
        */
84
        AddSourceNode(ctx context.Context, nodeID int64) error
85
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
86

87
        /*
88
                Channel queries.
89
        */
90
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
91
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
92
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
93
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
94
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
95
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
96
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
97
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
98
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
99
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
100
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
101
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
102
        DeleteChannel(ctx context.Context, id int64) error
103

104
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
105
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
106

107
        /*
108
                Channel Policy table queries.
109
        */
110
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
111
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
112
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
113

114
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
115
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
116
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
117

118
        /*
119
                Zombie index queries.
120
        */
121
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
122
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
123
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
124
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
125
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
126
}
127

128
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
129
// database operations.
130
type BatchedSQLQueries interface {
131
        SQLQueries
132
        sqldb.BatchedTx[SQLQueries]
133
}
134

135
// SQLStore is an implementation of the V1Store interface that uses a SQL
136
// database as the backend.
137
//
138
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
139
// implement the V1Store interface incrementally. For any method not
140
// implemented,  things will fall back to the KVStore. This is ONLY the case
141
// for the time being while this struct is purely used in unit tests only.
142
type SQLStore struct {
143
        cfg *SQLStoreConfig
144
        db  BatchedSQLQueries
145

146
        // cacheMu guards all caches (rejectCache and chanCache). If
147
        // this mutex will be acquired at the same time as the DB mutex then
148
        // the cacheMu MUST be acquired first to prevent deadlock.
149
        cacheMu     sync.RWMutex
150
        rejectCache *rejectCache
151
        chanCache   *channelCache
152

153
        chanScheduler batch.Scheduler[SQLQueries]
154
        nodeScheduler batch.Scheduler[SQLQueries]
155

156
        srcNodes  map[ProtocolVersion]*srcNodeInfo
157
        srcNodeMu sync.Mutex
158

159
        // Temporary fall-back to the KVStore so that we can implement the
160
        // interface incrementally.
161
        *KVStore
162
}
163

164
// A compile-time assertion to ensure that SQLStore implements the V1Store
165
// interface.
166
var _ V1Store = (*SQLStore)(nil)
167

168
// SQLStoreConfig holds the configuration for the SQLStore.
169
type SQLStoreConfig struct {
170
        // ChainHash is the genesis hash for the chain that all the gossip
171
        // messages in this store are aimed at.
172
        ChainHash chainhash.Hash
173
}
174

175
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
176
// storage backend.
177
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, kvStore *KVStore,
178
        options ...StoreOptionModifier) (*SQLStore, error) {
179

180
        opts := DefaultOptions()
181
        for _, o := range options {
182
                o(opts)
183
        }
184

185
        if opts.NoMigration {
186
                return nil, fmt.Errorf("the NoMigration option is not yet " +
187
                        "supported for SQL stores")
188
        }
189

190
        s := &SQLStore{
191
                cfg:         cfg,
×
192
                db:          db,
×
193
                KVStore:     kvStore,
×
194
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
195
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
196
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
197
        }
198

×
199
        s.chanScheduler = batch.NewTimeScheduler(
×
200
                db, &s.cacheMu, opts.BatchCommitInterval,
×
201
        )
×
202
        s.nodeScheduler = batch.NewTimeScheduler(
203
                db, nil, opts.BatchCommitInterval,
×
204
        )
×
205

×
206
        return s, nil
×
UNCOV
207
}
×
UNCOV
208

×
UNCOV
209
// AddLightningNode adds a vertex/node to the graph database. If the node is not
×
UNCOV
210
// in the database from before, this will add a new, unconnected one to the
×
UNCOV
211
// graph. If it is present from before, this will update that node's
×
UNCOV
212
// information.
×
UNCOV
213
//
×
UNCOV
214
// NOTE: part of the V1Store interface.
×
UNCOV
215
func (s *SQLStore) AddLightningNode(ctx context.Context,
×
216
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
217

×
218
        r := &batch.Request[SQLQueries]{
×
219
                Opts: batch.NewSchedulerOptions(opts...),
×
220
                Do: func(queries SQLQueries) error {
221
                        _, err := upsertNode(ctx, queries, node)
222
                        return err
223
                },
224
        }
225

226
        return s.nodeScheduler.Execute(ctx, r)
227
}
228

UNCOV
229
// FetchLightningNode attempts to look up a target node by its identity public
×
UNCOV
230
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
×
UNCOV
231
// returned.
×
UNCOV
232
//
×
UNCOV
233
// NOTE: part of the V1Store interface.
×
UNCOV
234
func (s *SQLStore) FetchLightningNode(ctx context.Context,
×
235
        pubKey route.Vertex) (*models.LightningNode, error) {
×
236

×
237
        var node *models.LightningNode
238
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
239
                var err error
×
240
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
241

242
                return err
243
        }, sqldb.NoOpReset)
244
        if err != nil {
245
                return nil, fmt.Errorf("unable to fetch node: %w", err)
246
        }
247

248
        return node, nil
×
UNCOV
249
}
×
UNCOV
250

×
UNCOV
251
// HasLightningNode determines if the graph has a vertex identified by the
×
UNCOV
252
// target node identity public key. If the node exists in the database, a
×
UNCOV
253
// timestamp of when the data for the node was lasted updated is returned along
×
UNCOV
254
// with a true boolean. Otherwise, an empty time.Time is returned with a false
×
UNCOV
255
// boolean.
×
UNCOV
256
//
×
UNCOV
257
// NOTE: part of the V1Store interface.
×
UNCOV
258
func (s *SQLStore) HasLightningNode(ctx context.Context,
×
259
        pubKey [33]byte) (time.Time, bool, error) {
×
260

261
        var (
×
262
                exists     bool
263
                lastUpdate time.Time
264
        )
265
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
266
                dbNode, err := db.GetNodeByPubKey(
267
                        ctx, sqlc.GetNodeByPubKeyParams{
268
                                Version: int16(ProtocolV1),
269
                                PubKey:  pubKey[:],
270
                        },
271
                )
272
                if errors.Is(err, sql.ErrNoRows) {
×
273
                        return nil
×
274
                } else if err != nil {
×
275
                        return fmt.Errorf("unable to fetch node: %w", err)
×
276
                }
×
UNCOV
277

×
278
                exists = true
×
279

×
280
                if dbNode.LastUpdate.Valid {
×
281
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
282
                }
×
UNCOV
283

×
284
                return nil
×
UNCOV
285
        }, sqldb.NoOpReset)
×
286
        if err != nil {
×
287
                return time.Time{}, false,
×
288
                        fmt.Errorf("unable to fetch node: %w", err)
×
289
        }
×
290

291
        return lastUpdate, exists, nil
×
UNCOV
292
}
×
UNCOV
293

×
UNCOV
294
// AddrsForNode returns all known addresses for the target node public key
×
UNCOV
295
// that the graph DB is aware of. The returned boolean indicates if the
×
296
// given node is unknown to the graph DB or not.
UNCOV
297
//
×
298
// NOTE: part of the V1Store interface.
UNCOV
299
func (s *SQLStore) AddrsForNode(ctx context.Context,
×
300
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
301

×
302
        var (
×
303
                addresses []net.Addr
304
                known     bool
×
305
        )
306
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
307
                var err error
308
                known, addresses, err = getNodeAddresses(
309
                        ctx, db, nodePub.SerializeCompressed(),
310
                )
311
                if err != nil {
312
                        return fmt.Errorf("unable to fetch node addresses: %w",
313
                                err)
×
314
                }
×
UNCOV
315

×
316
                return nil
×
UNCOV
317
        }, sqldb.NoOpReset)
×
318
        if err != nil {
×
319
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
320
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
321
        }
×
UNCOV
322

×
323
        return known, addresses, nil
×
UNCOV
324
}
×
UNCOV
325

×
UNCOV
326
// DeleteLightningNode starts a new database transaction to remove a vertex/node
×
UNCOV
327
// from the database according to the node's public key.
×
328
//
UNCOV
329
// NOTE: part of the V1Store interface.
×
330
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
331
        pubKey route.Vertex) error {
×
332

×
333
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
334
                res, err := db.DeleteNodeByPubKey(
×
335
                        ctx, sqlc.DeleteNodeByPubKeyParams{
336
                                Version: int16(ProtocolV1),
×
337
                                PubKey:  pubKey[:],
338
                        },
339
                )
340
                if err != nil {
341
                        return err
342
                }
343

344
                rows, err := res.RowsAffected()
×
345
                if err != nil {
×
346
                        return err
×
347
                }
×
UNCOV
348

×
349
                if rows == 0 {
×
350
                        return ErrGraphNodeNotFound
×
351
                } else if rows > 1 {
×
352
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
353
                }
×
UNCOV
354

×
355
                return err
×
356
        }, sqldb.NoOpReset)
357
        if err != nil {
×
358
                return fmt.Errorf("unable to delete node: %w", err)
×
359
        }
×
UNCOV
360

×
361
        return nil
UNCOV
362
}
×
UNCOV
363

×
UNCOV
364
// FetchNodeFeatures returns the features of the given node. If no features are
×
UNCOV
365
// known for the node, an empty feature vector is returned.
×
UNCOV
366
//
×
367
// NOTE: this is part of the graphdb.NodeTraverser interface.
UNCOV
368
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
×
369
        *lnwire.FeatureVector, error) {
370

×
371
        ctx := context.TODO()
×
372

×
373
        return fetchNodeFeatures(ctx, s.db, nodePub)
374
}
×
375

376
// DisabledChannelIDs returns the channel ids of disabled channels.
377
// A channel is disabled when two of the associated ChanelEdgePolicies
378
// have their disabled bit on.
379
//
380
// NOTE: part of the V1Store interface.
381
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
382
        var (
×
383
                ctx     = context.TODO()
×
384
                chanIDs []uint64
×
385
        )
×
386
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
387
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
388
                if err != nil {
389
                        return fmt.Errorf("unable to fetch disabled "+
390
                                "channels: %w", err)
391
                }
392

393
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
394

×
395
                return nil
×
UNCOV
396
        }, sqldb.NoOpReset)
×
397
        if err != nil {
×
398
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
399
                        err)
×
400
        }
×
UNCOV
401

×
402
        return chanIDs, nil
×
UNCOV
403
}
×
UNCOV
404

×
405
// LookupAlias attempts to return the alias as advertised by the target node.
UNCOV
406
//
×
UNCOV
407
// NOTE: part of the V1Store interface.
×
UNCOV
408
func (s *SQLStore) LookupAlias(ctx context.Context,
×
409
        pub *btcec.PublicKey) (string, error) {
410

×
411
        var alias string
×
412
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
413
                dbNode, err := db.GetNodeByPubKey(
×
414
                        ctx, sqlc.GetNodeByPubKeyParams{
415
                                Version: int16(ProtocolV1),
×
416
                                PubKey:  pub.SerializeCompressed(),
417
                        },
418
                )
419
                if errors.Is(err, sql.ErrNoRows) {
420
                        return ErrNodeAliasNotFound
421
                } else if err != nil {
422
                        return fmt.Errorf("unable to fetch node: %w", err)
×
423
                }
×
UNCOV
424

×
425
                if !dbNode.Alias.Valid {
×
426
                        return ErrNodeAliasNotFound
×
427
                }
×
UNCOV
428

×
429
                alias = dbNode.Alias.String
×
430

×
431
                return nil
×
UNCOV
432
        }, sqldb.NoOpReset)
×
433
        if err != nil {
×
434
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
435
        }
×
UNCOV
436

×
437
        return alias, nil
UNCOV
438
}
×
UNCOV
439

×
UNCOV
440
// SourceNode returns the source node of the graph. The source node is treated
×
441
// as the center node within a star-graph. This method may be used to kick off
UNCOV
442
// a path finding algorithm in order to explore the reachability of another
×
UNCOV
443
// node based off the source node.
×
UNCOV
444
//
×
445
// NOTE: part of the V1Store interface.
UNCOV
446
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
×
447
        error) {
×
448

×
449
        var node *models.LightningNode
450
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
451
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
452
                if err != nil {
453
                        return fmt.Errorf("unable to fetch V1 source node: %w",
454
                                err)
455
                }
456

457
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
458

459
                return err
UNCOV
460
        }, sqldb.NoOpReset)
×
461
        if err != nil {
×
462
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
463
        }
×
UNCOV
464

×
465
        return node, nil
×
UNCOV
466
}
×
UNCOV
467

×
UNCOV
468
// SetSourceNode sets the source node within the graph database. The source
×
469
// node is to be used as the center of a star-graph within path finding
UNCOV
470
// algorithms.
×
UNCOV
471
//
×
UNCOV
472
// NOTE: part of the V1Store interface.
×
473
func (s *SQLStore) SetSourceNode(ctx context.Context,
474
        node *models.LightningNode) error {
×
475

×
476
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
477
                id, err := upsertNode(ctx, db, node)
478
                if err != nil {
×
479
                        return fmt.Errorf("unable to upsert source node: %w",
480
                                err)
481
                }
482

483
                // Make sure that if a source node for this version is already
484
                // set, then the ID is the same as the one we are about to set.
485
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
486
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
487
                        return fmt.Errorf("unable to fetch source node: %w",
×
488
                                err)
×
489
                } else if err == nil {
×
490
                        if dbSourceNodeID != id {
×
491
                                return fmt.Errorf("v1 source node already "+
×
492
                                        "set to a different node: %d vs %d",
×
493
                                        dbSourceNodeID, id)
×
494
                        }
×
495

496
                        return nil
497
                }
UNCOV
498

×
499
                return db.AddSourceNode(ctx, id)
×
UNCOV
500
        }, sqldb.NoOpReset)
×
UNCOV
501
}
×
UNCOV
502

×
UNCOV
503
// NodeUpdatesInHorizon returns all the known lightning node which have an
×
UNCOV
504
// update timestamp within the passed range. This method can be used by two
×
UNCOV
505
// nodes to quickly determine if they have the same set of up to date node
×
UNCOV
506
// announcements.
×
UNCOV
507
//
×
508
// NOTE: This is part of the V1Store interface.
UNCOV
509
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
×
510
        endTime time.Time) ([]models.LightningNode, error) {
511

512
        ctx := context.TODO()
×
513

514
        var nodes []models.LightningNode
515
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
516
                dbNodes, err := db.GetNodesByLastUpdateRange(
517
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
518
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
519
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
520
                        },
521
                )
522
                if err != nil {
523
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
524
                }
×
UNCOV
525

×
526
                for _, dbNode := range dbNodes {
×
527
                        node, err := buildNode(ctx, db, &dbNode)
×
528
                        if err != nil {
×
529
                                return fmt.Errorf("unable to build node: %w",
×
530
                                        err)
×
531
                        }
×
UNCOV
532

×
533
                        nodes = append(nodes, *node)
×
UNCOV
534
                }
×
UNCOV
535

×
536
                return nil
×
UNCOV
537
        }, sqldb.NoOpReset)
×
538
        if err != nil {
539
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
540
        }
×
UNCOV
541

×
542
        return nodes, nil
×
UNCOV
543
}
×
UNCOV
544

×
545
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
UNCOV
546
// undirected edge from the two target nodes are created. The information stored
×
547
// denotes the static attributes of the channel, such as the channelID, the keys
548
// involved in creation of the channel, and the set of features that the channel
UNCOV
549
// supports. The chanPoint and chanID are used to uniquely identify the edge
×
550
// globally within the database.
UNCOV
551
//
×
UNCOV
552
// NOTE: part of the V1Store interface.
×
UNCOV
553
func (s *SQLStore) AddChannelEdge(ctx context.Context,
×
554
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
555

×
556
        var alreadyExists bool
557
        r := &batch.Request[SQLQueries]{
558
                Opts: batch.NewSchedulerOptions(opts...),
559
                Reset: func() {
560
                        alreadyExists = false
561
                },
562
                Do: func(tx SQLQueries) error {
563
                        err := insertChannel(ctx, tx, edge)
564

565
                        // Silence ErrEdgeAlreadyExist so that the batch can
566
                        // succeed, but propagate the error via local state.
567
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
568
                                alreadyExists = true
×
569
                                return nil
×
570
                        }
×
UNCOV
571

×
572
                        return err
×
UNCOV
573
                },
×
574
                OnCommit: func(err error) error {
×
575
                        switch {
×
576
                        case err != nil:
×
577
                                return err
×
578
                        case alreadyExists:
×
579
                                return ErrEdgeAlreadyExist
×
580
                        default:
×
581
                                s.rejectCache.remove(edge.ChannelID)
×
582
                                s.chanCache.remove(edge.ChannelID)
×
583
                                return nil
×
584
                        }
UNCOV
585
                },
×
586
        }
UNCOV
587

×
588
        return s.chanScheduler.Execute(ctx, r)
×
UNCOV
589
}
×
UNCOV
590

×
UNCOV
591
// HighestChanID returns the "highest" known channel ID in the channel graph.
×
UNCOV
592
// This represents the "newest" channel from the PoV of the chain. This method
×
UNCOV
593
// can be used by peers to quickly determine if their graphs are in sync.
×
UNCOV
594
//
×
UNCOV
595
// NOTE: This is part of the V1Store interface.
×
596
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
597
        var highestChanID uint64
598
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
599
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
600
                if errors.Is(err, sql.ErrNoRows) {
601
                        return nil
×
602
                } else if err != nil {
603
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
604
                                err)
605
                }
606

607
                highestChanID = byteOrder.Uint64(chanID)
608

609
                return nil
×
UNCOV
610
        }, sqldb.NoOpReset)
×
611
        if err != nil {
×
612
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
613
        }
×
UNCOV
614

×
615
        return highestChanID, nil
×
UNCOV
616
}
×
UNCOV
617

×
UNCOV
618
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
×
619
// within the database for the referenced channel. The `flags` attribute within
UNCOV
620
// the ChannelEdgePolicy determines which of the directed edges are being
×
UNCOV
621
// updated. If the flag is 1, then the first node's information is being
×
UNCOV
622
// updated, otherwise it's the second node's information. The node ordering is
×
623
// determined by the lexicographical ordering of the identity public keys of the
UNCOV
624
// nodes on either side of the channel.
×
UNCOV
625
//
×
UNCOV
626
// NOTE: part of the V1Store interface.
×
627
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
UNCOV
628
        edge *models.ChannelEdgePolicy,
×
629
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
630

631
        var (
632
                isUpdate1    bool
633
                edgeNotFound bool
634
                from, to     route.Vertex
635
        )
636

637
        r := &batch.Request[SQLQueries]{
638
                Opts: batch.NewSchedulerOptions(opts...),
639
                Reset: func() {
640
                        isUpdate1 = false
641
                        edgeNotFound = false
642
                },
×
643
                Do: func(tx SQLQueries) error {
×
644
                        var err error
×
645
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
646
                                ctx, tx, edge,
×
647
                        )
×
648
                        if err != nil {
×
649
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
650
                        }
×
UNCOV
651

×
UNCOV
652
                        // Silence ErrEdgeNotFound so that the batch can
×
UNCOV
653
                        // succeed, but propagate the error via local state.
×
654
                        if errors.Is(err, ErrEdgeNotFound) {
×
655
                                edgeNotFound = true
×
656
                                return nil
×
657
                        }
×
UNCOV
658

×
659
                        return err
×
UNCOV
660
                },
×
661
                OnCommit: func(err error) error {
×
662
                        switch {
×
663
                        case err != nil:
×
664
                                return err
665
                        case edgeNotFound:
666
                                return ErrEdgeNotFound
667
                        default:
×
668
                                s.updateEdgeCache(edge, isUpdate1)
×
669
                                return nil
×
UNCOV
670
                        }
×
671
                },
UNCOV
672
        }
×
673

674
        err := s.chanScheduler.Execute(ctx, r)
×
675

×
676
        return from, to, err
×
UNCOV
677
}
×
UNCOV
678

×
UNCOV
679
// updateEdgeCache updates our reject and channel caches with the new
×
UNCOV
680
// edge policy information.
×
UNCOV
681
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
×
682
        isUpdate1 bool) {
×
683

684
        // If an entry for this channel is found in reject cache, we'll modify
685
        // the entry with the updated timestamp for the direction that was just
686
        // written. If the edge doesn't exist, we'll load the cache entry lazily
687
        // during the next query for this edge.
×
688
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
689
                if isUpdate1 {
×
690
                        entry.upd1Time = e.LastUpdate.Unix()
691
                } else {
692
                        entry.upd2Time = e.LastUpdate.Unix()
693
                }
694
                s.rejectCache.insert(e.ChannelID, entry)
UNCOV
695
        }
×
UNCOV
696

×
UNCOV
697
        // If an entry for this channel is found in channel cache, we'll modify
×
UNCOV
698
        // the entry with the updated policy for the direction that was just
×
UNCOV
699
        // written. If the edge doesn't exist, we'll defer loading the info and
×
UNCOV
700
        // policies and lazily read from disk during the next query.
×
701
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
702
                if isUpdate1 {
×
703
                        channel.Policy1 = e
×
704
                } else {
×
705
                        channel.Policy2 = e
×
706
                }
×
707
                s.chanCache.insert(e.ChannelID, channel)
×
708
        }
709
}
710

711
// ForEachSourceNodeChannel iterates through all channels of the source node,
712
// executing the passed callback on each. The call-back is provided with the
713
// channel's outpoint, whether we have a policy for the channel and the channel
UNCOV
714
// peer's node information.
×
UNCOV
715
//
×
UNCOV
716
// NOTE: part of the V1Store interface.
×
UNCOV
717
func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
×
718
        havePolicy bool, otherNode *models.LightningNode) error) error {
×
719

×
720
        var ctx = context.TODO()
×
721

722
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
723
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
724
                if err != nil {
725
                        return fmt.Errorf("unable to fetch source node: %w",
726
                                err)
727
                }
728

729
                return forEachNodeChannel(
730
                        ctx, db, s.cfg.ChainHash, nodeID,
731
                        func(info *models.ChannelEdgeInfo,
×
732
                                outPolicy *models.ChannelEdgePolicy,
×
733
                                _ *models.ChannelEdgePolicy) error {
×
734

×
735
                                // Fetch the other node.
×
736
                                var (
×
737
                                        otherNodePub [33]byte
×
738
                                        node1        = info.NodeKey1Bytes
×
739
                                        node2        = info.NodeKey2Bytes
×
740
                                )
×
741
                                switch {
742
                                case bytes.Equal(node1[:], nodePub[:]):
×
743
                                        otherNodePub = node2
×
744
                                case bytes.Equal(node2[:], nodePub[:]):
×
745
                                        otherNodePub = node1
×
746
                                default:
×
747
                                        return fmt.Errorf("node not " +
×
748
                                                "participating in this channel")
×
UNCOV
749
                                }
×
UNCOV
750

×
751
                                _, otherNode, err := getNodeByPubKey(
×
752
                                        ctx, db, otherNodePub,
×
753
                                )
×
754
                                if err != nil {
×
755
                                        return fmt.Errorf("unable to fetch "+
×
756
                                                "other node(%x): %w",
×
757
                                                otherNodePub, err)
×
758
                                }
×
UNCOV
759

×
760
                                return cb(
×
761
                                        info.ChannelPoint, outPolicy != nil,
×
762
                                        otherNode,
763
                                )
UNCOV
764
                        },
×
UNCOV
765
                )
×
UNCOV
766
        }, sqldb.NoOpReset)
×
UNCOV
767
}
×
UNCOV
768

×
UNCOV
769
// ForEachNode iterates through all the stored vertices/nodes in the graph,
×
UNCOV
770
// executing the passed callback with each node encountered. If the callback
×
UNCOV
771
// returns an error, then the transaction is aborted and the iteration stops
×
772
// early. Any operations performed on the NodeTx passed to the call-back are
UNCOV
773
// executed under the same read transaction and so, methods on the NodeTx object
×
UNCOV
774
// _MUST_ only be called from within the call-back.
×
UNCOV
775
//
×
UNCOV
776
// NOTE: part of the V1Store interface.
×
777
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
778
        var (
779
                ctx          = context.TODO()
780
                lastID int64 = 0
781
        )
782

783
        handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
784
                node, err := buildNode(ctx, db, &dbNode)
785
                if err != nil {
786
                        return fmt.Errorf("unable to build node(id=%d): %w",
787
                                dbNode.ID, err)
788
                }
789

790
                err = cb(
×
791
                        newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
×
792
                )
×
793
                if err != nil {
×
794
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
795
                                dbNode.ID, err)
×
796
                }
×
UNCOV
797

×
798
                return nil
×
UNCOV
799
        }
×
UNCOV
800

×
801
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
802
                for {
803
                        nodes, err := db.ListNodesPaginated(
×
804
                                ctx, sqlc.ListNodesPaginatedParams{
×
805
                                        Version: int16(ProtocolV1),
×
806
                                        ID:      lastID,
×
807
                                        Limit:   pageSize,
×
808
                                },
×
809
                        )
×
810
                        if err != nil {
811
                                return fmt.Errorf("unable to fetch nodes: %w",
×
812
                                        err)
813
                        }
UNCOV
814

×
815
                        if len(nodes) == 0 {
×
816
                                break
×
UNCOV
817
                        }
×
UNCOV
818

×
819
                        for _, dbNode := range nodes {
×
820
                                err = handleNode(db, dbNode)
×
821
                                if err != nil {
×
822
                                        return err
×
823
                                }
×
UNCOV
824

×
825
                                lastID = dbNode.ID
×
UNCOV
826
                        }
×
827
                }
UNCOV
828

×
829
                return nil
×
830
        }, sqldb.NoOpReset)
831
}
UNCOV
832

×
UNCOV
833
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
×
UNCOV
834
// SQLStore and a SQL transaction.
×
UNCOV
835
type sqlGraphNodeTx struct {
×
UNCOV
836
        db    SQLQueries
×
837
        id    int64
UNCOV
838
        node  *models.LightningNode
×
839
        chain chainhash.Hash
840
}
841

UNCOV
842
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
×
843
// interface.
844
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
845

846
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
847
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
848

849
        return &sqlGraphNodeTx{
850
                db:    db,
851
                chain: chain,
852
                id:    id,
853
                node:  node,
854
        }
855
}
856

857
// Node returns the raw information of the node.
858
//
859
// NOTE: This is a part of the NodeRTx interface.
860
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
861
        return s.node
×
862
}
×
UNCOV
863

×
UNCOV
864
// ForEachChannel can be used to iterate over the node's channels under the same
×
UNCOV
865
// transaction used to fetch the node.
×
UNCOV
866
//
×
UNCOV
867
// NOTE: This is a part of the NodeRTx interface.
×
UNCOV
868
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
×
869
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
870

871
        ctx := context.TODO()
872

873
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
874
}
×
UNCOV
875

×
876
// FetchNode fetches the node with the given pub key under the same transaction
877
// used to fetch the current node. The returned node is also a NodeRTx and any
878
// operations on that NodeRTx will also be done under the same transaction.
879
//
880
// NOTE: This is a part of the NodeRTx interface.
881
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
882
        ctx := context.TODO()
×
883

×
884
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
885
        if err != nil {
×
886
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
887
                        nodePub, err)
×
888
        }
889

890
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
891
}
892

893
// ForEachNodeDirectedChannel iterates through all channels of a given node,
UNCOV
894
// executing the passed callback on the directed edge representing the channel
×
UNCOV
895
// and its incoming policy. If the callback returns an error, then the iteration
×
UNCOV
896
// is halted with the error propagated back up to the caller.
×
UNCOV
897
//
×
UNCOV
898
// Unknown policies are passed into the callback as nil values.
×
UNCOV
899
//
×
UNCOV
900
// NOTE: this is part of the graphdb.NodeTraverser interface.
×
UNCOV
901
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
×
902
        cb func(channel *DirectedChannel) error) error {
903

×
904
        var ctx = context.TODO()
905

906
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
907
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
908
        }, sqldb.NoOpReset)
909
}
910

911
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
912
// graph, executing the passed callback with each node encountered. If the
913
// callback returns an error, then the transaction is aborted and the iteration
914
// stops early.
UNCOV
915
//
×
UNCOV
916
// NOTE: This is a part of the V1Store interface.
×
UNCOV
917
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
×
918
        *lnwire.FeatureVector) error) error {
×
919

×
920
        ctx := context.TODO()
×
921

×
922
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
923
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
924
                        nodePub route.Vertex) error {
925

926
                        features, err := getNodeFeatures(ctx, db, nodeID)
927
                        if err != nil {
928
                                return fmt.Errorf("unable to fetch node "+
929
                                        "features: %w", err)
930
                        }
UNCOV
931

×
932
                        return cb(nodePub, features)
×
UNCOV
933
                })
×
UNCOV
934
        }, sqldb.NoOpReset)
×
935
        if err != nil {
×
936
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
937
        }
×
UNCOV
938

×
939
        return nil
×
UNCOV
940
}
×
UNCOV
941

×
UNCOV
942
// ForEachNodeChannel iterates through all channels of the given node,
×
UNCOV
943
// executing the passed callback with an edge info structure and the policies
×
944
// of each end of the channel. The first edge policy is the outgoing edge *to*
UNCOV
945
// the connecting node, while the second is the incoming edge *from* the
×
946
// connecting node. If the callback returns an error, then the iteration is
947
// halted with the error propagated back up to the caller.
UNCOV
948
//
×
UNCOV
949
// Unknown policies are passed into the callback as nil values.
×
UNCOV
950
//
×
951
// NOTE: part of the V1Store interface.
UNCOV
952
func (s *SQLStore) ForEachNodeChannel(nodePub route.Vertex,
×
953
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
954
                *models.ChannelEdgePolicy) error) error {
955

956
        var ctx = context.TODO()
957

958
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
959
                dbNode, err := db.GetNodeByPubKey(
960
                        ctx, sqlc.GetNodeByPubKeyParams{
961
                                Version: int16(ProtocolV1),
962
                                PubKey:  nodePub[:],
963
                        },
964
                )
965
                if errors.Is(err, sql.ErrNoRows) {
966
                        return nil
967
                } else if err != nil {
×
968
                        return fmt.Errorf("unable to fetch node: %w", err)
×
969
                }
×
UNCOV
970

×
971
                return forEachNodeChannel(
×
972
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
973
                )
×
UNCOV
974
        }, sqldb.NoOpReset)
×
UNCOV
975
}
×
UNCOV
976

×
UNCOV
977
// ChanUpdatesInHorizon returns all the known channel edges which have at least
×
UNCOV
978
// one edge that has an update timestamp within the specified horizon.
×
UNCOV
979
//
×
UNCOV
980
// NOTE: This is part of the V1Store interface.
×
UNCOV
981
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
×
982
        endTime time.Time) ([]ChannelEdge, error) {
×
983

984
        s.cacheMu.Lock()
×
985
        defer s.cacheMu.Unlock()
×
986

×
987
        var (
988
                ctx = context.TODO()
989
                // To ensure we don't return duplicate ChannelEdges, we'll use
990
                // an additional map to keep track of the edges already seen to
991
                // prevent re-adding it.
992
                edgesSeen    = make(map[uint64]struct{})
993
                edgesToCache = make(map[uint64]ChannelEdge)
994
                edges        []ChannelEdge
995
                hits         int
×
996
        )
×
997
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
998
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
999
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1000
                                Version:   int16(ProtocolV1),
×
1001
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1002
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1003
                        },
×
1004
                )
×
1005
                if err != nil {
×
1006
                        return err
×
1007
                }
×
UNCOV
1008

×
1009
                for _, row := range rows {
×
1010
                        // If we've already retrieved the info and policies for
×
1011
                        // this edge, then we can skip it as we don't need to do
×
1012
                        // so again.
×
1013
                        chanIDInt := byteOrder.Uint64(row.Channel.Scid)
×
1014
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1015
                                continue
×
UNCOV
1016
                        }
×
UNCOV
1017

×
1018
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1019
                                hits++
×
1020
                                edgesSeen[chanIDInt] = struct{}{}
×
1021
                                edges = append(edges, channel)
1022

×
1023
                                continue
×
UNCOV
1024
                        }
×
UNCOV
1025

×
1026
                        node1, node2, err := buildNodes(
×
1027
                                ctx, db, row.Node, row.Node_2,
×
1028
                        )
×
1029
                        if err != nil {
1030
                                return err
1031
                        }
×
UNCOV
1032

×
1033
                        channel, err := getAndBuildEdgeInfo(
×
1034
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1035
                                row.Channel, node1.PubKeyBytes,
×
1036
                                node2.PubKeyBytes,
×
1037
                        )
1038
                        if err != nil {
1039
                                return fmt.Errorf("unable to build channel "+
×
1040
                                        "info: %w", err)
×
1041
                        }
×
UNCOV
1042

×
1043
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1044
                        if err != nil {
×
1045
                                return fmt.Errorf("unable to extract channel "+
1046
                                        "policies: %w", err)
×
1047
                        }
×
UNCOV
1048

×
1049
                        p1, p2, err := getAndBuildChanPolicies(
×
1050
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1051
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1052
                        )
×
1053
                        if err != nil {
×
1054
                                return fmt.Errorf("unable to build channel "+
×
1055
                                        "policies: %w", err)
1056
                        }
×
UNCOV
1057

×
1058
                        edgesSeen[chanIDInt] = struct{}{}
×
1059
                        chanEdge := ChannelEdge{
×
1060
                                Info:    channel,
×
1061
                                Policy1: p1,
1062
                                Policy2: p2,
×
1063
                                Node1:   node1,
×
1064
                                Node2:   node2,
×
1065
                        }
×
1066
                        edges = append(edges, chanEdge)
×
1067
                        edgesToCache[chanIDInt] = chanEdge
×
UNCOV
1068
                }
×
UNCOV
1069

×
1070
                return nil
1071
        }, func() {
×
1072
                edgesSeen = make(map[uint64]struct{})
×
1073
                edgesToCache = make(map[uint64]ChannelEdge)
×
1074
                edges = nil
×
1075
        })
×
1076
        if err != nil {
×
1077
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1078
        }
×
UNCOV
1079

×
UNCOV
1080
        // Insert any edges loaded from disk into the cache.
×
1081
        for chanid, channel := range edgesToCache {
1082
                s.chanCache.insert(chanid, channel)
1083
        }
×
UNCOV
1084

×
1085
        if len(edges) > 0 {
×
1086
                log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
×
1087
                        float64(hits)/float64(len(edges)), hits, len(edges))
×
1088
        } else {
×
1089
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1090
                        "horizon (%s, %s)", startTime, endTime)
×
1091
        }
×
1092

1093
        return edges, nil
UNCOV
1094
}
×
UNCOV
1095

×
UNCOV
1096
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
×
1097
// data to the call-back.
UNCOV
1098
//
×
UNCOV
1099
// NOTE: The callback contents MUST not be modified.
×
UNCOV
1100
//
×
UNCOV
1101
// NOTE: part of the V1Store interface.
×
UNCOV
1102
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
×
1103
        chans map[uint64]*DirectedChannel) error) error {
×
1104

×
1105
        var ctx = context.TODO()
1106

×
1107
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1108
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
1109
                        nodePub route.Vertex) error {
1110

1111
                        features, err := getNodeFeatures(ctx, db, nodeID)
1112
                        if err != nil {
1113
                                return fmt.Errorf("unable to fetch "+
1114
                                        "node(id=%d) features: %w", nodeID, err)
1115
                        }
UNCOV
1116

×
1117
                        toNodeCallback := func() route.Vertex {
×
1118
                                return nodePub
×
1119
                        }
×
UNCOV
1120

×
1121
                        rows, err := db.ListChannelsByNodeID(
×
1122
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1123
                                        Version: int16(ProtocolV1),
×
1124
                                        NodeID1: nodeID,
×
1125
                                },
×
1126
                        )
×
1127
                        if err != nil {
×
1128
                                return fmt.Errorf("unable to fetch channels "+
×
1129
                                        "of node(id=%d): %w", nodeID, err)
1130
                        }
×
UNCOV
1131

×
1132
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1133
                        for _, row := range rows {
1134
                                node1, node2, err := buildNodeVertices(
×
1135
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1136
                                )
×
1137
                                if err != nil {
×
1138
                                        return err
×
1139
                                }
×
UNCOV
1140

×
1141
                                e, err := getAndBuildEdgeInfo(
×
1142
                                        ctx, db, s.cfg.ChainHash,
×
1143
                                        row.Channel.ID, row.Channel, node1,
×
1144
                                        node2,
1145
                                )
×
1146
                                if err != nil {
×
1147
                                        return fmt.Errorf("unable to build "+
×
1148
                                                "channel info: %w", err)
×
1149
                                }
×
UNCOV
1150

×
1151
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1152
                                        row,
×
1153
                                )
1154
                                if err != nil {
×
1155
                                        return fmt.Errorf("unable to "+
×
1156
                                                "extract channel "+
×
1157
                                                "policies: %w", err)
×
1158
                                }
×
UNCOV
1159

×
1160
                                p1, p2, err := getAndBuildChanPolicies(
×
1161
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1162
                                        node1, node2,
×
1163
                                )
1164
                                if err != nil {
×
1165
                                        return fmt.Errorf("unable to "+
×
1166
                                                "build channel policies: %w",
×
1167
                                                err)
×
1168
                                }
×
UNCOV
1169

×
UNCOV
1170
                                // Determine the outgoing and incoming policy
×
UNCOV
1171
                                // for this channel and node combo.
×
1172
                                outPolicy, inPolicy := p1, p2
1173
                                if p1 != nil && p1.ToNode == nodePub {
×
1174
                                        outPolicy, inPolicy = p2, p1
×
1175
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1176
                                        outPolicy, inPolicy = p2, p1
×
1177
                                }
×
UNCOV
1178

×
1179
                                var cachedInPolicy *models.CachedEdgePolicy
×
1180
                                if inPolicy != nil {
×
1181
                                        cachedInPolicy = models.NewCachedPolicy(
×
1182
                                                p2,
1183
                                        )
1184
                                        cachedInPolicy.ToNodePubKey =
1185
                                                toNodeCallback
×
1186
                                        cachedInPolicy.ToNodeFeatures =
×
1187
                                                features
×
1188
                                }
×
UNCOV
1189

×
1190
                                var inboundFee lnwire.Fee
×
1191
                                outPolicy.InboundFee.WhenSome(
1192
                                        func(fee lnwire.Fee) {
×
1193
                                                inboundFee = fee
×
1194
                                        },
×
UNCOV
1195
                                )
×
UNCOV
1196

×
1197
                                directedChannel := &DirectedChannel{
×
1198
                                        ChannelID: e.ChannelID,
×
1199
                                        IsNode1: nodePub ==
×
1200
                                                e.NodeKey1Bytes,
×
1201
                                        OtherNode:    e.NodeKey2Bytes,
×
1202
                                        Capacity:     e.Capacity,
1203
                                        OutPolicySet: p1 != nil,
×
1204
                                        InPolicy:     cachedInPolicy,
×
1205
                                        InboundFee:   inboundFee,
×
1206
                                }
×
1207

×
1208
                                if nodePub == e.NodeKey2Bytes {
1209
                                        directedChannel.OtherNode =
1210
                                                e.NodeKey1Bytes
×
1211
                                }
×
UNCOV
1212

×
1213
                                channels[e.ChannelID] = directedChannel
×
UNCOV
1214
                        }
×
UNCOV
1215

×
1216
                        return cb(nodePub, channels)
×
UNCOV
1217
                })
×
UNCOV
1218
        }, sqldb.NoOpReset)
×
UNCOV
1219
}
×
UNCOV
1220

×
UNCOV
1221
// ForEachChannel iterates through all the channel edges stored within the
×
UNCOV
1222
// graph and invokes the passed callback for each edge. The callback takes two
×
UNCOV
1223
// edges as since this is a directed graph, both the in/out edges are visited.
×
UNCOV
1224
// If the callback returns an error, then the transaction is aborted and the
×
1225
// iteration stops early.
UNCOV
1226
//
×
1227
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1228
// for that particular channel edge routing policy will be passed into the
UNCOV
1229
// callback.
×
1230
//
1231
// NOTE: part of the V1Store interface.
1232
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1233
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
1234

1235
        ctx := context.TODO()
1236

1237
        handleChannel := func(db SQLQueries,
1238
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
1239

1240
                node1, node2, err := buildNodeVertices(
1241
                        row.Node1Pubkey, row.Node2Pubkey,
1242
                )
1243
                if err != nil {
1244
                        return fmt.Errorf("unable to build node vertices: %w",
1245
                                err)
1246
                }
×
UNCOV
1247

×
1248
                edge, err := getAndBuildEdgeInfo(
×
1249
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1250
                        node1, node2,
×
1251
                )
×
1252
                if err != nil {
×
1253
                        return fmt.Errorf("unable to build channel info: %w",
×
1254
                                err)
×
1255
                }
×
UNCOV
1256

×
1257
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1258
                if err != nil {
×
1259
                        return fmt.Errorf("unable to extract channel "+
×
1260
                                "policies: %w", err)
1261
                }
×
UNCOV
1262

×
1263
                p1, p2, err := getAndBuildChanPolicies(
×
1264
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1265
                )
×
1266
                if err != nil {
×
1267
                        return fmt.Errorf("unable to build channel "+
×
1268
                                "policies: %w", err)
×
1269
                }
UNCOV
1270

×
1271
                err = cb(edge, p1, p2)
×
1272
                if err != nil {
×
1273
                        return fmt.Errorf("callback failed for channel "+
×
1274
                                "id=%d: %w", edge.ChannelID, err)
×
1275
                }
UNCOV
1276

×
1277
                return nil
×
UNCOV
1278
        }
×
UNCOV
1279

×
1280
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1281
                var lastID int64
×
1282
                for {
×
1283
                        //nolint:ll
1284
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1285
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1286
                                        Version: int16(ProtocolV1),
×
1287
                                        ID:      lastID,
×
1288
                                        Limit:   pageSize,
×
1289
                                },
1290
                        )
×
1291
                        if err != nil {
1292
                                return err
1293
                        }
×
UNCOV
1294

×
1295
                        if len(rows) == 0 {
×
1296
                                break
×
UNCOV
1297
                        }
×
UNCOV
1298

×
1299
                        for _, row := range rows {
×
1300
                                err := handleChannel(db, row)
×
1301
                                if err != nil {
×
1302
                                        return err
×
1303
                                }
×
UNCOV
1304

×
1305
                                lastID = row.Channel.ID
×
UNCOV
1306
                        }
×
1307
                }
UNCOV
1308

×
1309
                return nil
×
1310
        }, sqldb.NoOpReset)
1311
}
UNCOV
1312

×
UNCOV
1313
// FilterChannelRange returns the channel ID's of all known channels which were
×
UNCOV
1314
// mined in a block height within the passed range. The channel IDs are grouped
×
UNCOV
1315
// by their common block height. This method can be used to quickly share with a
×
UNCOV
1316
// peer the set of channels we know of within a particular range to catch them
×
1317
// up after a period of time offline. If withTimestamps is true then the
UNCOV
1318
// timestamp info of the latest received channel update messages of the channel
×
1319
// will be included in the response.
1320
//
1321
// NOTE: This is part of the V1Store interface.
UNCOV
1322
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
×
1323
        withTimestamps bool) ([]BlockChannelRange, error) {
1324

1325
        var (
1326
                ctx       = context.TODO()
1327
                startSCID = &lnwire.ShortChannelID{
1328
                        BlockHeight: startHeight,
1329
                }
1330
                endSCID = lnwire.ShortChannelID{
1331
                        BlockHeight: endHeight,
1332
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
1333
                        TxPosition:  math.MaxUint16,
1334
                }
1335
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
1336
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1337
        )
×
1338

×
1339
        // 1) get all channels where channelID is between start and end chan ID.
×
1340
        // 2) skip if not public (ie, no channel_proof)
×
1341
        // 3) collect that channel.
×
1342
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1343
        //    and add those timestamps to the collected channel.
×
1344
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1345
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1346
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1347
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1348
                                StartScid: chanIDStart[:],
×
1349
                                EndScid:   chanIDEnd[:],
×
1350
                        },
×
1351
                )
×
1352
                if err != nil {
×
1353
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1354
                                err)
×
1355
                }
×
UNCOV
1356

×
1357
                for _, dbChan := range dbChans {
×
1358
                        cid := lnwire.NewShortChanIDFromInt(
×
1359
                                byteOrder.Uint64(dbChan.Scid),
×
1360
                        )
×
1361
                        chanInfo := NewChannelUpdateInfo(
×
1362
                                cid, time.Time{}, time.Time{},
×
1363
                        )
×
1364

×
1365
                        if !withTimestamps {
×
1366
                                channelsPerBlock[cid.BlockHeight] = append(
×
1367
                                        channelsPerBlock[cid.BlockHeight],
×
1368
                                        chanInfo,
×
1369
                                )
1370

×
1371
                                continue
×
UNCOV
1372
                        }
×
UNCOV
1373

×
UNCOV
1374
                        //nolint:ll
×
1375
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1376
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1377
                                        Version:   int16(ProtocolV1),
×
1378
                                        ChannelID: dbChan.ID,
×
1379
                                        NodeID:    dbChan.NodeID1,
×
1380
                                },
×
1381
                        )
×
1382
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1383
                                return fmt.Errorf("unable to fetch node1 "+
×
1384
                                        "policy: %w", err)
×
1385
                        } else if err == nil {
1386
                                chanInfo.Node1UpdateTimestamp = time.Unix(
1387
                                        node1Policy.LastUpdate.Int64, 0,
1388
                                )
×
1389
                        }
×
UNCOV
1390

×
UNCOV
1391
                        //nolint:ll
×
1392
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1393
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1394
                                        Version:   int16(ProtocolV1),
×
1395
                                        ChannelID: dbChan.ID,
×
1396
                                        NodeID:    dbChan.NodeID2,
×
1397
                                },
×
1398
                        )
×
1399
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1400
                                return fmt.Errorf("unable to fetch node2 "+
×
1401
                                        "policy: %w", err)
×
1402
                        } else if err == nil {
×
1403
                                chanInfo.Node2UpdateTimestamp = time.Unix(
1404
                                        node2Policy.LastUpdate.Int64, 0,
1405
                                )
×
1406
                        }
×
UNCOV
1407

×
1408
                        channelsPerBlock[cid.BlockHeight] = append(
×
1409
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1410
                        )
×
UNCOV
1411
                }
×
UNCOV
1412

×
1413
                return nil
×
1414
        }, func() {
×
1415
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1416
        })
×
1417
        if err != nil {
×
1418
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1419
        }
×
1420

1421
        if len(channelsPerBlock) == 0 {
×
1422
                return nil, nil
×
1423
        }
×
1424

1425
        // Return the channel ranges in ascending block height order.
1426
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1427
        slices.Sort(blocks)
×
1428

×
1429
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1430
                return BlockChannelRange{
×
1431
                        Height:   block,
×
1432
                        Channels: channelsPerBlock[block],
×
1433
                }
1434
        }), nil
×
UNCOV
1435
}
×
UNCOV
1436

×
1437
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1438
// zombie. This method is used on an ad-hoc basis, when channels need to be
UNCOV
1439
// marked as zombies outside the normal pruning cycle.
×
UNCOV
1440
//
×
UNCOV
1441
// NOTE: part of the V1Store interface.
×
UNCOV
1442
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
×
1443
        pubKey1, pubKey2 [33]byte) error {
×
1444

×
1445
        ctx := context.TODO()
×
1446

×
1447
        s.cacheMu.Lock()
×
1448
        defer s.cacheMu.Unlock()
1449

1450
        chanIDB := channelIDToBytes(chanID)
1451

1452
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
1453
                return db.UpsertZombieChannel(
1454
                        ctx, sqlc.UpsertZombieChannelParams{
1455
                                Version:  int16(ProtocolV1),
1456
                                Scid:     chanIDB[:],
×
1457
                                NodeKey1: pubKey1[:],
×
1458
                                NodeKey2: pubKey2[:],
×
1459
                        },
×
1460
                )
×
1461
        }, sqldb.NoOpReset)
×
1462
        if err != nil {
×
1463
                return fmt.Errorf("unable to upsert zombie channel "+
×
1464
                        "(channel_id=%d): %w", chanID, err)
×
1465
        }
×
UNCOV
1466

×
1467
        s.rejectCache.remove(chanID)
×
1468
        s.chanCache.remove(chanID)
×
1469

×
1470
        return nil
×
UNCOV
1471
}
×
UNCOV
1472

×
UNCOV
1473
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
×
UNCOV
1474
//
×
UNCOV
1475
// NOTE: part of the V1Store interface.
×
1476
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1477
        s.cacheMu.Lock()
×
1478
        defer s.cacheMu.Unlock()
×
1479

1480
        var (
×
1481
                ctx     = context.TODO()
×
1482
                chanIDB = channelIDToBytes(chanID)
×
1483
        )
×
1484

1485
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
1486
                res, err := db.DeleteZombieChannel(
1487
                        ctx, sqlc.DeleteZombieChannelParams{
1488
                                Scid:    chanIDB[:],
1489
                                Version: int16(ProtocolV1),
×
1490
                        },
×
1491
                )
×
1492
                if err != nil {
×
1493
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1494
                                err)
×
1495
                }
×
UNCOV
1496

×
1497
                rows, err := res.RowsAffected()
×
1498
                if err != nil {
×
1499
                        return err
×
1500
                }
×
UNCOV
1501

×
1502
                if rows == 0 {
×
1503
                        return ErrZombieEdgeNotFound
×
1504
                } else if rows > 1 {
×
1505
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1506
                                "expected 1", rows)
×
1507
                }
×
UNCOV
1508

×
1509
                return nil
UNCOV
1510
        }, sqldb.NoOpReset)
×
1511
        if err != nil {
×
1512
                return fmt.Errorf("unable to mark edge live "+
×
1513
                        "(channel_id=%d): %w", chanID, err)
×
1514
        }
UNCOV
1515

×
1516
        s.rejectCache.remove(chanID)
×
1517
        s.chanCache.remove(chanID)
×
1518

×
1519
        return err
×
UNCOV
1520
}
×
1521

UNCOV
1522
// IsZombieEdge returns whether the edge is considered zombie. If it is a
×
1523
// zombie, then the two node public keys corresponding to this edge are also
UNCOV
1524
// returned.
×
UNCOV
1525
//
×
UNCOV
1526
// NOTE: part of the V1Store interface.
×
1527
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
×
1528
        var (
1529
                ctx              = context.TODO()
×
1530
                isZombie         bool
×
1531
                pubKey1, pubKey2 route.Vertex
×
1532
                chanIDB          = channelIDToBytes(chanID)
×
1533
        )
1534

1535
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1536
                zombie, err := db.GetZombieChannel(
1537
                        ctx, sqlc.GetZombieChannelParams{
1538
                                Scid:    chanIDB[:],
1539
                                Version: int16(ProtocolV1),
1540
                        },
×
1541
                )
×
1542
                if errors.Is(err, sql.ErrNoRows) {
×
1543
                        return nil
×
1544
                }
×
1545
                if err != nil {
×
1546
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1547
                                err)
×
1548
                }
×
UNCOV
1549

×
1550
                copy(pubKey1[:], zombie.NodeKey1)
×
1551
                copy(pubKey2[:], zombie.NodeKey2)
×
1552
                isZombie = true
×
1553

×
1554
                return nil
×
UNCOV
1555
        }, sqldb.NoOpReset)
×
1556
        if err != nil {
×
1557
                // TODO(elle): update the IsZombieEdge method to return an
×
1558
                // error.
×
1559
                return false, route.Vertex{}, route.Vertex{}
×
1560
        }
×
UNCOV
1561

×
1562
        return isZombie, pubKey1, pubKey2
UNCOV
1563
}
×
UNCOV
1564

×
UNCOV
1565
// NumZombies returns the current number of zombie channels in the graph.
×
UNCOV
1566
//
×
UNCOV
1567
// NOTE: part of the V1Store interface.
×
1568
func (s *SQLStore) NumZombies() (uint64, error) {
1569
        var (
×
1570
                ctx        = context.TODO()
×
1571
                numZombies uint64
×
1572
        )
×
1573
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1574
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
1575
                if err != nil {
×
1576
                        return fmt.Errorf("unable to count zombie channels: %w",
1577
                                err)
1578
                }
1579

1580
                numZombies = uint64(count)
1581

×
1582
                return nil
×
UNCOV
1583
        }, sqldb.NoOpReset)
×
1584
        if err != nil {
×
1585
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1586
        }
×
UNCOV
1587

×
1588
        return numZombies, nil
×
UNCOV
1589
}
×
UNCOV
1590

×
UNCOV
1591
// DeleteChannelEdges removes edges with the given channel IDs from the
×
1592
// database and marks them as zombies. This ensures that we're unable to re-add
UNCOV
1593
// it to our database once again. If an edge does not exist within the
×
UNCOV
1594
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
×
UNCOV
1595
// true, then when we mark these edges as zombies, we'll set up the keys such
×
1596
// that we require the node that failed to send the fresh update to be the one
UNCOV
1597
// that resurrects the channel from its zombie state. The markZombie bool
×
UNCOV
1598
// denotes whether to mark the channel as a zombie.
×
UNCOV
1599
//
×
1600
// NOTE: part of the V1Store interface.
UNCOV
1601
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
×
1602
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
1603

1604
        s.cacheMu.Lock()
1605
        defer s.cacheMu.Unlock()
1606

1607
        var (
1608
                ctx     = context.TODO()
1609
                deleted []*models.ChannelEdgeInfo
1610
        )
1611
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
1612
                for _, chanID := range chanIDs {
1613
                        chanIDB := channelIDToBytes(chanID)
1614

1615
                        row, err := db.GetChannelBySCIDWithPolicies(
×
1616
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1617
                                        Scid:    chanIDB[:],
×
1618
                                        Version: int16(ProtocolV1),
×
1619
                                },
×
1620
                        )
×
1621
                        if errors.Is(err, sql.ErrNoRows) {
×
1622
                                return ErrEdgeNotFound
×
1623
                        } else if err != nil {
×
1624
                                return fmt.Errorf("unable to fetch channel: %w",
×
1625
                                        err)
×
1626
                        }
×
UNCOV
1627

×
1628
                        node1, node2, err := buildNodeVertices(
×
1629
                                row.Node.PubKey, row.Node_2.PubKey,
×
1630
                        )
×
1631
                        if err != nil {
×
1632
                                return err
×
1633
                        }
×
UNCOV
1634

×
1635
                        info, err := getAndBuildEdgeInfo(
×
1636
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
1637
                                row.Channel, node1, node2,
×
1638
                        )
×
1639
                        if err != nil {
×
1640
                                return err
1641
                        }
×
UNCOV
1642

×
1643
                        err = db.DeleteChannel(ctx, row.Channel.ID)
×
1644
                        if err != nil {
×
1645
                                return fmt.Errorf("unable to delete "+
×
1646
                                        "channel: %w", err)
×
1647
                        }
UNCOV
1648

×
1649
                        deleted = append(deleted, info)
×
1650

×
1651
                        if !markZombie {
×
1652
                                continue
×
UNCOV
1653
                        }
×
UNCOV
1654

×
1655
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
1656
                                info.NodeKey2Bytes
×
1657
                        if strictZombiePruning {
×
1658
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1659
                                if row.Policy1LastUpdate.Valid {
×
1660
                                        e1Time := time.Unix(
×
1661
                                                row.Policy1LastUpdate.Int64, 0,
1662
                                        )
×
1663
                                        e1UpdateTime = &e1Time
×
1664
                                }
×
1665
                                if row.Policy2LastUpdate.Valid {
×
1666
                                        e2Time := time.Unix(
1667
                                                row.Policy2LastUpdate.Int64, 0,
1668
                                        )
×
1669
                                        e2UpdateTime = &e2Time
×
1670
                                }
×
UNCOV
1671

×
1672
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1673
                                        info, e1UpdateTime, e2UpdateTime,
×
1674
                                )
×
UNCOV
1675
                        }
×
UNCOV
1676

×
1677
                        err = db.UpsertZombieChannel(
×
1678
                                ctx, sqlc.UpsertZombieChannelParams{
×
1679
                                        Version:  int16(ProtocolV1),
×
1680
                                        Scid:     chanIDB[:],
×
1681
                                        NodeKey1: nodeKey1[:],
×
1682
                                        NodeKey2: nodeKey2[:],
×
1683
                                },
×
1684
                        )
1685
                        if err != nil {
×
1686
                                return fmt.Errorf("unable to mark channel as "+
×
1687
                                        "zombie: %w", err)
×
1688
                        }
1689
                }
UNCOV
1690

×
1691
                return nil
×
1692
        }, func() {
×
1693
                deleted = nil
×
1694
        })
×
1695
        if err != nil {
×
1696
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1697
                        err)
×
1698
        }
×
UNCOV
1699

×
1700
        for _, chanID := range chanIDs {
×
1701
                s.rejectCache.remove(chanID)
×
1702
                s.chanCache.remove(chanID)
1703
        }
UNCOV
1704

×
1705
        return deleted, nil
×
UNCOV
1706
}
×
UNCOV
1707

×
UNCOV
1708
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
×
UNCOV
1709
// channel identified by the channel ID. If the channel can't be found, then
×
UNCOV
1710
// ErrEdgeNotFound is returned. A struct which houses the general information
×
UNCOV
1711
// for the channel itself is returned as well as two structs that contain the
×
1712
// routing policies for the channel in either direction.
UNCOV
1713
//
×
UNCOV
1714
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
×
UNCOV
1715
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
×
UNCOV
1716
// the ChannelEdgeInfo will only include the public keys of each node.
×
1717
//
UNCOV
1718
// NOTE: part of the V1Store interface.
×
1719
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1720
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1721
        *models.ChannelEdgePolicy, error) {
1722

1723
        var (
1724
                ctx              = context.TODO()
1725
                edge             *models.ChannelEdgeInfo
1726
                policy1, policy2 *models.ChannelEdgePolicy
1727
        )
1728
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1729
                var chanIDB [8]byte
1730
                byteOrder.PutUint64(chanIDB[:], chanID)
1731

1732
                row, err := db.GetChannelBySCIDWithPolicies(
1733
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
1734
                                Scid:    chanIDB[:],
×
1735
                                Version: int16(ProtocolV1),
×
1736
                        },
×
1737
                )
×
1738
                if errors.Is(err, sql.ErrNoRows) {
×
1739
                        // First check if this edge is perhaps in the zombie
×
1740
                        // index.
×
1741
                        isZombie, err := db.IsZombieChannel(
×
1742
                                ctx, sqlc.IsZombieChannelParams{
×
1743
                                        Scid:    chanIDB[:],
×
1744
                                        Version: int16(ProtocolV1),
×
1745
                                },
×
1746
                        )
×
1747
                        if err != nil {
×
1748
                                return fmt.Errorf("unable to check if "+
×
1749
                                        "channel is zombie: %w", err)
×
1750
                        } else if isZombie {
×
1751
                                return ErrZombieEdge
×
1752
                        }
×
UNCOV
1753

×
1754
                        return ErrEdgeNotFound
×
1755
                } else if err != nil {
×
1756
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1757
                }
×
UNCOV
1758

×
1759
                node1, node2, err := buildNodeVertices(
×
1760
                        row.Node.PubKey, row.Node_2.PubKey,
×
1761
                )
×
1762
                if err != nil {
×
1763
                        return err
×
1764
                }
×
UNCOV
1765

×
1766
                edge, err = getAndBuildEdgeInfo(
1767
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1768
                        node1, node2,
×
1769
                )
×
1770
                if err != nil {
×
1771
                        return fmt.Errorf("unable to build channel info: %w",
1772
                                err)
×
1773
                }
×
UNCOV
1774

×
1775
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1776
                if err != nil {
×
1777
                        return fmt.Errorf("unable to extract channel "+
×
1778
                                "policies: %w", err)
1779
                }
×
UNCOV
1780

×
1781
                policy1, policy2, err = getAndBuildChanPolicies(
×
1782
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1783
                )
×
1784
                if err != nil {
×
1785
                        return fmt.Errorf("unable to build channel "+
×
1786
                                "policies: %w", err)
×
1787
                }
UNCOV
1788

×
1789
                return nil
×
UNCOV
1790
        }, sqldb.NoOpReset)
×
1791
        if err != nil {
×
1792
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1793
                        err)
1794
        }
×
UNCOV
1795

×
1796
        return edge, policy1, policy2, nil
×
UNCOV
1797
}
×
UNCOV
1798

×
UNCOV
1799
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
×
UNCOV
1800
// the channel identified by the funding outpoint. If the channel can't be
×
1801
// found, then ErrEdgeNotFound is returned. A struct which houses the general
UNCOV
1802
// information for the channel itself is returned as well as two structs that
×
1803
// contain the routing policies for the channel in either direction.
UNCOV
1804
//
×
UNCOV
1805
// NOTE: part of the V1Store interface.
×
UNCOV
1806
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
×
UNCOV
1807
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
×
1808
        *models.ChannelEdgePolicy, error) {
1809

×
1810
        var (
1811
                ctx              = context.TODO()
1812
                edge             *models.ChannelEdgeInfo
1813
                policy1, policy2 *models.ChannelEdgePolicy
1814
        )
1815
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1816
                row, err := db.GetChannelByOutpointWithPolicies(
1817
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
1818
                                Outpoint: op.String(),
1819
                                Version:  int16(ProtocolV1),
1820
                        },
1821
                )
×
1822
                if errors.Is(err, sql.ErrNoRows) {
×
1823
                        return ErrEdgeNotFound
×
1824
                } else if err != nil {
×
1825
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1826
                }
×
UNCOV
1827

×
1828
                node1, node2, err := buildNodeVertices(
×
1829
                        row.Node1Pubkey, row.Node2Pubkey,
×
1830
                )
×
1831
                if err != nil {
×
1832
                        return err
×
1833
                }
×
UNCOV
1834

×
1835
                edge, err = getAndBuildEdgeInfo(
×
1836
                        ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
×
1837
                        node1, node2,
×
1838
                )
×
1839
                if err != nil {
×
1840
                        return fmt.Errorf("unable to build channel info: %w",
1841
                                err)
×
1842
                }
×
UNCOV
1843

×
1844
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1845
                if err != nil {
×
1846
                        return fmt.Errorf("unable to extract channel "+
×
1847
                                "policies: %w", err)
1848
                }
×
UNCOV
1849

×
1850
                policy1, policy2, err = getAndBuildChanPolicies(
×
1851
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1852
                )
×
1853
                if err != nil {
×
1854
                        return fmt.Errorf("unable to build channel "+
×
1855
                                "policies: %w", err)
×
1856
                }
UNCOV
1857

×
1858
                return nil
×
UNCOV
1859
        }, sqldb.NoOpReset)
×
1860
        if err != nil {
×
1861
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1862
                        err)
1863
        }
×
UNCOV
1864

×
1865
        return edge, policy1, policy2, nil
×
UNCOV
1866
}
×
UNCOV
1867

×
UNCOV
1868
// HasChannelEdge returns true if the database knows of a channel edge with the
×
UNCOV
1869
// passed channel ID, and false otherwise. If an edge with that ID is found
×
1870
// within the graph, then two time stamps representing the last time the edge
UNCOV
1871
// was updated for both directed edges are returned along with the boolean. If
×
1872
// it is not found, then the zombie index is checked and its result is returned
UNCOV
1873
// as the second boolean.
×
UNCOV
1874
//
×
UNCOV
1875
// NOTE: part of the V1Store interface.
×
UNCOV
1876
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
×
1877
        bool, error) {
1878

×
1879
        ctx := context.TODO()
1880

1881
        var (
1882
                exists          bool
1883
                isZombie        bool
1884
                node1LastUpdate time.Time
1885
                node2LastUpdate time.Time
1886
        )
1887

1888
        // We'll query the cache with the shared lock held to allow multiple
1889
        // readers to access values in the cache concurrently if they exist.
1890
        s.cacheMu.RLock()
×
1891
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1892
                s.cacheMu.RUnlock()
×
1893
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1894
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1895
                exists, isZombie = entry.flags.unpack()
×
1896

×
1897
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1898
        }
×
1899
        s.cacheMu.RUnlock()
×
1900

×
1901
        s.cacheMu.Lock()
×
1902
        defer s.cacheMu.Unlock()
×
1903

×
1904
        // The item was not found with the shared lock, so we'll acquire the
×
1905
        // exclusive lock and check the cache again in case another method added
×
1906
        // the entry to the cache while no lock was held.
×
1907
        if entry, ok := s.rejectCache.get(chanID); ok {
×
1908
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
1909
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
1910
                exists, isZombie = entry.flags.unpack()
×
1911

×
1912
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1913
        }
×
UNCOV
1914

×
1915
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1916
                var chanIDB [8]byte
×
1917
                byteOrder.PutUint64(chanIDB[:], chanID)
×
1918

×
1919
                channel, err := db.GetChannelBySCID(
×
1920
                        ctx, sqlc.GetChannelBySCIDParams{
×
1921
                                Scid:    chanIDB[:],
×
1922
                                Version: int16(ProtocolV1),
×
1923
                        },
×
1924
                )
×
1925
                if errors.Is(err, sql.ErrNoRows) {
×
1926
                        // Check if it is a zombie channel.
×
1927
                        isZombie, err = db.IsZombieChannel(
1928
                                ctx, sqlc.IsZombieChannelParams{
×
1929
                                        Scid:    chanIDB[:],
×
1930
                                        Version: int16(ProtocolV1),
×
1931
                                },
×
1932
                        )
×
1933
                        if err != nil {
×
1934
                                return fmt.Errorf("could not check if channel "+
×
1935
                                        "is zombie: %w", err)
×
1936
                        }
×
UNCOV
1937

×
1938
                        return nil
×
1939
                } else if err != nil {
×
1940
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1941
                }
×
UNCOV
1942

×
1943
                exists = true
×
1944

×
1945
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
1946
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1947
                                Version:   int16(ProtocolV1),
×
1948
                                ChannelID: channel.ID,
×
1949
                                NodeID:    channel.NodeID1,
×
1950
                        },
1951
                )
×
1952
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1953
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1954
                                err)
×
1955
                } else if err == nil {
1956
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
1957
                }
×
UNCOV
1958

×
1959
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
1960
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1961
                                Version:   int16(ProtocolV1),
×
1962
                                ChannelID: channel.ID,
×
1963
                                NodeID:    channel.NodeID2,
×
1964
                        },
×
1965
                )
×
1966
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1967
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
1968
                                err)
×
1969
                } else if err == nil {
×
1970
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
1971
                }
UNCOV
1972

×
1973
                return nil
×
UNCOV
1974
        }, sqldb.NoOpReset)
×
1975
        if err != nil {
×
1976
                return time.Time{}, time.Time{}, false, false,
×
1977
                        fmt.Errorf("unable to fetch channel: %w", err)
×
1978
        }
×
UNCOV
1979

×
1980
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
1981
                upd1Time: node1LastUpdate.Unix(),
×
1982
                upd2Time: node2LastUpdate.Unix(),
×
1983
                flags:    packRejectFlags(exists, isZombie),
×
1984
        })
×
1985

1986
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
1987
}
UNCOV
1988

×
UNCOV
1989
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
×
UNCOV
1990
// passed channel point (outpoint). If the passed channel doesn't exist within
×
UNCOV
1991
// the database, then ErrEdgeNotFound is returned.
×
1992
//
UNCOV
1993
// NOTE: part of the V1Store interface.
×
1994
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
1995
        var (
×
1996
                ctx       = context.TODO()
×
1997
                channelID uint64
×
1998
        )
×
1999
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2000
                chanID, err := db.GetSCIDByOutpoint(
2001
                        ctx, sqlc.GetSCIDByOutpointParams{
2002
                                Outpoint: chanPoint.String(),
2003
                                Version:  int16(ProtocolV1),
2004
                        },
2005
                )
2006
                if errors.Is(err, sql.ErrNoRows) {
2007
                        return ErrEdgeNotFound
×
2008
                } else if err != nil {
×
2009
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2010
                                err)
×
2011
                }
×
UNCOV
2012

×
2013
                channelID = byteOrder.Uint64(chanID)
×
2014

×
2015
                return nil
×
UNCOV
2016
        }, sqldb.NoOpReset)
×
2017
        if err != nil {
×
2018
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2019
        }
×
UNCOV
2020

×
2021
        return channelID, nil
×
UNCOV
2022
}
×
UNCOV
2023

×
UNCOV
2024
// IsPublicNode is a helper method that determines whether the node with the
×
2025
// given public key is seen as a public node in the graph from the graph's
UNCOV
2026
// source node's point of view.
×
UNCOV
2027
//
×
UNCOV
2028
// NOTE: part of the V1Store interface.
×
2029
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
2030
        ctx := context.TODO()
×
2031

×
2032
        var isPublic bool
×
2033
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
2034
                var err error
×
2035
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
2036

2037
                return err
2038
        }, sqldb.NoOpReset)
2039
        if err != nil {
2040
                return false, fmt.Errorf("unable to check if node is "+
2041
                        "public: %w", err)
2042
        }
×
UNCOV
2043

×
2044
        return isPublic, nil
×
UNCOV
2045
}
×
UNCOV
2046

×
UNCOV
2047
// FetchChanInfos returns the set of channel edges that correspond to the passed
×
UNCOV
2048
// channel ID's. If an edge is the query is unknown to the database, it will
×
UNCOV
2049
// skipped and the result will contain only those edges that exist at the time
×
UNCOV
2050
// of the query. This can be used to respond to peer queries that are seeking to
×
UNCOV
2051
// fill in gaps in their view of the channel graph.
×
UNCOV
2052
//
×
UNCOV
2053
// NOTE: part of the V1Store interface.
×
2054
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2055
        var (
×
2056
                ctx   = context.TODO()
2057
                edges []ChannelEdge
×
2058
        )
2059
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
2060
                for _, chanID := range chanIDs {
2061
                        var chanIDB [8]byte
2062
                        byteOrder.PutUint64(chanIDB[:], chanID)
2063

2064
                        // TODO(elle): potentially optimize this by using
2065
                        //  sqlc.slice() once that works for both SQLite and
2066
                        //  Postgres.
2067
                        row, err := db.GetChannelBySCIDWithPolicies(
×
2068
                                ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2069
                                        Scid:    chanIDB[:],
×
2070
                                        Version: int16(ProtocolV1),
×
2071
                                },
×
2072
                        )
×
2073
                        if errors.Is(err, sql.ErrNoRows) {
×
2074
                                continue
×
2075
                        } else if err != nil {
×
2076
                                return fmt.Errorf("unable to fetch channel: %w",
×
2077
                                        err)
×
2078
                        }
×
UNCOV
2079

×
2080
                        node1, node2, err := buildNodes(
×
2081
                                ctx, db, row.Node, row.Node_2,
×
2082
                        )
×
2083
                        if err != nil {
×
2084
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2085
                                        err)
×
2086
                        }
×
UNCOV
2087

×
2088
                        edge, err := getAndBuildEdgeInfo(
×
2089
                                ctx, db, s.cfg.ChainHash, row.Channel.ID,
×
2090
                                row.Channel, node1.PubKeyBytes,
×
2091
                                node2.PubKeyBytes,
×
2092
                        )
2093
                        if err != nil {
×
2094
                                return fmt.Errorf("unable to build "+
×
2095
                                        "channel info: %w", err)
×
2096
                        }
×
UNCOV
2097

×
2098
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2099
                        if err != nil {
×
2100
                                return fmt.Errorf("unable to extract channel "+
2101
                                        "policies: %w", err)
×
2102
                        }
×
UNCOV
2103

×
2104
                        p1, p2, err := getAndBuildChanPolicies(
×
2105
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2106
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2107
                        )
×
2108
                        if err != nil {
×
2109
                                return fmt.Errorf("unable to build channel "+
×
2110
                                        "policies: %w", err)
2111
                        }
×
UNCOV
2112

×
2113
                        edges = append(edges, ChannelEdge{
×
2114
                                Info:    edge,
×
2115
                                Policy1: p1,
×
2116
                                Policy2: p2,
2117
                                Node1:   node1,
×
2118
                                Node2:   node2,
×
2119
                        })
×
UNCOV
2120
                }
×
UNCOV
2121

×
2122
                return nil
×
2123
        }, func() {
×
2124
                edges = nil
×
2125
        })
2126
        if err != nil {
×
2127
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2128
        }
×
UNCOV
2129

×
2130
        return edges, nil
×
UNCOV
2131
}
×
UNCOV
2132

×
2133
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2134
// ID's that we don't know and are not known zombies of the passed set. In other
UNCOV
2135
// words, we perform a set difference of our set of chan ID's and the ones
×
UNCOV
2136
// passed in. This method can be used by callers to determine the set of
×
UNCOV
2137
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
×
UNCOV
2138
// known zombies is also returned.
×
UNCOV
2139
//
×
UNCOV
2140
// NOTE: part of the V1Store interface.
×
UNCOV
2141
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
×
2142
        []ChannelUpdateInfo, error) {
2143

×
2144
        var (
2145
                ctx          = context.TODO()
2146
                newChanIDs   []uint64
2147
                knownZombies []ChannelUpdateInfo
2148
        )
2149
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
2150
                for _, chanInfo := range chansInfo {
2151
                        channelID := chanInfo.ShortChannelID.ToUint64()
2152
                        var chanIDB [8]byte
2153
                        byteOrder.PutUint64(chanIDB[:], channelID)
2154

2155
                        // TODO(elle): potentially optimize this by using
×
2156
                        //  sqlc.slice() once that works for both SQLite and
×
2157
                        //  Postgres.
×
2158
                        _, err := db.GetChannelBySCID(
×
2159
                                ctx, sqlc.GetChannelBySCIDParams{
×
2160
                                        Version: int16(ProtocolV1),
×
2161
                                        Scid:    chanIDB[:],
×
2162
                                },
×
2163
                        )
×
2164
                        if err == nil {
×
2165
                                continue
×
2166
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
2167
                                return fmt.Errorf("unable to fetch channel: %w",
×
2168
                                        err)
×
2169
                        }
×
UNCOV
2170

×
2171
                        isZombie, err := db.IsZombieChannel(
×
2172
                                ctx, sqlc.IsZombieChannelParams{
×
2173
                                        Scid:    chanIDB[:],
×
2174
                                        Version: int16(ProtocolV1),
×
2175
                                },
×
2176
                        )
×
2177
                        if err != nil {
×
2178
                                return fmt.Errorf("unable to fetch zombie "+
×
2179
                                        "channel: %w", err)
×
2180
                        }
×
UNCOV
2181

×
2182
                        if isZombie {
×
2183
                                knownZombies = append(knownZombies, chanInfo)
2184

×
2185
                                continue
×
UNCOV
2186
                        }
×
UNCOV
2187

×
2188
                        newChanIDs = append(newChanIDs, channelID)
×
UNCOV
2189
                }
×
UNCOV
2190

×
2191
                return nil
×
2192
        }, func() {
×
2193
                newChanIDs = nil
×
2194
                knownZombies = nil
2195
        })
×
2196
        if err != nil {
×
2197
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2198
        }
×
2199

2200
        return newChanIDs, knownZombies, nil
UNCOV
2201
}
×
2202

2203
// forEachNodeDirectedChannel iterates through all channels of a given
UNCOV
2204
// node, executing the passed callback on the directed edge representing the
×
UNCOV
2205
// channel and its incoming policy. If the node is not found, no error is
×
UNCOV
2206
// returned.
×
UNCOV
2207
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
×
2208
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2209

×
2210
        toNodeCallback := func() route.Vertex {
×
2211
                return nodePub
×
2212
        }
UNCOV
2213

×
2214
        dbID, err := db.GetNodeIDByPubKey(
2215
                ctx, sqlc.GetNodeIDByPubKeyParams{
2216
                        Version: int16(ProtocolV1),
2217
                        PubKey:  nodePub[:],
2218
                },
2219
        )
2220
        if errors.Is(err, sql.ErrNoRows) {
2221
                return nil
2222
        } else if err != nil {
2223
                return fmt.Errorf("unable to fetch node: %w", err)
2224
        }
UNCOV
2225

×
2226
        rows, err := db.ListChannelsByNodeID(
×
2227
                ctx, sqlc.ListChannelsByNodeIDParams{
×
2228
                        Version: int16(ProtocolV1),
×
2229
                        NodeID1: dbID,
×
2230
                },
×
2231
        )
×
2232
        if err != nil {
×
2233
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2234
        }
×
UNCOV
2235

×
UNCOV
2236
        // Exit early if there are no channels for this node so we don't
×
UNCOV
2237
        // do the unnecessary feature fetching.
×
2238
        if len(rows) == 0 {
×
2239
                return nil
×
2240
        }
UNCOV
2241

×
2242
        features, err := getNodeFeatures(ctx, db, dbID)
2243
        if err != nil {
2244
                return fmt.Errorf("unable to fetch node features: %w", err)
2245
        }
2246

2247
        for _, row := range rows {
2248
                node1, node2, err := buildNodeVertices(
2249
                        row.Node1Pubkey, row.Node2Pubkey,
2250
                )
2251
                if err != nil {
2252
                        return fmt.Errorf("unable to build node vertices: %w",
2253
                                err)
2254
                }
2255

2256
                edge := buildCacheableChannelInfo(row.Channel, node1, node2)
×
2257

×
2258
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2259
                if err != nil {
×
2260
                        return err
×
2261
                }
×
UNCOV
2262

×
2263
                var p1, p2 *models.CachedEdgePolicy
×
2264
                if dbPol1 != nil {
×
2265
                        policy1, err := buildChanPolicy(
×
2266
                                *dbPol1, edge.ChannelID, nil, node2, true,
×
2267
                        )
×
2268
                        if err != nil {
×
2269
                                return err
×
2270
                        }
×
UNCOV
2271

×
2272
                        p1 = models.NewCachedPolicy(policy1)
×
UNCOV
2273
                }
×
2274
                if dbPol2 != nil {
×
2275
                        policy2, err := buildChanPolicy(
×
2276
                                *dbPol2, edge.ChannelID, nil, node1, false,
×
2277
                        )
×
2278
                        if err != nil {
×
2279
                                return err
×
2280
                        }
×
UNCOV
2281

×
2282
                        p2 = models.NewCachedPolicy(policy2)
×
UNCOV
2283
                }
×
2284

UNCOV
2285
                // Determine the outgoing and incoming policy for this
×
UNCOV
2286
                // channel and node combo.
×
2287
                outPolicy, inPolicy := p1, p2
×
2288
                if p1 != nil && node2 == nodePub {
×
2289
                        outPolicy, inPolicy = p2, p1
×
2290
                } else if p2 != nil && node1 != nodePub {
×
2291
                        outPolicy, inPolicy = p2, p1
2292
                }
×
UNCOV
2293

×
2294
                var cachedInPolicy *models.CachedEdgePolicy
×
2295
                if inPolicy != nil {
×
2296
                        cachedInPolicy = inPolicy
×
2297
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
2298
                        cachedInPolicy.ToNodeFeatures = features
×
2299
                }
UNCOV
2300

×
2301
                directedChannel := &DirectedChannel{
×
2302
                        ChannelID:    edge.ChannelID,
×
2303
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
2304
                        OtherNode:    edge.NodeKey2Bytes,
×
2305
                        Capacity:     edge.Capacity,
2306
                        OutPolicySet: outPolicy != nil,
×
2307
                        InPolicy:     cachedInPolicy,
2308
                }
2309
                if outPolicy != nil {
×
2310
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
2311
                                directedChannel.InboundFee = fee
×
2312
                        })
×
UNCOV
2313
                }
×
UNCOV
2314

×
2315
                if nodePub == edge.NodeKey2Bytes {
×
2316
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
2317
                }
×
UNCOV
2318

×
2319
                if err := cb(directedChannel); err != nil {
2320
                        return err
2321
                }
UNCOV
2322
        }
×
UNCOV
2323

×
2324
        return nil
×
UNCOV
2325
}
×
UNCOV
2326

×
2327
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
UNCOV
2328
// and executes the provided callback for each node.
×
UNCOV
2329
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
×
2330
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
2331

×
2332
        var lastID int64
×
2333

×
2334
        for {
×
2335
                nodes, err := db.ListNodeIDsAndPubKeys(
×
2336
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
2337
                                Version: int16(ProtocolV1),
×
2338
                                ID:      lastID,
×
2339
                                Limit:   pageSize,
×
2340
                        },
×
2341
                )
2342
                if err != nil {
×
2343
                        return fmt.Errorf("unable to fetch nodes: %w", err)
2344
                }
2345

2346
                if len(nodes) == 0 {
2347
                        break
2348
                }
2349

2350
                for _, node := range nodes {
2351
                        var pub route.Vertex
×
2352
                        copy(pub[:], node.PubKey)
×
2353

×
2354
                        if err := cb(node.ID, pub); err != nil {
×
2355
                                return fmt.Errorf("forEachNodeCacheable "+
×
2356
                                        "callback failed for node(id=%d): %w",
×
2357
                                        node.ID, err)
×
2358
                        }
×
UNCOV
2359

×
2360
                        lastID = node.ID
×
UNCOV
2361
                }
×
UNCOV
2362
        }
×
UNCOV
2363

×
2364
        return nil
×
UNCOV
2365
}
×
2366

UNCOV
2367
// forEachNodeChannel iterates through all channels of a node, executing
×
UNCOV
2368
// the passed callback on each. The call-back is provided with the channel's
×
UNCOV
2369
// edge information, the outgoing policy and the incoming policy for the
×
UNCOV
2370
// channel and node combo.
×
2371
func forEachNodeChannel(ctx context.Context, db SQLQueries,
UNCOV
2372
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
×
UNCOV
2373
                *models.ChannelEdgePolicy,
×
2374
                *models.ChannelEdgePolicy) error) error {
×
2375

×
2376
        // Get all the V1 channels for this node.Add commentMore actions
×
2377
        rows, err := db.ListChannelsByNodeID(
×
2378
                ctx, sqlc.ListChannelsByNodeIDParams{
2379
                        Version: int16(ProtocolV1),
2380
                        NodeID1: id,
×
2381
                },
×
2382
        )
×
2383
        if err != nil {
×
2384
                return fmt.Errorf("unable to fetch channels: %w", err)
×
2385
        }
×
UNCOV
2386

×
UNCOV
2387
        // Call the call-back for each channel and its known policies.
×
2388
        for _, row := range rows {
×
2389
                node1, node2, err := buildNodeVertices(
×
2390
                        row.Node1Pubkey, row.Node2Pubkey,
×
2391
                )
×
2392
                if err != nil {
×
2393
                        return fmt.Errorf("unable to build node vertices: %w",
2394
                                err)
×
2395
                }
×
2396

2397
                edge, err := getAndBuildEdgeInfo(
2398
                        ctx, db, chain, row.Channel.ID, row.Channel, node1,
×
2399
                        node2,
×
2400
                )
×
2401
                if err != nil {
×
2402
                        return fmt.Errorf("unable to build channel info: %w",
×
2403
                                err)
2404
                }
×
2405

2406
                dbPol1, dbPol2, err := extractChannelPolicies(row)
2407
                if err != nil {
2408
                        return fmt.Errorf("unable to extract channel "+
×
2409
                                "policies: %w", err)
×
2410
                }
×
UNCOV
2411

×
2412
                p1, p2, err := getAndBuildChanPolicies(
×
2413
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2414
                )
×
2415
                if err != nil {
2416
                        return fmt.Errorf("unable to build channel "+
×
2417
                                "policies: %w", err)
2418
                }
2419

2420
                // Determine the outgoing and incoming policy for this
2421
                // channel and node combo.
2422
                p1ToNode := row.Channel.NodeID2
2423
                p2ToNode := row.Channel.NodeID1
2424
                outPolicy, inPolicy := p1, p2
2425
                if (p1 != nil && p1ToNode == id) ||
×
2426
                        (p2 != nil && p2ToNode != id) {
×
2427

×
2428
                        outPolicy, inPolicy = p2, p1
×
2429
                }
×
UNCOV
2430

×
2431
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
2432
                        return err
×
2433
                }
×
UNCOV
2434
        }
×
UNCOV
2435

×
2436
        return nil
×
UNCOV
2437
}
×
2438

UNCOV
2439
// updateChanEdgePolicy upserts the channel policy info we have stored for
×
UNCOV
2440
// a channel we already know of.
×
UNCOV
2441
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
×
UNCOV
2442
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
×
2443
        error) {
2444

×
2445
        var (
×
2446
                node1Pub, node2Pub route.Vertex
×
2447
                isNode1            bool
2448
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
2449
        )
2450

2451
        // Check that this edge policy refers to a channel that we already
2452
        // know of. We do this explicitly so that we can return the appropriate
2453
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
2454
        // abort the transaction which would abort the entire batch.
2455
        dbChan, err := tx.GetChannelAndNodesBySCID(
2456
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
2457
                        Scid:    chanIDB[:],
×
2458
                        Version: int16(ProtocolV1),
×
2459
                },
×
2460
        )
×
2461
        if errors.Is(err, sql.ErrNoRows) {
×
2462
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
2463
        } else if err != nil {
×
2464
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
2465
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
2466
        }
UNCOV
2467

×
2468
        copy(node1Pub[:], dbChan.Node1PubKey)
×
2469
        copy(node2Pub[:], dbChan.Node2PubKey)
×
2470

×
2471
        // Figure out which node this edge is from.
×
2472
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
2473
        nodeID := dbChan.NodeID1
×
2474
        if !isNode1 {
2475
                nodeID = dbChan.NodeID2
×
2476
        }
×
UNCOV
2477

×
2478
        var (
×
2479
                inboundBase sql.NullInt64
×
2480
                inboundRate sql.NullInt64
2481
        )
×
2482
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
2483
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
2484
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
2485
        })
2486

2487
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
2488
                Version:     int16(ProtocolV1),
2489
                ChannelID:   dbChan.ID,
2490
                NodeID:      nodeID,
2491
                Timelock:    int32(edge.TimeLockDelta),
2492
                FeePpm:      int64(edge.FeeProportionalMillionths),
2493
                BaseFeeMsat: int64(edge.FeeBaseMSat),
2494
                MinHtlcMsat: int64(edge.MinHTLC),
2495
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
2496
                Disabled: sql.NullBool{
2497
                        Valid: true,
×
2498
                        Bool:  edge.IsDisabled(),
×
2499
                },
×
2500
                MaxHtlcMsat: sql.NullInt64{
×
2501
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
2502
                        Int64: int64(edge.MaxHTLC),
×
2503
                },
×
2504
                InboundBaseFeeMsat:      inboundBase,
×
2505
                InboundFeeRateMilliMsat: inboundRate,
×
2506
                Signature:               edge.SigBytes,
×
2507
        })
×
2508
        if err != nil {
×
2509
                return node1Pub, node2Pub, isNode1,
×
2510
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
2511
        }
×
UNCOV
2512

×
UNCOV
2513
        // Convert the flat extra opaque data into a map of TLV types to
×
UNCOV
2514
        // values.
×
2515
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
2516
        if err != nil {
×
2517
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
2518
                        "marshal extra opaque data: %w", err)
×
2519
        }
×
UNCOV
2520

×
UNCOV
2521
        // Update the channel policy's extra signed fields.
×
2522
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
2523
        if err != nil {
×
2524
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
2525
                        "policy extra TLVs: %w", err)
×
2526
        }
×
UNCOV
2527

×
2528
        return node1Pub, node2Pub, isNode1, nil
×
UNCOV
2529
}
×
2530

UNCOV
2531
// getNodeByPubKey attempts to look up a target node by its public key.
×
UNCOV
2532
func getNodeByPubKey(ctx context.Context, db SQLQueries,
×
2533
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
2534

×
2535
        dbNode, err := db.GetNodeByPubKey(
×
2536
                ctx, sqlc.GetNodeByPubKeyParams{
×
2537
                        Version: int16(ProtocolV1),
×
2538
                        PubKey:  pubKey[:],
2539
                },
×
2540
        )
×
2541
        if errors.Is(err, sql.ErrNoRows) {
×
2542
                return 0, nil, ErrGraphNodeNotFound
×
2543
        } else if err != nil {
×
2544
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
2545
        }
×
2546

2547
        node, err := buildNode(ctx, db, &dbNode)
×
2548
        if err != nil {
×
2549
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
2550
        }
×
UNCOV
2551

×
2552
        return dbNode.ID, node, nil
UNCOV
2553
}
×
2554

2555
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
UNCOV
2556
// provided database channel row and the public keys of the two nodes
×
UNCOV
2557
// involved in the channel.
×
UNCOV
2558
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
×
2559
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
2560

×
2561
        return &models.CachedEdgeInfo{
×
2562
                ChannelID:     byteOrder.Uint64(dbChan.Scid),
×
2563
                NodeKey1Bytes: node1Pub,
×
2564
                NodeKey2Bytes: node2Pub,
×
2565
                Capacity:      btcutil.Amount(dbChan.Capacity.Int64),
×
2566
        }
×
2567
}
×
UNCOV
2568

×
2569
// buildNode constructs a LightningNode instance from the given database node
UNCOV
2570
// record. The node's features, addresses and extra signed fields are also
×
UNCOV
2571
// fetched from the database and set on the node.
×
UNCOV
2572
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
×
2573
        *models.LightningNode, error) {
×
2574

2575
        if dbNode.Version != int16(ProtocolV1) {
×
2576
                return nil, fmt.Errorf("unsupported node version: %d",
2577
                        dbNode.Version)
2578
        }
2579

2580
        var pub [33]byte
2581
        copy(pub[:], dbNode.PubKey)
2582

2583
        node := &models.LightningNode{
×
2584
                PubKeyBytes: pub,
×
2585
                Features:    lnwire.EmptyFeatureVector(),
×
2586
                LastUpdate:  time.Unix(0, 0),
×
2587
        }
×
2588

2589
        if len(dbNode.Signature) == 0 {
×
2590
                return node, nil
×
2591
        }
×
UNCOV
2592

×
2593
        node.HaveNodeAnnouncement = true
×
2594
        node.AuthSigBytes = dbNode.Signature
×
2595
        node.Alias = dbNode.Alias.String
×
2596
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
2597

×
2598
        var err error
×
2599
        if dbNode.Color.Valid {
×
2600
                node.Color, err = DecodeHexColor(dbNode.Color.String)
2601
                if err != nil {
×
2602
                        return nil, fmt.Errorf("unable to decode color: %w",
×
2603
                                err)
×
2604
                }
×
UNCOV
2605
        }
×
UNCOV
2606

×
UNCOV
2607
        // Fetch the node's features.
×
2608
        node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
×
2609
        if err != nil {
×
2610
                return nil, fmt.Errorf("unable to fetch node(%d) "+
2611
                        "features: %w", dbNode.ID, err)
2612
        }
UNCOV
2613

×
UNCOV
2614
        // Fetch the node's addresses.
×
2615
        _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
×
2616
        if err != nil {
2617
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2618
                        "addresses: %w", dbNode.ID, err)
×
2619
        }
×
UNCOV
2620

×
2621
        // Fetch the node's extra signed fields.
2622
        extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
×
2623
        if err != nil {
×
2624
                return nil, fmt.Errorf("unable to fetch node(%d) "+
×
2625
                        "extra signed fields: %w", dbNode.ID, err)
×
2626
        }
×
UNCOV
2627

×
2628
        recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
×
2629
        if err != nil {
×
2630
                return nil, fmt.Errorf("unable to serialize extra signed "+
2631
                        "fields: %w", err)
×
2632
        }
×
UNCOV
2633

×
2634
        if len(recs) != 0 {
×
2635
                node.ExtraOpaqueData = recs
×
2636
        }
×
2637

2638
        return node, nil
×
UNCOV
2639
}
×
UNCOV
2640

×
UNCOV
2641
// getNodeFeatures fetches the feature bits and constructs the feature vector
×
UNCOV
2642
// for a node with the given DB ID.
×
UNCOV
2643
func getNodeFeatures(ctx context.Context, db SQLQueries,
×
2644
        nodeID int64) (*lnwire.FeatureVector, error) {
×
2645

×
2646
        rows, err := db.GetNodeFeatures(ctx, nodeID)
2647
        if err != nil {
×
2648
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
2649
                        nodeID, err)
×
2650
        }
×
UNCOV
2651

×
2652
        features := lnwire.EmptyFeatureVector()
×
2653
        for _, feature := range rows {
×
2654
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
2655
        }
×
2656

2657
        return features, nil
×
2658
}
2659

2660
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
2661
// given DB ID.
UNCOV
2662
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
×
2663
        nodeID int64) (map[uint64][]byte, error) {
×
2664

×
2665
        fields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
2666
        if err != nil {
×
2667
                return nil, fmt.Errorf("unable to get node(%d) extra "+
×
2668
                        "signed fields: %w", nodeID, err)
2669
        }
×
UNCOV
2670

×
2671
        extraFields := make(map[uint64][]byte)
×
2672
        for _, field := range fields {
×
2673
                extraFields[uint64(field.Type)] = field.Value
×
2674
        }
×
2675

2676
        return extraFields, nil
×
UNCOV
2677
}
×
UNCOV
2678

×
UNCOV
2679
// upsertNode upserts the node record into the database. If the node already
×
UNCOV
2680
// exists, then the node's information is updated. If the node doesn't exist,
×
UNCOV
2681
// then a new node is created. The node's features, addresses and extra TLV
×
UNCOV
2682
// types are also updated. The node's DB ID is returned.
×
UNCOV
2683
func upsertNode(ctx context.Context, db SQLQueries,
×
2684
        node *models.LightningNode) (int64, error) {
×
2685

×
2686
        params := sqlc.UpsertNodeParams{
×
2687
                Version: int16(ProtocolV1),
×
2688
                PubKey:  node.PubKeyBytes[:],
2689
        }
2690

×
2691
        if node.HaveNodeAnnouncement {
×
2692
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
2693
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
2694
                params.Alias = sqldb.SQLStr(node.Alias)
×
2695
                params.Signature = node.AuthSigBytes
×
2696
        }
×
2697

2698
        nodeID, err := db.UpsertNode(ctx, params)
2699
        if err != nil {
×
2700
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
2701
                        err)
2702
        }
2703

2704
        // We can exit here if we don't have the announcement yet.
2705
        if !node.HaveNodeAnnouncement {
×
2706
                return nodeID, nil
×
2707
        }
×
UNCOV
2708

×
UNCOV
2709
        // Update the node's features.
×
2710
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
2711
        if err != nil {
×
2712
                return 0, fmt.Errorf("inserting node features: %w", err)
×
2713
        }
×
UNCOV
2714

×
UNCOV
2715
        // Update the node's addresses.
×
2716
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
2717
        if err != nil {
×
2718
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
2719
        }
×
2720

UNCOV
2721
        // Convert the flat extra opaque data into a map of TLV types to
×
UNCOV
2722
        // values.
×
2723
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
2724
        if err != nil {
2725
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
2726
                        err)
×
2727
        }
×
UNCOV
2728

×
UNCOV
2729
        // Update the node's extra signed fields.
×
2730
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
2731
        if err != nil {
×
2732
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
2733
        }
×
2734

2735
        return nodeID, nil
×
2736
}
2737

2738
// upsertNodeFeatures updates the node's features node_features table. This
UNCOV
2739
// includes deleting any feature bits no longer present and inserting any new
×
2740
// feature bits. If the feature bit does not yet exist in the features table,
2741
// then an entry is created in that table first.
2742
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
2743
        features *lnwire.FeatureVector) error {
2744

2745
        // Get any existing features for the node.
2746
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
2747
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
2748
                return err
2749
        }
×
UNCOV
2750

×
UNCOV
2751
        // Copy the nodes latest set of feature bits.
×
2752
        newFeatures := make(map[int32]struct{})
×
2753
        if features != nil {
×
2754
                for feature := range features.Features() {
×
2755
                        newFeatures[int32(feature)] = struct{}{}
×
2756
                }
×
UNCOV
2757
        }
×
UNCOV
2758

×
UNCOV
2759
        // For any current feature that already exists in the DB, remove it from
×
UNCOV
2760
        // the in-memory map. For any existing feature that does not exist in
×
2761
        // the in-memory map, delete it from the database.
2762
        for _, feature := range existingFeatures {
2763
                // The feature is still present, so there are no updates to be
×
2764
                // made.
×
2765
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
2766
                        delete(newFeatures, feature.FeatureBit)
×
2767
                        continue
×
UNCOV
2768
                }
×
UNCOV
2769

×
UNCOV
2770
                // The feature is no longer present, so we remove it from the
×
2771
                // database.
2772
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
2773
                        NodeID:     nodeID,
×
2774
                        FeatureBit: feature.FeatureBit,
×
2775
                })
×
2776
                if err != nil {
×
2777
                        return fmt.Errorf("unable to delete node(%d) "+
×
2778
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
2779
                                err)
×
2780
                }
UNCOV
2781
        }
×
UNCOV
2782

×
UNCOV
2783
        // Any remaining entries in newFeatures are new features that need to be
×
UNCOV
2784
        // added to the database for the first time.
×
2785
        for feature := range newFeatures {
×
2786
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
2787
                        NodeID:     nodeID,
×
2788
                        FeatureBit: feature,
×
2789
                })
×
2790
                if err != nil {
×
2791
                        return fmt.Errorf("unable to insert node(%d) "+
×
2792
                                "feature(%v): %w", nodeID, feature, err)
×
2793
                }
×
2794
        }
2795

2796
        return nil
UNCOV
2797
}
×
UNCOV
2798

×
UNCOV
2799
// fetchNodeFeatures fetches the features for a node with the given public key.
×
UNCOV
2800
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
×
2801
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
2802

×
2803
        rows, err := queries.GetNodeFeaturesByPubKey(
×
2804
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
2805
                        PubKey:  nodePub[:],
2806
                        Version: int16(ProtocolV1),
×
2807
                },
×
2808
        )
×
2809
        if err != nil {
2810
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
2811
                        nodePub, err)
×
2812
        }
2813

2814
        features := lnwire.EmptyFeatureVector()
2815
        for _, bit := range rows {
2816
                features.Set(lnwire.FeatureBit(bit))
2817
        }
UNCOV
2818

×
2819
        return features, nil
×
UNCOV
2820
}
×
UNCOV
2821

×
UNCOV
2822
// dbAddressType is an enum type that represents the different address types
×
UNCOV
2823
// that we store in the node_addresses table. The address type determines how
×
UNCOV
2824
// the address is to be serialised/deserialize.
×
UNCOV
2825
type dbAddressType uint8
×
UNCOV
2826

×
UNCOV
2827
const (
×
UNCOV
2828
        addressTypeIPv4   dbAddressType = 1
×
UNCOV
2829
        addressTypeIPv6   dbAddressType = 2
×
UNCOV
2830
        addressTypeTorV2  dbAddressType = 3
×
UNCOV
2831
        addressTypeTorV3  dbAddressType = 4
×
UNCOV
2832
        addressTypeOpaque dbAddressType = math.MaxInt8
×
UNCOV
2833
)
×
UNCOV
2834

×
UNCOV
2835
// upsertNodeAddresses updates the node's addresses in the database. This
×
UNCOV
2836
// includes deleting any existing addresses and inserting the new set of
×
UNCOV
2837
// addresses. The deletion is necessary since the ordering of the addresses may
×
UNCOV
2838
// change, and we need to ensure that the database reflects the latest set of
×
UNCOV
2839
// addresses so that at the time of reconstructing the node announcement, the
×
UNCOV
2840
// order is preserved and the signature over the message remains valid.
×
UNCOV
2841
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
×
2842
        addresses []net.Addr) error {
2843

×
2844
        // Delete any existing addresses for the node. This is required since
×
2845
        // even if the new set of addresses is the same, the ordering may have
×
2846
        // changed for a given address type.
×
2847
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
2848
        if err != nil {
×
2849
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
2850
                        nodeID, err)
×
2851
        }
×
2852

UNCOV
2853
        // Copy the nodes latest set of addresses.
×
2854
        newAddresses := map[dbAddressType][]string{
×
2855
                addressTypeIPv4:   {},
×
2856
                addressTypeIPv6:   {},
×
2857
                addressTypeTorV2:  {},
×
2858
                addressTypeTorV3:  {},
×
2859
                addressTypeOpaque: {},
×
2860
        }
×
2861
        addAddr := func(t dbAddressType, addr net.Addr) {
2862
                newAddresses[t] = append(newAddresses[t], addr.String())
×
2863
        }
×
UNCOV
2864

×
2865
        for _, address := range addresses {
×
2866
                switch addr := address.(type) {
×
2867
                case *net.TCPAddr:
×
2868
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
2869
                                addAddr(addressTypeIPv4, addr)
×
2870
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
2871
                                addAddr(addressTypeIPv6, addr)
×
2872
                        } else {
×
2873
                                return fmt.Errorf("unhandled IP address: %v",
×
2874
                                        addr)
×
2875
                        }
×
UNCOV
2876

×
2877
                case *tor.OnionAddr:
×
2878
                        switch len(addr.OnionService) {
×
2879
                        case tor.V2Len:
×
2880
                                addAddr(addressTypeTorV2, addr)
×
2881
                        case tor.V3Len:
×
2882
                                addAddr(addressTypeTorV3, addr)
×
2883
                        default:
×
2884
                                return fmt.Errorf("invalid length for a tor " +
×
2885
                                        "address")
×
UNCOV
2886
                        }
×
2887

2888
                case *lnwire.OpaqueAddrs:
2889
                        addAddr(addressTypeOpaque, addr)
UNCOV
2890

×
2891
                default:
×
2892
                        return fmt.Errorf("unhandled address type: %T", addr)
×
UNCOV
2893
                }
×
UNCOV
2894
        }
×
2895

2896
        // Any remaining entries in newAddresses are new addresses that need to
UNCOV
2897
        // be added to the database for the first time.
×
2898
        for addrType, addrList := range newAddresses {
×
2899
                for position, addr := range addrList {
×
2900
                        err := db.InsertNodeAddress(
×
2901
                                ctx, sqlc.InsertNodeAddressParams{
×
2902
                                        NodeID:   nodeID,
2903
                                        Type:     int16(addrType),
×
2904
                                        Address:  addr,
2905
                                        Position: int32(position),
2906
                                },
2907
                        )
2908
                        if err != nil {
×
2909
                                return fmt.Errorf("unable to insert "+
×
2910
                                        "node(%d) address(%v): %w", nodeID,
×
2911
                                        addr, err)
×
2912
                        }
×
UNCOV
2913
                }
×
UNCOV
2914
        }
×
UNCOV
2915

×
2916
        return nil
×
UNCOV
2917
}
×
UNCOV
2918

×
UNCOV
2919
// getNodeAddresses fetches the addresses for a node with the given public key.
×
UNCOV
2920
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
×
2921
        []net.Addr, error) {
2922

×
2923
        // GetNodeAddressesByPubKey ensures that the addresses for a given type
×
2924
        // are returned in the same order as they were inserted.
×
2925
        rows, err := db.GetNodeAddressesByPubKey(
×
2926
                ctx, sqlc.GetNodeAddressesByPubKeyParams{
2927
                        Version: int16(ProtocolV1),
×
2928
                        PubKey:  nodePub,
2929
                },
2930
        )
2931
        if err != nil {
2932
                return false, nil, err
2933
        }
UNCOV
2934

×
UNCOV
2935
        // GetNodeAddressesByPubKey uses a left join so there should always be
×
UNCOV
2936
        // at least one row returned if the node exists even if it has no
×
UNCOV
2937
        // addresses.
×
2938
        if len(rows) == 0 {
×
2939
                return false, nil, nil
×
2940
        }
×
UNCOV
2941

×
2942
        addresses := make([]net.Addr, 0, len(rows))
×
2943
        for _, addr := range rows {
2944
                if !(addr.Type.Valid && addr.Address.Valid) {
2945
                        continue
2946
                }
2947

2948
                address := addr.Address.String
×
2949

×
2950
                switch dbAddressType(addr.Type.Int16) {
×
2951
                case addressTypeIPv4:
×
2952
                        tcp, err := net.ResolveTCPAddr("tcp4", address)
×
2953
                        if err != nil {
×
2954
                                return false, nil, nil
2955
                        }
×
2956
                        tcp.IP = tcp.IP.To4()
×
2957

×
2958
                        addresses = append(addresses, tcp)
×
UNCOV
2959

×
2960
                case addressTypeIPv6:
×
2961
                        tcp, err := net.ResolveTCPAddr("tcp6", address)
×
2962
                        if err != nil {
×
2963
                                return false, nil, nil
×
2964
                        }
×
2965
                        addresses = append(addresses, tcp)
×
UNCOV
2966

×
2967
                case addressTypeTorV3, addressTypeTorV2:
2968
                        service, portStr, err := net.SplitHostPort(address)
×
2969
                        if err != nil {
×
2970
                                return false, nil, fmt.Errorf("unable to "+
×
2971
                                        "split tor v3 address: %v",
×
2972
                                        addr.Address)
×
2973
                        }
×
UNCOV
2974

×
2975
                        port, err := strconv.Atoi(portStr)
×
2976
                        if err != nil {
×
2977
                                return false, nil, err
×
2978
                        }
×
UNCOV
2979

×
2980
                        addresses = append(addresses, &tor.OnionAddr{
2981
                                OnionService: service,
2982
                                Port:         port,
2983
                        })
×
UNCOV
2984

×
2985
                case addressTypeOpaque:
×
2986
                        opaque, err := hex.DecodeString(address)
×
2987
                        if err != nil {
×
2988
                                return false, nil, fmt.Errorf("unable to "+
2989
                                        "decode opaque address: %v", addr)
2990
                        }
×
UNCOV
2991

×
2992
                        addresses = append(addresses, &lnwire.OpaqueAddrs{
×
2993
                                Payload: opaque,
×
2994
                        })
×
2995

2996
                default:
2997
                        return false, nil, fmt.Errorf("unknown address "+
×
2998
                                "type: %v", addr.Type)
×
UNCOV
2999
                }
×
UNCOV
3000
        }
×
UNCOV
3001

×
3002
        return true, addresses, nil
UNCOV
3003
}
×
UNCOV
3004

×
UNCOV
3005
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
×
UNCOV
3006
// database. This includes updating any existing types, inserting any new types,
×
UNCOV
3007
// and deleting any types that are no longer present.
×
3008
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3009
        nodeID int64, extraFields map[uint64][]byte) error {
×
3010

×
3011
        // Get any existing extra signed fields for the node.
×
3012
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
3013
        if err != nil {
×
3014
                return err
3015
        }
3016

3017
        // Make a lookup map of the existing field types so that we can use it
3018
        // to keep track of any fields we should delete.
3019
        m := make(map[uint64]bool)
×
3020
        for _, field := range existingFields {
×
3021
                m[uint64(field.Type)] = true
×
3022
        }
×
UNCOV
3023

×
UNCOV
3024
        // For all the new fields, we'll upsert them and remove them from the
×
UNCOV
3025
        // map of existing fields.
×
3026
        for tlvType, value := range extraFields {
3027
                err = db.UpsertNodeExtraType(
×
3028
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3029
                                NodeID: nodeID,
×
3030
                                Type:   int64(tlvType),
×
3031
                                Value:  value,
3032
                        },
×
3033
                )
3034
                if err != nil {
3035
                        return fmt.Errorf("unable to upsert node(%d) extra "+
3036
                                "signed field(%v): %w", nodeID, tlvType, err)
3037
                }
UNCOV
3038

×
UNCOV
3039
                // Remove the field from the map of existing fields if it was
×
UNCOV
3040
                // present.
×
3041
                delete(m, tlvType)
×
UNCOV
3042
        }
×
UNCOV
3043

×
UNCOV
3044
        // For all the fields that are left in the map of existing fields, we'll
×
3045
        // delete them as they are no longer present in the new set of fields.
3046
        for tlvType := range m {
×
3047
                err = db.DeleteExtraNodeType(
×
3048
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3049
                                NodeID: nodeID,
×
3050
                                Type:   int64(tlvType),
3051
                        },
×
3052
                )
3053
                if err != nil {
3054
                        return fmt.Errorf("unable to delete node(%d) extra "+
3055
                                "signed field(%v): %w", nodeID, tlvType, err)
3056
                }
3057
        }
3058

3059
        return nil
×
UNCOV
3060
}
×
UNCOV
3061

×
UNCOV
3062
// srcNodeInfo holds the information about the source node of the graph.
×
UNCOV
3063
type srcNodeInfo struct {
×
UNCOV
3064
        // id is the DB level ID of the source node entry in the "nodes" table.
×
UNCOV
3065
        id int64
×
UNCOV
3066

×
UNCOV
3067
        // pub is the public key of the source node.
×
UNCOV
3068
        pub route.Vertex
×
UNCOV
3069
}
×
UNCOV
3070

×
UNCOV
3071
// getSourceNode returns the DB node ID and pub key of the source node for the
×
3072
// specified protocol version.
UNCOV
3073
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
×
3074
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3075

×
3076
        s.srcNodeMu.Lock()
×
3077
        defer s.srcNodeMu.Unlock()
×
3078

3079
        // If we already have the source node ID and pub key cached, then
3080
        // return them.
×
3081
        if info, ok := s.srcNodes[version]; ok {
×
3082
                return info.id, info.pub, nil
×
3083
        }
3084

3085
        var pubKey route.Vertex
×
3086

×
3087
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3088
        if err != nil {
×
3089
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
3090
                        err)
3091
        }
×
UNCOV
3092

×
3093
        if len(nodes) == 0 {
×
3094
                return 0, pubKey, ErrSourceNodeNotSet
×
3095
        } else if len(nodes) > 1 {
3096
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
3097
                        "protocol %s found", version)
3098
        }
×
UNCOV
3099

×
3100
        copy(pubKey[:], nodes[0].PubKey)
×
3101

×
3102
        s.srcNodes[version] = &srcNodeInfo{
×
3103
                id:  nodes[0].NodeID,
3104
                pub: pubKey,
3105
        }
×
3106

×
3107
        return nodes[0].NodeID, pubKey, nil
×
UNCOV
3108
}
×
3109

UNCOV
3110
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
×
3111
// This then produces a map from TLV type to value. If the input is not a
3112
// valid TLV stream, then an error is returned.
3113
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
3114
        r := bytes.NewReader(data)
3115

3116
        tlvStream, err := tlv.NewStream()
3117
        if err != nil {
3118
                return nil, err
×
3119
        }
×
UNCOV
3120

×
UNCOV
3121
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
×
UNCOV
3122
        // pass it into the P2P decoding variant.
×
3123
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3124
        if err != nil {
×
3125
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
3126
        }
3127
        if len(parsedTypes) == 0 {
×
3128
                return nil, nil
×
3129
        }
×
UNCOV
3130

×
3131
        records := make(map[uint64][]byte)
×
3132
        for k, v := range parsedTypes {
3133
                records[uint64(k)] = v
3134
        }
3135

3136
        return records, nil
UNCOV
3137
}
×
UNCOV
3138

×
UNCOV
3139
// insertChannel inserts a new channel record into the database.
×
UNCOV
3140
func insertChannel(ctx context.Context, db SQLQueries,
×
3141
        edge *models.ChannelEdgeInfo) error {
×
3142

×
3143
        chanIDB := channelIDToBytes(edge.ChannelID)
3144

3145
        // Make sure that the channel doesn't already exist. We do this
3146
        // explicitly instead of relying on catching a unique constraint error
3147
        // because relying on SQL to throw that error would abort the entire
×
3148
        // batch of transactions.
×
3149
        _, err := db.GetChannelBySCID(
×
3150
                ctx, sqlc.GetChannelBySCIDParams{
×
3151
                        Scid:    chanIDB[:],
×
3152
                        Version: int16(ProtocolV1),
×
3153
                },
×
3154
        )
×
3155
        if err == nil {
×
3156
                return ErrEdgeAlreadyExist
3157
        } else if !errors.Is(err, sql.ErrNoRows) {
3158
                return fmt.Errorf("unable to fetch channel: %w", err)
3159
        }
UNCOV
3160

×
UNCOV
3161
        // Make sure that at least a "shell" entry for each node is present in
×
UNCOV
3162
        // the nodes table.
×
3163
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3164
        if err != nil {
×
3165
                return fmt.Errorf("unable to create shell node: %w", err)
×
3166
        }
×
UNCOV
3167

×
3168
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3169
        if err != nil {
3170
                return fmt.Errorf("unable to create shell node: %w", err)
3171
        }
×
3172

3173
        var capacity sql.NullInt64
3174
        if edge.Capacity != 0 {
3175
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
3176
        }
×
UNCOV
3177

×
3178
        createParams := sqlc.CreateChannelParams{
×
3179
                Version:     int16(ProtocolV1),
×
3180
                Scid:        chanIDB[:],
×
3181
                NodeID1:     node1DBID,
×
3182
                NodeID2:     node2DBID,
×
3183
                Outpoint:    edge.ChannelPoint.String(),
×
3184
                Capacity:    capacity,
×
3185
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3186
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3187
        }
×
3188

3189
        if edge.AuthProof != nil {
×
3190
                proof := edge.AuthProof
×
3191

×
3192
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3193
                createParams.Node2Signature = proof.NodeSig2Bytes
3194
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3195
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
3196
        }
3197

3198
        // Insert the new channel record.
3199
        dbChanID, err := db.CreateChannel(ctx, createParams)
3200
        if err != nil {
3201
                return err
3202
        }
3203

3204
        // Insert any channel features.
3205
        if len(edge.Features) != 0 {
3206
                chanFeatures := lnwire.NewRawFeatureVector()
3207
                err := chanFeatures.Decode(bytes.NewReader(edge.Features))
3208
                if err != nil {
3209
                        return err
3210
                }
3211

3212
                fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
3213
                for feature := range fv.Features() {
3214
                        err = db.InsertChannelFeature(
3215
                                ctx, sqlc.InsertChannelFeatureParams{
3216
                                        ChannelID:  dbChanID,
3217
                                        FeatureBit: int32(feature),
×
3218
                                },
×
3219
                        )
×
3220
                        if err != nil {
×
3221
                                return fmt.Errorf("unable to insert "+
×
3222
                                        "channel(%d) feature(%v): %w", dbChanID,
×
3223
                                        feature, err)
×
3224
                        }
×
UNCOV
3225
                }
×
UNCOV
3226
        }
×
3227

3228
        // Finally, insert any extra TLV fields in the channel announcement.
3229
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3230
        if err != nil {
×
3231
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3232
                        err)
×
3233
        }
×
UNCOV
3234

×
3235
        for tlvType, value := range extra {
×
3236
                err := db.CreateChannelExtraType(
×
3237
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3238
                                ChannelID: dbChanID,
×
3239
                                Type:      int64(tlvType),
3240
                                Value:     value,
×
3241
                        },
×
3242
                )
×
3243
                if err != nil {
×
3244
                        return fmt.Errorf("unable to upsert channel(%d) extra "+
×
3245
                                "signed field(%v): %w", edge.ChannelID,
×
3246
                                tlvType, err)
×
3247
                }
×
UNCOV
3248
        }
×
UNCOV
3249

×
3250
        return nil
×
3251
}
UNCOV
3252

×
UNCOV
3253
// maybeCreateShellNode checks if a shell node entry exists for the
×
UNCOV
3254
// given public key. If it does not exist, then a new shell node entry is
×
UNCOV
3255
// created. The ID of the node is returned. A shell node only has a protocol
×
UNCOV
3256
// version and public key persisted.
×
UNCOV
3257
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
×
3258
        pubKey route.Vertex) (int64, error) {
×
3259

×
3260
        dbNode, err := db.GetNodeByPubKey(
×
3261
                ctx, sqlc.GetNodeByPubKeyParams{
3262
                        PubKey:  pubKey[:],
3263
                        Version: int16(ProtocolV1),
×
3264
                },
×
3265
        )
3266
        // The node exists. Return the ID.
×
3267
        if err == nil {
×
3268
                return dbNode.ID, nil
3269
        } else if !errors.Is(err, sql.ErrNoRows) {
3270
                return 0, err
3271
        }
3272

UNCOV
3273
        // Otherwise, the node does not exist, so we create a shell entry for
×
UNCOV
3274
        // it.
×
3275
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
3276
                Version: int16(ProtocolV1),
×
3277
                PubKey:  pubKey[:],
×
3278
        })
×
3279
        if err != nil {
×
3280
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
3281
        }
×
UNCOV
3282

×
3283
        return id, nil
×
UNCOV
3284
}
×
UNCOV
3285

×
UNCOV
3286
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
×
UNCOV
3287
// the database. This includes deleting any existing types and then inserting
×
3288
// the new types.
3289
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
3290
        chanPolicyID int64, extraFields map[uint64][]byte) error {
3291

×
3292
        // Delete all existing extra signed fields for the channel policy.
3293
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
3294
        if err != nil {
3295
                return fmt.Errorf("unable to delete "+
3296
                        "existing policy extra signed fields for policy %d: %w",
×
3297
                        chanPolicyID, err)
×
3298
        }
×
UNCOV
3299

×
UNCOV
3300
        // Insert all new extra signed fields for the channel policy.
×
3301
        for tlvType, value := range extraFields {
×
3302
                err = db.InsertChanPolicyExtraType(
×
3303
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
3304
                                ChannelPolicyID: chanPolicyID,
×
3305
                                Type:            int64(tlvType),
×
3306
                                Value:           value,
×
3307
                        },
×
3308
                )
×
3309
                if err != nil {
3310
                        return fmt.Errorf("unable to insert "+
3311
                                "channel_policy(%d) extra signed field(%v): %w",
3312
                                chanPolicyID, tlvType, err)
3313
                }
×
UNCOV
3314
        }
×
UNCOV
3315

×
3316
        return nil
UNCOV
3317
}
×
UNCOV
3318

×
UNCOV
3319
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
×
UNCOV
3320
// provided dbChanRow and also fetches any other required information
×
3321
// to construct the edge info.
3322
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
UNCOV
3323
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.Channel, node1,
×
3324
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
3325

×
3326
        fv, extras, err := getChanFeaturesAndExtras(
×
3327
                ctx, db, dbChanID,
×
3328
        )
×
3329
        if err != nil {
×
3330
                return nil, err
×
3331
        }
×
UNCOV
3332

×
3333
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
3334
        if err != nil {
3335
                return nil, err
×
3336
        }
×
UNCOV
3337

×
3338
        var featureBuf bytes.Buffer
×
3339
        if err := fv.Encode(&featureBuf); err != nil {
×
3340
                return nil, fmt.Errorf("unable to encode features: %w", err)
×
3341
        }
UNCOV
3342

×
3343
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3344
        if err != nil {
×
3345
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3346
                        "fields: %w", err)
×
3347
        }
×
3348
        if recs == nil {
×
3349
                recs = make([]byte, 0)
3350
        }
×
UNCOV
3351

×
3352
        var btcKey1, btcKey2 route.Vertex
×
3353
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
3354
        copy(btcKey2[:], dbChan.BitcoinKey2)
3355

×
3356
        channel := &models.ChannelEdgeInfo{
×
3357
                ChainHash:        chain,
×
3358
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
3359
                NodeKey1Bytes:    node1,
3360
                NodeKey2Bytes:    node2,
×
3361
                BitcoinKey1Bytes: btcKey1,
×
3362
                BitcoinKey2Bytes: btcKey2,
×
3363
                ChannelPoint:     *op,
×
3364
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
3365
                Features:         featureBuf.Bytes(),
×
3366
                ExtraOpaqueData:  recs,
3367
        }
×
3368

×
3369
        // We always set all the signatures at the same time, so we can
×
3370
        // safely check if one signature is present to determine if we have the
3371
        // rest of the signatures for the auth proof.
×
3372
        if len(dbChan.Bitcoin1Signature) > 0 {
×
3373
                channel.AuthProof = &models.ChannelAuthProof{
×
3374
                        NodeSig1Bytes:    dbChan.Node1Signature,
3375
                        NodeSig2Bytes:    dbChan.Node2Signature,
3376
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
3377
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
3378
                }
3379
        }
3380

3381
        return channel, nil
3382
}
3383

UNCOV
3384
// buildNodeVertices is a helper that converts raw node public keys
×
UNCOV
3385
// into route.Vertex instances.
×
UNCOV
3386
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
×
3387
        route.Vertex, error) {
×
3388

×
3389
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
3390
        if err != nil {
×
3391
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
3392
                        "create vertex from node1 pubkey: %w", err)
3393
        }
UNCOV
3394

×
3395
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
3396
        if err != nil {
×
3397
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
3398
                        "create vertex from node2 pubkey: %w", err)
3399
        }
3400

3401
        return node1Vertex, node2Vertex, nil
×
UNCOV
3402
}
×
UNCOV
3403

×
UNCOV
3404
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
×
UNCOV
3405
// for a channel with the given ID.
×
UNCOV
3406
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
×
3407
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
3408

×
3409
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
3410
        if err != nil {
×
3411
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
3412
                        "features and extras: %w", err)
×
3413
        }
3414

3415
        var (
3416
                fv     = lnwire.EmptyFeatureVector()
×
3417
                extras = make(map[uint64][]byte)
3418
        )
3419
        for _, row := range rows {
3420
                if row.IsFeature {
3421
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
3422

×
3423
                        continue
×
UNCOV
3424
                }
×
UNCOV
3425

×
3426
                tlvType, ok := row.ExtraKey.(int64)
×
3427
                if !ok {
×
3428
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
3429
                                "TLV type: %T", row.ExtraKey)
×
3430
                }
×
UNCOV
3431

×
3432
                valueBytes, ok := row.Value.([]byte)
3433
                if !ok {
3434
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
3435
                                "Value: %T", row.Value)
3436
                }
3437

3438
                extras[uint64(tlvType)] = valueBytes
3439
        }
3440

3441
        return fv, extras, nil
3442
}
3443

3444
// getAndBuildChanPolicies uses the given sqlc.ChannelPolicy and also retrieves
3445
// all the extra info required to build the complete models.ChannelEdgePolicy
3446
// types. It returns two policies, which may be nil if the provided
3447
// sqlc.ChannelPolicy records are nil.
3448
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
UNCOV
3449
        dbPol1, dbPol2 *sqlc.ChannelPolicy, channelID uint64, node1,
×
UNCOV
3450
        node2 route.Vertex) (*models.ChannelEdgePolicy,
×
3451
        *models.ChannelEdgePolicy, error) {
×
3452

×
3453
        if dbPol1 == nil && dbPol2 == nil {
×
3454
                return nil, nil, nil
×
3455
        }
×
UNCOV
3456

×
3457
        var (
×
3458
                policy1ID int64
×
3459
                policy2ID int64
3460
        )
×
3461
        if dbPol1 != nil {
×
3462
                policy1ID = dbPol1.ID
×
3463
        }
×
3464
        if dbPol2 != nil {
×
3465
                policy2ID = dbPol2.ID
×
3466
        }
×
3467
        rows, err := db.GetChannelPolicyExtraTypes(
3468
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
3469
                        ID:   policy1ID,
×
3470
                        ID_2: policy2ID,
×
3471
                },
×
3472
        )
×
3473
        if err != nil {
×
3474
                return nil, nil, err
3475
        }
×
UNCOV
3476

×
3477
        var (
×
3478
                dbPol1Extras = make(map[uint64][]byte)
×
3479
                dbPol2Extras = make(map[uint64][]byte)
×
3480
        )
×
3481
        for _, row := range rows {
×
3482
                switch row.PolicyID {
×
3483
                case policy1ID:
3484
                        dbPol1Extras[uint64(row.Type)] = row.Value
3485
                case policy2ID:
3486
                        dbPol2Extras[uint64(row.Type)] = row.Value
3487
                default:
3488
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
3489
                                "in row: %v", row.PolicyID, row)
×
UNCOV
3490
                }
×
UNCOV
3491
        }
×
UNCOV
3492

×
3493
        var pol1, pol2 *models.ChannelEdgePolicy
×
3494
        if dbPol1 != nil {
×
3495
                pol1, err = buildChanPolicy(
3496
                        *dbPol1, channelID, dbPol1Extras, node2, true,
3497
                )
3498
                if err != nil {
×
3499
                        return nil, nil, err
×
3500
                }
×
UNCOV
3501
        }
×
3502
        if dbPol2 != nil {
×
3503
                pol2, err = buildChanPolicy(
×
3504
                        *dbPol2, channelID, dbPol2Extras, node1, false,
×
3505
                )
3506
                if err != nil {
×
3507
                        return nil, nil, err
×
3508
                }
×
UNCOV
3509
        }
×
3510

3511
        return pol1, pol2, nil
×
3512
}
3513

3514
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
3515
// provided sqlc.ChannelPolicy and other required information.
UNCOV
3516
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
×
UNCOV
3517
        extras map[uint64][]byte, toNode route.Vertex,
×
3518
        isNode1 bool) (*models.ChannelEdgePolicy, error) {
×
3519

×
3520
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
3521
        if err != nil {
×
3522
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
3523
                        "fields: %w", err)
×
3524
        }
×
UNCOV
3525

×
3526
        var msgFlags lnwire.ChanUpdateMsgFlags
×
3527
        if dbPolicy.MaxHtlcMsat.Valid {
×
3528
                msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
×
3529
        }
×
UNCOV
3530

×
3531
        var chanFlags lnwire.ChanUpdateChanFlags
×
3532
        if !isNode1 {
×
3533
                chanFlags |= lnwire.ChanUpdateDirection
×
3534
        }
×
3535
        if dbPolicy.Disabled.Bool {
3536
                chanFlags |= lnwire.ChanUpdateDisabled
3537
        }
UNCOV
3538

×
3539
        var inboundFee fn.Option[lnwire.Fee]
×
3540
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
3541
                dbPolicy.InboundBaseFeeMsat.Valid {
×
3542

3543
                inboundFee = fn.Some(lnwire.Fee{
×
3544
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
3545
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
3546
                })
×
3547
        }
UNCOV
3548

×
3549
        return &models.ChannelEdgePolicy{
×
3550
                SigBytes:  dbPolicy.Signature,
×
3551
                ChannelID: channelID,
×
3552
                LastUpdate: time.Unix(
3553
                        dbPolicy.LastUpdate.Int64, 0,
×
3554
                ),
×
3555
                MessageFlags:  msgFlags,
×
3556
                ChannelFlags:  chanFlags,
×
3557
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
3558
                MinHTLC: lnwire.MilliSatoshi(
×
3559
                        dbPolicy.MinHtlcMsat,
×
3560
                ),
×
3561
                MaxHTLC: lnwire.MilliSatoshi(
×
3562
                        dbPolicy.MaxHtlcMsat.Int64,
×
3563
                ),
×
3564
                FeeBaseMSat: lnwire.MilliSatoshi(
×
3565
                        dbPolicy.BaseFeeMsat,
×
3566
                ),
×
3567
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
3568
                ToNode:                    toNode,
×
3569
                InboundFee:                inboundFee,
×
3570
                ExtraOpaqueData:           recs,
×
3571
        }, nil
×
3572
}
3573

UNCOV
3574
// buildNodes builds the models.LightningNode instances for the
×
UNCOV
3575
// given row which is expected to be a sqlc type that contains node information.
×
UNCOV
3576
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
×
UNCOV
3577
        dbNode2 sqlc.Node) (*models.LightningNode, *models.LightningNode,
×
3578
        error) {
3579

3580
        node1, err := buildNode(ctx, db, &dbNode1)
×
3581
        if err != nil {
×
3582
                return nil, nil, err
×
3583
        }
×
UNCOV
3584

×
3585
        node2, err := buildNode(ctx, db, &dbNode2)
×
3586
        if err != nil {
3587
                return nil, nil, err
×
3588
        }
×
UNCOV
3589

×
3590
        return node1, node2, nil
×
UNCOV
3591
}
×
UNCOV
3592

×
UNCOV
3593
// extractChannelPolicies extracts the sqlc.ChannelPolicy records from the give
×
UNCOV
3594
// row which is expected to be a sqlc type that contains channel policy
×
UNCOV
3595
// information. It returns two policies, which may be nil if the policy
×
UNCOV
3596
// information is not present in the row.
×
UNCOV
3597
//
×
UNCOV
3598
//nolint:ll,dupl,funlen
×
UNCOV
3599
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
×
3600
        error) {
3601

3602
        var policy1, policy2 *sqlc.ChannelPolicy
3603
        switch r := row.(type) {
3604
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
3605
                if r.Policy1ID.Valid {
×
3606
                        policy1 = &sqlc.ChannelPolicy{
×
3607
                                ID:                      r.Policy1ID.Int64,
×
3608
                                Version:                 r.Policy1Version.Int16,
×
3609
                                ChannelID:               r.Channel.ID,
3610
                                NodeID:                  r.Policy1NodeID.Int64,
×
3611
                                Timelock:                r.Policy1Timelock.Int32,
×
3612
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
3613
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
3614
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
3615
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
3616
                                LastUpdate:              r.Policy1LastUpdate,
×
3617
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
3618
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
3619
                                Disabled:                r.Policy1Disabled,
×
3620
                                Signature:               r.Policy1Signature,
×
3621
                        }
×
3622
                }
×
3623
                if r.Policy2ID.Valid {
3624
                        policy2 = &sqlc.ChannelPolicy{
3625
                                ID:                      r.Policy2ID.Int64,
×
3626
                                Version:                 r.Policy2Version.Int16,
3627
                                ChannelID:               r.Channel.ID,
3628
                                NodeID:                  r.Policy2NodeID.Int64,
3629
                                Timelock:                r.Policy2Timelock.Int32,
3630
                                FeePpm:                  r.Policy2FeePpm.Int64,
3631
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
3632
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
3633
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
3634
                                LastUpdate:              r.Policy2LastUpdate,
×
3635
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
3636
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
3637
                                Disabled:                r.Policy2Disabled,
×
3638
                                Signature:               r.Policy2Signature,
×
3639
                        }
×
3640
                }
×
UNCOV
3641

×
3642
                return policy1, policy2, nil
×
UNCOV
3643

×
3644
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
3645
                if r.Policy1ID.Valid {
×
3646
                        policy1 = &sqlc.ChannelPolicy{
×
3647
                                ID:                      r.Policy1ID.Int64,
3648
                                Version:                 r.Policy1Version.Int16,
3649
                                ChannelID:               r.Channel.ID,
3650
                                NodeID:                  r.Policy1NodeID.Int64,
×
3651
                                Timelock:                r.Policy1Timelock.Int32,
×
3652
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
3653
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
3654
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
3655
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
3656
                                LastUpdate:              r.Policy1LastUpdate,
×
3657
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
3658
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
3659
                                Disabled:                r.Policy1Disabled,
3660
                                Signature:               r.Policy1Signature,
3661
                        }
3662
                }
3663
                if r.Policy2ID.Valid {
3664
                        policy2 = &sqlc.ChannelPolicy{
3665
                                ID:                      r.Policy2ID.Int64,
×
3666
                                Version:                 r.Policy2Version.Int16,
×
3667
                                ChannelID:               r.Channel.ID,
×
3668
                                NodeID:                  r.Policy2NodeID.Int64,
×
3669
                                Timelock:                r.Policy2Timelock.Int32,
×
3670
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
3671
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
3672
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
3673
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
3674
                                LastUpdate:              r.Policy2LastUpdate,
3675
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
3676
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
3677
                                Disabled:                r.Policy2Disabled,
×
3678
                                Signature:               r.Policy2Signature,
×
3679
                        }
×
3680
                }
×
UNCOV
3681

×
3682
                return policy1, policy2, nil
×
UNCOV
3683

×
3684
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
3685
                if r.Policy1ID.Valid {
×
3686
                        policy1 = &sqlc.ChannelPolicy{
×
3687
                                ID:                      r.Policy1ID.Int64,
×
3688
                                Version:                 r.Policy1Version.Int16,
×
3689
                                ChannelID:               r.Channel.ID,
3690
                                NodeID:                  r.Policy1NodeID.Int64,
3691
                                Timelock:                r.Policy1Timelock.Int32,
×
3692
                                FeePpm:                  r.Policy1FeePpm.Int64,
3693
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
3694
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
3695
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
3696
                                LastUpdate:              r.Policy1LastUpdate,
3697
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
3698
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
3699
                                Disabled:                r.Policy1Disabled,
×
3700
                                Signature:               r.Policy1Signature,
×
3701
                        }
×
3702
                }
×
3703
                if r.Policy2ID.Valid {
×
3704
                        policy2 = &sqlc.ChannelPolicy{
×
3705
                                ID:                      r.Policy2ID.Int64,
3706
                                Version:                 r.Policy2Version.Int16,
×
3707
                                ChannelID:               r.Channel.ID,
×
3708
                                NodeID:                  r.Policy2NodeID.Int64,
×
3709
                                Timelock:                r.Policy2Timelock.Int32,
×
3710
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
3711
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
3712
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
3713
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
3714
                                LastUpdate:              r.Policy2LastUpdate,
×
3715
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
3716
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
3717
                                Disabled:                r.Policy2Disabled,
3718
                                Signature:               r.Policy2Signature,
×
3719
                        }
×
3720
                }
×
UNCOV
3721

×
3722
                return policy1, policy2, nil
UNCOV
3723

×
3724
        case sqlc.ListChannelsByNodeIDRow:
×
3725
                if r.Policy1ID.Valid {
×
3726
                        policy1 = &sqlc.ChannelPolicy{
×
3727
                                ID:                      r.Policy1ID.Int64,
×
3728
                                Version:                 r.Policy1Version.Int16,
×
3729
                                ChannelID:               r.Channel.ID,
×
3730
                                NodeID:                  r.Policy1NodeID.Int64,
×
3731
                                Timelock:                r.Policy1Timelock.Int32,
3732
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
3733
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
3734
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
3735
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
3736
                                LastUpdate:              r.Policy1LastUpdate,
×
3737
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
3738
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
3739
                                Disabled:                r.Policy1Disabled,
×
3740
                                Signature:               r.Policy1Signature,
×
3741
                        }
×
3742
                }
×
3743
                if r.Policy2ID.Valid {
×
3744
                        policy2 = &sqlc.ChannelPolicy{
×
3745
                                ID:                      r.Policy2ID.Int64,
×
3746
                                Version:                 r.Policy2Version.Int16,
×
3747
                                ChannelID:               r.Channel.ID,
×
3748
                                NodeID:                  r.Policy2NodeID.Int64,
×
3749
                                Timelock:                r.Policy2Timelock.Int32,
×
3750
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
3751
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
3752
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
3753
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
3754
                                LastUpdate:              r.Policy2LastUpdate,
×
3755
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
3756
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
3757
                                Disabled:                r.Policy2Disabled,
×
3758
                                Signature:               r.Policy2Signature,
×
3759
                        }
×
3760
                }
UNCOV
3761

×
3762
                return policy1, policy2, nil
3763

3764
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
3765
                if r.Policy1ID.Valid {
3766
                        policy1 = &sqlc.ChannelPolicy{
3767
                                ID:                      r.Policy1ID.Int64,
×
3768
                                Version:                 r.Policy1Version.Int16,
×
3769
                                ChannelID:               r.Channel.ID,
×
3770
                                NodeID:                  r.Policy1NodeID.Int64,
×
3771
                                Timelock:                r.Policy1Timelock.Int32,
×
3772
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
3773
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
3774
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
3775
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
3776
                                LastUpdate:              r.Policy1LastUpdate,
×
3777
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
3778
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
3779
                                Disabled:                r.Policy1Disabled,
×
3780
                                Signature:               r.Policy1Signature,
3781
                        }
×
3782
                }
3783
                if r.Policy2ID.Valid {
3784
                        policy2 = &sqlc.ChannelPolicy{
3785
                                ID:                      r.Policy2ID.Int64,
3786
                                Version:                 r.Policy2Version.Int16,
3787
                                ChannelID:               r.Channel.ID,
×
3788
                                NodeID:                  r.Policy2NodeID.Int64,
×
3789
                                Timelock:                r.Policy2Timelock.Int32,
×
3790
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
3791
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
3792
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
3793
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
3794
                                LastUpdate:              r.Policy2LastUpdate,
3795
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
3796
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
3797
                                Disabled:                r.Policy2Disabled,
×
3798
                                Signature:               r.Policy2Signature,
×
3799
                        }
×
3800
                }
×
UNCOV
3801

×
3802
                return policy1, policy2, nil
×
3803
        default:
×
3804
                return nil, nil, fmt.Errorf("unexpected row type in "+
3805
                        "extractChannelPolicies: %T", r)
UNCOV
3806
        }
×
UNCOV
3807
}
×
UNCOV
3808

×
UNCOV
3809
// channelIDToBytes converts a channel ID (SCID) to a byte array
×
UNCOV
3810
// representation.
×
3811
func channelIDToBytes(channelID uint64) [8]byte {
3812
        var chanIDB [8]byte
×
3813
        byteOrder.PutUint64(chanIDB[:], channelID)
×
3814

×
3815
        return chanIDB
×
3816
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc