• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16640523368

31 Jul 2025 05:17AM UTC coverage: 67.074% (-0.1%) from 67.181%
16640523368

push

github

web-flow
Merge pull request #10115 from ellemouton/graphPerf3

[2] graph/db: batch-fetch node data

3 of 336 new or added lines in 3 files covered. (0.89%)

116 existing lines in 20 files now uncovered.

135509 of 202030 relevant lines covered (67.07%)

21670.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        "maps"
11
        "math"
12
        "net"
13
        "slices"
14
        "strconv"
15
        "sync"
16
        "time"
17

18
        "github.com/btcsuite/btcd/btcec/v2"
19
        "github.com/btcsuite/btcd/btcutil"
20
        "github.com/btcsuite/btcd/chaincfg/chainhash"
21
        "github.com/btcsuite/btcd/wire"
22
        "github.com/lightningnetwork/lnd/aliasmgr"
23
        "github.com/lightningnetwork/lnd/batch"
24
        "github.com/lightningnetwork/lnd/fn/v2"
25
        "github.com/lightningnetwork/lnd/graph/db/models"
26
        "github.com/lightningnetwork/lnd/lnwire"
27
        "github.com/lightningnetwork/lnd/routing/route"
28
        "github.com/lightningnetwork/lnd/sqldb"
29
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
30
        "github.com/lightningnetwork/lnd/tlv"
31
        "github.com/lightningnetwork/lnd/tor"
32
)
33

34
// pageSize is the limit for the number of records that can be returned
35
// in a paginated query. This can be tuned after some benchmarks.
36
const pageSize = 2000
37

38
// ProtocolVersion is an enum that defines the gossip protocol version of a
39
// message.
40
type ProtocolVersion uint8
41

42
const (
43
        // ProtocolV1 is the gossip protocol version defined in BOLT #7.
44
        ProtocolV1 ProtocolVersion = 1
45
)
46

47
// String returns a string representation of the protocol version.
48
func (v ProtocolVersion) String() string {
×
49
        return fmt.Sprintf("V%d", v)
×
50
}
×
51

52
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
53
// execute queries against the SQL graph tables.
54
//
55
//nolint:ll,interfacebloat
56
type SQLQueries interface {
57
        /*
58
                Node queries.
59
        */
60
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
61
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
62
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
63
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
64
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
65
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
66
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
67
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
68
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
69
        DeleteNode(ctx context.Context, id int64) error
70

71
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
72
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
73
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
74
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
75

76
        InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
77
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
78
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
79
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
80

81
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
82
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
83
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
84
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
85
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
86

87
        /*
88
                Source node queries.
89
        */
90
        AddSourceNode(ctx context.Context, nodeID int64) error
91
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
92

93
        /*
94
                Channel queries.
95
        */
96
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
97
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
98
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
99
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
100
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
101
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
102
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
103
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
104
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
105
        GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
106
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
107
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
108
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
109
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
110
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
111
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
112
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
113
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
114
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
115
        DeleteChannels(ctx context.Context, ids []int64) error
116

117
        CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
118
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
119

120
        /*
121
                Channel Policy table queries.
122
        */
123
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
124
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
125
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
126

127
        InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
128
        GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
129
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
130

131
        /*
132
                Zombie index queries.
133
        */
134
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
135
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
136
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
137
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
138
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
139

140
        /*
141
                Prune log table queries.
142
        */
143
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
144
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
145
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
146
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
147

148
        /*
149
                Closed SCID table queries.
150
        */
151
        InsertClosedChannel(ctx context.Context, scid []byte) error
152
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
153
}
154

155
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
156
// database operations.
157
type BatchedSQLQueries interface {
158
        SQLQueries
159
        sqldb.BatchedTx[SQLQueries]
160
}
161

162
// SQLStore is an implementation of the V1Store interface that uses a SQL
163
// database as the backend.
164
type SQLStore struct {
165
        cfg *SQLStoreConfig
166
        db  BatchedSQLQueries
167

168
        // cacheMu guards all caches (rejectCache and chanCache). If
169
        // this mutex will be acquired at the same time as the DB mutex then
170
        // the cacheMu MUST be acquired first to prevent deadlock.
171
        cacheMu     sync.RWMutex
172
        rejectCache *rejectCache
173
        chanCache   *channelCache
174

175
        chanScheduler batch.Scheduler[SQLQueries]
176
        nodeScheduler batch.Scheduler[SQLQueries]
177

178
        srcNodes  map[ProtocolVersion]*srcNodeInfo
179
        srcNodeMu sync.Mutex
180
}
181

182
// A compile-time assertion to ensure that SQLStore implements the V1Store
183
// interface.
184
var _ V1Store = (*SQLStore)(nil)
185

186
// SQLStoreConfig holds the configuration for the SQLStore.
187
type SQLStoreConfig struct {
188
        // ChainHash is the genesis hash for the chain that all the gossip
189
        // messages in this store are aimed at.
190
        ChainHash chainhash.Hash
191

192
        // PaginationCfg is the configuration for paginated queries.
193
        PaginationCfg *sqldb.PagedQueryConfig
194
}
195

196
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
197
// storage backend.
198
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
199
        options ...StoreOptionModifier) (*SQLStore, error) {
×
200

×
201
        opts := DefaultOptions()
×
202
        for _, o := range options {
×
203
                o(opts)
×
204
        }
×
205

206
        if opts.NoMigration {
×
207
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
208
                        "supported for SQL stores")
×
209
        }
×
210

211
        s := &SQLStore{
×
212
                cfg:         cfg,
×
213
                db:          db,
×
214
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
215
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
216
                srcNodes:    make(map[ProtocolVersion]*srcNodeInfo),
×
217
        }
×
218

×
219
        s.chanScheduler = batch.NewTimeScheduler(
×
220
                db, &s.cacheMu, opts.BatchCommitInterval,
×
221
        )
×
222
        s.nodeScheduler = batch.NewTimeScheduler(
×
223
                db, nil, opts.BatchCommitInterval,
×
224
        )
×
225

×
226
        return s, nil
×
227
}
228

229
// AddLightningNode adds a vertex/node to the graph database. If the node is not
230
// in the database from before, this will add a new, unconnected one to the
231
// graph. If it is present from before, this will update that node's
232
// information.
233
//
234
// NOTE: part of the V1Store interface.
235
func (s *SQLStore) AddLightningNode(ctx context.Context,
236
        node *models.LightningNode, opts ...batch.SchedulerOption) error {
×
237

×
238
        r := &batch.Request[SQLQueries]{
×
239
                Opts: batch.NewSchedulerOptions(opts...),
×
240
                Do: func(queries SQLQueries) error {
×
241
                        _, err := upsertNode(ctx, queries, node)
×
242
                        return err
×
243
                },
×
244
        }
245

246
        return s.nodeScheduler.Execute(ctx, r)
×
247
}
248

249
// FetchLightningNode attempts to look up a target node by its identity public
250
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
251
// returned.
252
//
253
// NOTE: part of the V1Store interface.
254
func (s *SQLStore) FetchLightningNode(ctx context.Context,
255
        pubKey route.Vertex) (*models.LightningNode, error) {
×
256

×
257
        var node *models.LightningNode
×
258
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
259
                var err error
×
260
                _, node, err = getNodeByPubKey(ctx, db, pubKey)
×
261

×
262
                return err
×
263
        }, sqldb.NoOpReset)
×
264
        if err != nil {
×
265
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
266
        }
×
267

268
        return node, nil
×
269
}
270

271
// HasLightningNode determines if the graph has a vertex identified by the
272
// target node identity public key. If the node exists in the database, a
273
// timestamp of when the data for the node was lasted updated is returned along
274
// with a true boolean. Otherwise, an empty time.Time is returned with a false
275
// boolean.
276
//
277
// NOTE: part of the V1Store interface.
278
func (s *SQLStore) HasLightningNode(ctx context.Context,
279
        pubKey [33]byte) (time.Time, bool, error) {
×
280

×
281
        var (
×
282
                exists     bool
×
283
                lastUpdate time.Time
×
284
        )
×
285
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
286
                dbNode, err := db.GetNodeByPubKey(
×
287
                        ctx, sqlc.GetNodeByPubKeyParams{
×
288
                                Version: int16(ProtocolV1),
×
289
                                PubKey:  pubKey[:],
×
290
                        },
×
291
                )
×
292
                if errors.Is(err, sql.ErrNoRows) {
×
293
                        return nil
×
294
                } else if err != nil {
×
295
                        return fmt.Errorf("unable to fetch node: %w", err)
×
296
                }
×
297

298
                exists = true
×
299

×
300
                if dbNode.LastUpdate.Valid {
×
301
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
302
                }
×
303

304
                return nil
×
305
        }, sqldb.NoOpReset)
306
        if err != nil {
×
307
                return time.Time{}, false,
×
308
                        fmt.Errorf("unable to fetch node: %w", err)
×
309
        }
×
310

311
        return lastUpdate, exists, nil
×
312
}
313

314
// AddrsForNode returns all known addresses for the target node public key
315
// that the graph DB is aware of. The returned boolean indicates if the
316
// given node is unknown to the graph DB or not.
317
//
318
// NOTE: part of the V1Store interface.
319
func (s *SQLStore) AddrsForNode(ctx context.Context,
320
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
321

×
322
        var (
×
323
                addresses []net.Addr
×
324
                known     bool
×
325
        )
×
326
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
327
                // First, check if the node exists and get its DB ID if it
×
328
                // does.
×
329
                dbID, err := db.GetNodeIDByPubKey(
×
330
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
331
                                Version: int16(ProtocolV1),
×
332
                                PubKey:  nodePub.SerializeCompressed(),
×
333
                        },
×
334
                )
×
335
                if errors.Is(err, sql.ErrNoRows) {
×
336
                        return nil
×
337
                }
×
338

339
                known = true
×
340

×
341
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
342
                if err != nil {
×
343
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
344
                                err)
×
345
                }
×
346

347
                return nil
×
348
        }, sqldb.NoOpReset)
349
        if err != nil {
×
350
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
351
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
352
        }
×
353

354
        return known, addresses, nil
×
355
}
356

357
// DeleteLightningNode starts a new database transaction to remove a vertex/node
358
// from the database according to the node's public key.
359
//
360
// NOTE: part of the V1Store interface.
361
func (s *SQLStore) DeleteLightningNode(ctx context.Context,
362
        pubKey route.Vertex) error {
×
363

×
364
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
365
                res, err := db.DeleteNodeByPubKey(
×
366
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
367
                                Version: int16(ProtocolV1),
×
368
                                PubKey:  pubKey[:],
×
369
                        },
×
370
                )
×
371
                if err != nil {
×
372
                        return err
×
373
                }
×
374

375
                rows, err := res.RowsAffected()
×
376
                if err != nil {
×
377
                        return err
×
378
                }
×
379

380
                if rows == 0 {
×
381
                        return ErrGraphNodeNotFound
×
382
                } else if rows > 1 {
×
383
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
384
                }
×
385

386
                return err
×
387
        }, sqldb.NoOpReset)
388
        if err != nil {
×
389
                return fmt.Errorf("unable to delete node: %w", err)
×
390
        }
×
391

392
        return nil
×
393
}
394

395
// FetchNodeFeatures returns the features of the given node. If no features are
396
// known for the node, an empty feature vector is returned.
397
//
398
// NOTE: this is part of the graphdb.NodeTraverser interface.
399
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
400
        *lnwire.FeatureVector, error) {
×
401

×
402
        ctx := context.TODO()
×
403

×
404
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
405
}
×
406

407
// DisabledChannelIDs returns the channel ids of disabled channels.
408
// A channel is disabled when two of the associated ChanelEdgePolicies
409
// have their disabled bit on.
410
//
411
// NOTE: part of the V1Store interface.
412
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
413
        var (
×
414
                ctx     = context.TODO()
×
415
                chanIDs []uint64
×
416
        )
×
417
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
418
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
419
                if err != nil {
×
420
                        return fmt.Errorf("unable to fetch disabled "+
×
421
                                "channels: %w", err)
×
422
                }
×
423

424
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
425

×
426
                return nil
×
427
        }, sqldb.NoOpReset)
428
        if err != nil {
×
429
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
430
                        err)
×
431
        }
×
432

433
        return chanIDs, nil
×
434
}
435

436
// LookupAlias attempts to return the alias as advertised by the target node.
437
//
438
// NOTE: part of the V1Store interface.
439
func (s *SQLStore) LookupAlias(ctx context.Context,
440
        pub *btcec.PublicKey) (string, error) {
×
441

×
442
        var alias string
×
443
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
444
                dbNode, err := db.GetNodeByPubKey(
×
445
                        ctx, sqlc.GetNodeByPubKeyParams{
×
446
                                Version: int16(ProtocolV1),
×
447
                                PubKey:  pub.SerializeCompressed(),
×
448
                        },
×
449
                )
×
450
                if errors.Is(err, sql.ErrNoRows) {
×
451
                        return ErrNodeAliasNotFound
×
452
                } else if err != nil {
×
453
                        return fmt.Errorf("unable to fetch node: %w", err)
×
454
                }
×
455

456
                if !dbNode.Alias.Valid {
×
457
                        return ErrNodeAliasNotFound
×
458
                }
×
459

460
                alias = dbNode.Alias.String
×
461

×
462
                return nil
×
463
        }, sqldb.NoOpReset)
464
        if err != nil {
×
465
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
466
        }
×
467

468
        return alias, nil
×
469
}
470

471
// SourceNode returns the source node of the graph. The source node is treated
472
// as the center node within a star-graph. This method may be used to kick off
473
// a path finding algorithm in order to explore the reachability of another
474
// node based off the source node.
475
//
476
// NOTE: part of the V1Store interface.
477
func (s *SQLStore) SourceNode(ctx context.Context) (*models.LightningNode,
478
        error) {
×
479

×
480
        var node *models.LightningNode
×
481
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
482
                _, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
483
                if err != nil {
×
484
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
485
                                err)
×
486
                }
×
487

488
                _, node, err = getNodeByPubKey(ctx, db, nodePub)
×
489

×
490
                return err
×
491
        }, sqldb.NoOpReset)
492
        if err != nil {
×
493
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
494
        }
×
495

496
        return node, nil
×
497
}
498

499
// SetSourceNode sets the source node within the graph database. The source
500
// node is to be used as the center of a star-graph within path finding
501
// algorithms.
502
//
503
// NOTE: part of the V1Store interface.
504
func (s *SQLStore) SetSourceNode(ctx context.Context,
505
        node *models.LightningNode) error {
×
506

×
507
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
508
                id, err := upsertNode(ctx, db, node)
×
509
                if err != nil {
×
510
                        return fmt.Errorf("unable to upsert source node: %w",
×
511
                                err)
×
512
                }
×
513

514
                // Make sure that if a source node for this version is already
515
                // set, then the ID is the same as the one we are about to set.
516
                dbSourceNodeID, _, err := s.getSourceNode(ctx, db, ProtocolV1)
×
517
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
518
                        return fmt.Errorf("unable to fetch source node: %w",
×
519
                                err)
×
520
                } else if err == nil {
×
521
                        if dbSourceNodeID != id {
×
522
                                return fmt.Errorf("v1 source node already "+
×
523
                                        "set to a different node: %d vs %d",
×
524
                                        dbSourceNodeID, id)
×
525
                        }
×
526

527
                        return nil
×
528
                }
529

530
                return db.AddSourceNode(ctx, id)
×
531
        }, sqldb.NoOpReset)
532
}
533

534
// NodeUpdatesInHorizon returns all the known lightning node which have an
535
// update timestamp within the passed range. This method can be used by two
536
// nodes to quickly determine if they have the same set of up to date node
537
// announcements.
538
//
539
// NOTE: This is part of the V1Store interface.
540
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
541
        endTime time.Time) ([]models.LightningNode, error) {
×
542

×
543
        ctx := context.TODO()
×
544

×
545
        var nodes []models.LightningNode
×
546
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
547
                dbNodes, err := db.GetNodesByLastUpdateRange(
×
548
                        ctx, sqlc.GetNodesByLastUpdateRangeParams{
×
549
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
550
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
551
                        },
×
552
                )
×
553
                if err != nil {
×
554
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
555
                }
×
556

NEW
557
                err = forEachNodeInBatch(
×
NEW
558
                        ctx, s.cfg.PaginationCfg, db, dbNodes,
×
NEW
559
                        func(_ int64, node *models.LightningNode) error {
×
NEW
560
                                nodes = append(nodes, *node)
×
561

×
NEW
562
                                return nil
×
NEW
563
                        },
×
564
                )
NEW
565
                if err != nil {
×
NEW
566
                        return fmt.Errorf("unable to build nodes: %w", err)
×
UNCOV
567
                }
×
568

569
                return nil
×
570
        }, sqldb.NoOpReset)
571
        if err != nil {
×
572
                return nil, fmt.Errorf("unable to fetch nodes: %w", err)
×
573
        }
×
574

575
        return nodes, nil
×
576
}
577

578
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
579
// undirected edge from the two target nodes are created. The information stored
580
// denotes the static attributes of the channel, such as the channelID, the keys
581
// involved in creation of the channel, and the set of features that the channel
582
// supports. The chanPoint and chanID are used to uniquely identify the edge
583
// globally within the database.
584
//
585
// NOTE: part of the V1Store interface.
586
func (s *SQLStore) AddChannelEdge(ctx context.Context,
587
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
588

×
589
        var alreadyExists bool
×
590
        r := &batch.Request[SQLQueries]{
×
591
                Opts: batch.NewSchedulerOptions(opts...),
×
592
                Reset: func() {
×
593
                        alreadyExists = false
×
594
                },
×
595
                Do: func(tx SQLQueries) error {
×
596
                        _, err := insertChannel(ctx, tx, edge)
×
597

×
598
                        // Silence ErrEdgeAlreadyExist so that the batch can
×
599
                        // succeed, but propagate the error via local state.
×
600
                        if errors.Is(err, ErrEdgeAlreadyExist) {
×
601
                                alreadyExists = true
×
602
                                return nil
×
603
                        }
×
604

605
                        return err
×
606
                },
607
                OnCommit: func(err error) error {
×
608
                        switch {
×
609
                        case err != nil:
×
610
                                return err
×
611
                        case alreadyExists:
×
612
                                return ErrEdgeAlreadyExist
×
613
                        default:
×
614
                                s.rejectCache.remove(edge.ChannelID)
×
615
                                s.chanCache.remove(edge.ChannelID)
×
616
                                return nil
×
617
                        }
618
                },
619
        }
620

621
        return s.chanScheduler.Execute(ctx, r)
×
622
}
623

624
// HighestChanID returns the "highest" known channel ID in the channel graph.
625
// This represents the "newest" channel from the PoV of the chain. This method
626
// can be used by peers to quickly determine if their graphs are in sync.
627
//
628
// NOTE: This is part of the V1Store interface.
629
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
630
        var highestChanID uint64
×
631
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
632
                chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
×
633
                if errors.Is(err, sql.ErrNoRows) {
×
634
                        return nil
×
635
                } else if err != nil {
×
636
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
637
                                err)
×
638
                }
×
639

640
                highestChanID = byteOrder.Uint64(chanID)
×
641

×
642
                return nil
×
643
        }, sqldb.NoOpReset)
644
        if err != nil {
×
645
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
646
        }
×
647

648
        return highestChanID, nil
×
649
}
650

651
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
652
// within the database for the referenced channel. The `flags` attribute within
653
// the ChannelEdgePolicy determines which of the directed edges are being
654
// updated. If the flag is 1, then the first node's information is being
655
// updated, otherwise it's the second node's information. The node ordering is
656
// determined by the lexicographical ordering of the identity public keys of the
657
// nodes on either side of the channel.
658
//
659
// NOTE: part of the V1Store interface.
660
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
661
        edge *models.ChannelEdgePolicy,
662
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
663

×
664
        var (
×
665
                isUpdate1    bool
×
666
                edgeNotFound bool
×
667
                from, to     route.Vertex
×
668
        )
×
669

×
670
        r := &batch.Request[SQLQueries]{
×
671
                Opts: batch.NewSchedulerOptions(opts...),
×
672
                Reset: func() {
×
673
                        isUpdate1 = false
×
674
                        edgeNotFound = false
×
675
                },
×
676
                Do: func(tx SQLQueries) error {
×
677
                        var err error
×
678
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
679
                                ctx, tx, edge,
×
680
                        )
×
681
                        if err != nil {
×
682
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
683
                        }
×
684

685
                        // Silence ErrEdgeNotFound so that the batch can
686
                        // succeed, but propagate the error via local state.
687
                        if errors.Is(err, ErrEdgeNotFound) {
×
688
                                edgeNotFound = true
×
689
                                return nil
×
690
                        }
×
691

692
                        return err
×
693
                },
694
                OnCommit: func(err error) error {
×
695
                        switch {
×
696
                        case err != nil:
×
697
                                return err
×
698
                        case edgeNotFound:
×
699
                                return ErrEdgeNotFound
×
700
                        default:
×
701
                                s.updateEdgeCache(edge, isUpdate1)
×
702
                                return nil
×
703
                        }
704
                },
705
        }
706

707
        err := s.chanScheduler.Execute(ctx, r)
×
708

×
709
        return from, to, err
×
710
}
711

712
// updateEdgeCache updates our reject and channel caches with the new
713
// edge policy information.
714
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
715
        isUpdate1 bool) {
×
716

×
717
        // If an entry for this channel is found in reject cache, we'll modify
×
718
        // the entry with the updated timestamp for the direction that was just
×
719
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
720
        // during the next query for this edge.
×
721
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
722
                if isUpdate1 {
×
723
                        entry.upd1Time = e.LastUpdate.Unix()
×
724
                } else {
×
725
                        entry.upd2Time = e.LastUpdate.Unix()
×
726
                }
×
727
                s.rejectCache.insert(e.ChannelID, entry)
×
728
        }
729

730
        // If an entry for this channel is found in channel cache, we'll modify
731
        // the entry with the updated policy for the direction that was just
732
        // written. If the edge doesn't exist, we'll defer loading the info and
733
        // policies and lazily read from disk during the next query.
734
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
735
                if isUpdate1 {
×
736
                        channel.Policy1 = e
×
737
                } else {
×
738
                        channel.Policy2 = e
×
739
                }
×
740
                s.chanCache.insert(e.ChannelID, channel)
×
741
        }
742
}
743

744
// ForEachSourceNodeChannel iterates through all channels of the source node,
745
// executing the passed callback on each. The call-back is provided with the
746
// channel's outpoint, whether we have a policy for the channel and the channel
747
// peer's node information.
748
//
749
// NOTE: part of the V1Store interface.
750
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
751
        cb func(chanPoint wire.OutPoint, havePolicy bool,
752
                otherNode *models.LightningNode) error, reset func()) error {
×
753

×
754
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
755
                nodeID, nodePub, err := s.getSourceNode(ctx, db, ProtocolV1)
×
756
                if err != nil {
×
757
                        return fmt.Errorf("unable to fetch source node: %w",
×
758
                                err)
×
759
                }
×
760

761
                return forEachNodeChannel(
×
762
                        ctx, db, s.cfg.ChainHash, nodeID,
×
763
                        func(info *models.ChannelEdgeInfo,
×
764
                                outPolicy *models.ChannelEdgePolicy,
×
765
                                _ *models.ChannelEdgePolicy) error {
×
766

×
767
                                // Fetch the other node.
×
768
                                var (
×
769
                                        otherNodePub [33]byte
×
770
                                        node1        = info.NodeKey1Bytes
×
771
                                        node2        = info.NodeKey2Bytes
×
772
                                )
×
773
                                switch {
×
774
                                case bytes.Equal(node1[:], nodePub[:]):
×
775
                                        otherNodePub = node2
×
776
                                case bytes.Equal(node2[:], nodePub[:]):
×
777
                                        otherNodePub = node1
×
778
                                default:
×
779
                                        return fmt.Errorf("node not " +
×
780
                                                "participating in this channel")
×
781
                                }
782

783
                                _, otherNode, err := getNodeByPubKey(
×
784
                                        ctx, db, otherNodePub,
×
785
                                )
×
786
                                if err != nil {
×
787
                                        return fmt.Errorf("unable to fetch "+
×
788
                                                "other node(%x): %w",
×
789
                                                otherNodePub, err)
×
790
                                }
×
791

792
                                return cb(
×
793
                                        info.ChannelPoint, outPolicy != nil,
×
794
                                        otherNode,
×
795
                                )
×
796
                        },
797
                )
798
        }, reset)
799
}
800

801
// ForEachNode iterates through all the stored vertices/nodes in the graph,
802
// executing the passed callback with each node encountered. If the callback
803
// returns an error, then the transaction is aborted and the iteration stops
804
// early. Any operations performed on the NodeTx passed to the call-back are
805
// executed under the same read transaction and so, methods on the NodeTx object
806
// _MUST_ only be called from within the call-back.
807
//
808
// NOTE: part of the V1Store interface.
809
func (s *SQLStore) ForEachNode(ctx context.Context,
810
        cb func(tx NodeRTx) error, reset func()) error {
×
811

×
NEW
812
        var lastID int64
×
813

×
NEW
814
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
815
                nodeCB := func(dbID int64, node *models.LightningNode) error {
×
NEW
816
                        err := cb(newSQLGraphNodeTx(
×
NEW
817
                                db, s.cfg.ChainHash, dbID, node,
×
NEW
818
                        ))
×
NEW
819
                        if err != nil {
×
NEW
820
                                return fmt.Errorf("callback failed for "+
×
NEW
821
                                        "node(id=%d): %w", dbID, err)
×
NEW
822
                        }
×
NEW
823
                        lastID = dbID
×
824

×
NEW
825
                        return nil
×
826
                }
827

828
                for {
×
829
                        nodes, err := db.ListNodesPaginated(
×
830
                                ctx, sqlc.ListNodesPaginatedParams{
×
831
                                        Version: int16(ProtocolV1),
×
832
                                        ID:      lastID,
×
833
                                        Limit:   pageSize,
×
834
                                },
×
835
                        )
×
836
                        if err != nil {
×
837
                                return fmt.Errorf("unable to fetch nodes: %w",
×
838
                                        err)
×
839
                        }
×
840

841
                        if len(nodes) == 0 {
×
842
                                break
×
843
                        }
844

NEW
845
                        err = forEachNodeInBatch(
×
NEW
846
                                ctx, s.cfg.PaginationCfg, db, nodes, nodeCB,
×
NEW
847
                        )
×
NEW
848
                        if err != nil {
×
NEW
849
                                return fmt.Errorf("unable to iterate over "+
×
NEW
850
                                        "nodes: %w", err)
×
UNCOV
851
                        }
×
852
                }
853

854
                return nil
×
855
        }, reset)
856
}
857

858
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
859
// SQLStore and a SQL transaction.
860
type sqlGraphNodeTx struct {
861
        db    SQLQueries
862
        id    int64
863
        node  *models.LightningNode
864
        chain chainhash.Hash
865
}
866

867
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
868
// interface.
869
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
870

871
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
872
        id int64, node *models.LightningNode) *sqlGraphNodeTx {
×
873

×
874
        return &sqlGraphNodeTx{
×
875
                db:    db,
×
876
                chain: chain,
×
877
                id:    id,
×
878
                node:  node,
×
879
        }
×
880
}
×
881

882
// Node returns the raw information of the node.
883
//
884
// NOTE: This is a part of the NodeRTx interface.
885
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
×
886
        return s.node
×
887
}
×
888

889
// ForEachChannel can be used to iterate over the node's channels under the same
890
// transaction used to fetch the node.
891
//
892
// NOTE: This is a part of the NodeRTx interface.
893
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
894
        *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
×
895

×
896
        ctx := context.TODO()
×
897

×
898
        return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
×
899
}
×
900

901
// FetchNode fetches the node with the given pub key under the same transaction
902
// used to fetch the current node. The returned node is also a NodeRTx and any
903
// operations on that NodeRTx will also be done under the same transaction.
904
//
905
// NOTE: This is a part of the NodeRTx interface.
906
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
×
907
        ctx := context.TODO()
×
908

×
909
        id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
×
910
        if err != nil {
×
911
                return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
×
912
                        nodePub, err)
×
913
        }
×
914

915
        return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
×
916
}
917

918
// ForEachNodeDirectedChannel iterates through all channels of a given node,
919
// executing the passed callback on the directed edge representing the channel
920
// and its incoming policy. If the callback returns an error, then the iteration
921
// is halted with the error propagated back up to the caller.
922
//
923
// Unknown policies are passed into the callback as nil values.
924
//
925
// NOTE: this is part of the graphdb.NodeTraverser interface.
926
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
927
        cb func(channel *DirectedChannel) error, reset func()) error {
×
928

×
929
        var ctx = context.TODO()
×
930

×
931
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
932
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
933
        }, reset)
×
934
}
935

936
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
937
// graph, executing the passed callback with each node encountered. If the
938
// callback returns an error, then the transaction is aborted and the iteration
939
// stops early.
940
//
941
// NOTE: This is a part of the V1Store interface.
942
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
943
        cb func(route.Vertex, *lnwire.FeatureVector) error,
944
        reset func()) error {
×
945

×
946
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
947
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
948
                        nodePub route.Vertex) error {
×
949

×
950
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
951
                        if err != nil {
×
952
                                return fmt.Errorf("unable to fetch node "+
×
953
                                        "features: %w", err)
×
954
                        }
×
955

956
                        return cb(nodePub, features)
×
957
                })
958
        }, reset)
959
        if err != nil {
×
960
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
961
        }
×
962

963
        return nil
×
964
}
965

966
// ForEachNodeChannel iterates through all channels of the given node,
967
// executing the passed callback with an edge info structure and the policies
968
// of each end of the channel. The first edge policy is the outgoing edge *to*
969
// the connecting node, while the second is the incoming edge *from* the
970
// connecting node. If the callback returns an error, then the iteration is
971
// halted with the error propagated back up to the caller.
972
//
973
// Unknown policies are passed into the callback as nil values.
974
//
975
// NOTE: part of the V1Store interface.
976
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
977
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
978
                *models.ChannelEdgePolicy) error, reset func()) error {
×
979

×
980
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
981
                dbNode, err := db.GetNodeByPubKey(
×
982
                        ctx, sqlc.GetNodeByPubKeyParams{
×
983
                                Version: int16(ProtocolV1),
×
984
                                PubKey:  nodePub[:],
×
985
                        },
×
986
                )
×
987
                if errors.Is(err, sql.ErrNoRows) {
×
988
                        return nil
×
989
                } else if err != nil {
×
990
                        return fmt.Errorf("unable to fetch node: %w", err)
×
991
                }
×
992

993
                return forEachNodeChannel(
×
994
                        ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
×
995
                )
×
996
        }, reset)
997
}
998

999
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1000
// one edge that has an update timestamp within the specified horizon.
1001
//
1002
// NOTE: This is part of the V1Store interface.
1003
func (s *SQLStore) ChanUpdatesInHorizon(startTime,
1004
        endTime time.Time) ([]ChannelEdge, error) {
×
1005

×
1006
        s.cacheMu.Lock()
×
1007
        defer s.cacheMu.Unlock()
×
1008

×
1009
        var (
×
1010
                ctx = context.TODO()
×
1011
                // To ensure we don't return duplicate ChannelEdges, we'll use
×
1012
                // an additional map to keep track of the edges already seen to
×
1013
                // prevent re-adding it.
×
1014
                edgesSeen    = make(map[uint64]struct{})
×
1015
                edgesToCache = make(map[uint64]ChannelEdge)
×
1016
                edges        []ChannelEdge
×
1017
                hits         int
×
1018
        )
×
1019
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1020
                rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1021
                        ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1022
                                Version:   int16(ProtocolV1),
×
1023
                                StartTime: sqldb.SQLInt64(startTime.Unix()),
×
1024
                                EndTime:   sqldb.SQLInt64(endTime.Unix()),
×
1025
                        },
×
1026
                )
×
1027
                if err != nil {
×
1028
                        return err
×
1029
                }
×
1030

1031
                for _, row := range rows {
×
1032
                        // If we've already retrieved the info and policies for
×
1033
                        // this edge, then we can skip it as we don't need to do
×
1034
                        // so again.
×
1035
                        chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid)
×
1036
                        if _, ok := edgesSeen[chanIDInt]; ok {
×
1037
                                continue
×
1038
                        }
1039

1040
                        if channel, ok := s.chanCache.get(chanIDInt); ok {
×
1041
                                hits++
×
1042
                                edgesSeen[chanIDInt] = struct{}{}
×
1043
                                edges = append(edges, channel)
×
1044

×
1045
                                continue
×
1046
                        }
1047

1048
                        node1, node2, err := buildNodes(
×
1049
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
1050
                        )
×
1051
                        if err != nil {
×
1052
                                return err
×
1053
                        }
×
1054

1055
                        channel, err := getAndBuildEdgeInfo(
×
1056
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1057
                                row.GraphChannel, node1.PubKeyBytes,
×
1058
                                node2.PubKeyBytes,
×
1059
                        )
×
1060
                        if err != nil {
×
1061
                                return fmt.Errorf("unable to build channel "+
×
1062
                                        "info: %w", err)
×
1063
                        }
×
1064

1065
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1066
                        if err != nil {
×
1067
                                return fmt.Errorf("unable to extract channel "+
×
1068
                                        "policies: %w", err)
×
1069
                        }
×
1070

1071
                        p1, p2, err := getAndBuildChanPolicies(
×
1072
                                ctx, db, dbPol1, dbPol2, channel.ChannelID,
×
1073
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
1074
                        )
×
1075
                        if err != nil {
×
1076
                                return fmt.Errorf("unable to build channel "+
×
1077
                                        "policies: %w", err)
×
1078
                        }
×
1079

1080
                        edgesSeen[chanIDInt] = struct{}{}
×
1081
                        chanEdge := ChannelEdge{
×
1082
                                Info:    channel,
×
1083
                                Policy1: p1,
×
1084
                                Policy2: p2,
×
1085
                                Node1:   node1,
×
1086
                                Node2:   node2,
×
1087
                        }
×
1088
                        edges = append(edges, chanEdge)
×
1089
                        edgesToCache[chanIDInt] = chanEdge
×
1090
                }
1091

1092
                return nil
×
1093
        }, func() {
×
1094
                edgesSeen = make(map[uint64]struct{})
×
1095
                edgesToCache = make(map[uint64]ChannelEdge)
×
1096
                edges = nil
×
1097
        })
×
1098
        if err != nil {
×
1099
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
1100
        }
×
1101

1102
        // Insert any edges loaded from disk into the cache.
1103
        for chanid, channel := range edgesToCache {
×
1104
                s.chanCache.insert(chanid, channel)
×
1105
        }
×
1106

1107
        if len(edges) > 0 {
×
1108
                log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
×
1109
                        float64(hits)*100/float64(len(edges)), hits, len(edges))
×
1110
        } else {
×
1111
                log.Debugf("ChanUpdatesInHorizon returned no edges in "+
×
1112
                        "horizon (%s, %s)", startTime, endTime)
×
1113
        }
×
1114

1115
        return edges, nil
×
1116
}
1117

1118
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1119
// data to the call-back.
1120
//
1121
// NOTE: The callback contents MUST not be modified.
1122
//
1123
// NOTE: part of the V1Store interface.
1124
func (s *SQLStore) ForEachNodeCached(ctx context.Context,
1125
        cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
1126
        reset func()) error {
×
1127

×
1128
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1129
                return forEachNodeCacheable(ctx, db, func(nodeID int64,
×
1130
                        nodePub route.Vertex) error {
×
1131

×
1132
                        features, err := getNodeFeatures(ctx, db, nodeID)
×
1133
                        if err != nil {
×
1134
                                return fmt.Errorf("unable to fetch "+
×
1135
                                        "node(id=%d) features: %w", nodeID, err)
×
1136
                        }
×
1137

1138
                        toNodeCallback := func() route.Vertex {
×
1139
                                return nodePub
×
1140
                        }
×
1141

1142
                        rows, err := db.ListChannelsByNodeID(
×
1143
                                ctx, sqlc.ListChannelsByNodeIDParams{
×
1144
                                        Version: int16(ProtocolV1),
×
1145
                                        NodeID1: nodeID,
×
1146
                                },
×
1147
                        )
×
1148
                        if err != nil {
×
1149
                                return fmt.Errorf("unable to fetch channels "+
×
1150
                                        "of node(id=%d): %w", nodeID, err)
×
1151
                        }
×
1152

1153
                        channels := make(map[uint64]*DirectedChannel, len(rows))
×
1154
                        for _, row := range rows {
×
1155
                                node1, node2, err := buildNodeVertices(
×
1156
                                        row.Node1Pubkey, row.Node2Pubkey,
×
1157
                                )
×
1158
                                if err != nil {
×
1159
                                        return err
×
1160
                                }
×
1161

1162
                                e, err := getAndBuildEdgeInfo(
×
1163
                                        ctx, db, s.cfg.ChainHash,
×
1164
                                        row.GraphChannel.ID, row.GraphChannel,
×
1165
                                        node1, node2,
×
1166
                                )
×
1167
                                if err != nil {
×
1168
                                        return fmt.Errorf("unable to build "+
×
1169
                                                "channel info: %w", err)
×
1170
                                }
×
1171

1172
                                dbPol1, dbPol2, err := extractChannelPolicies(
×
1173
                                        row,
×
1174
                                )
×
1175
                                if err != nil {
×
1176
                                        return fmt.Errorf("unable to "+
×
1177
                                                "extract channel "+
×
1178
                                                "policies: %w", err)
×
1179
                                }
×
1180

1181
                                p1, p2, err := getAndBuildChanPolicies(
×
1182
                                        ctx, db, dbPol1, dbPol2, e.ChannelID,
×
1183
                                        node1, node2,
×
1184
                                )
×
1185
                                if err != nil {
×
1186
                                        return fmt.Errorf("unable to "+
×
1187
                                                "build channel policies: %w",
×
1188
                                                err)
×
1189
                                }
×
1190

1191
                                // Determine the outgoing and incoming policy
1192
                                // for this channel and node combo.
1193
                                outPolicy, inPolicy := p1, p2
×
1194
                                if p1 != nil && p1.ToNode == nodePub {
×
1195
                                        outPolicy, inPolicy = p2, p1
×
1196
                                } else if p2 != nil && p2.ToNode != nodePub {
×
1197
                                        outPolicy, inPolicy = p2, p1
×
1198
                                }
×
1199

1200
                                var cachedInPolicy *models.CachedEdgePolicy
×
1201
                                if inPolicy != nil {
×
1202
                                        cachedInPolicy = models.NewCachedPolicy(
×
1203
                                                inPolicy,
×
1204
                                        )
×
1205
                                        cachedInPolicy.ToNodePubKey =
×
1206
                                                toNodeCallback
×
1207
                                        cachedInPolicy.ToNodeFeatures =
×
1208
                                                features
×
1209
                                }
×
1210

1211
                                var inboundFee lnwire.Fee
×
1212
                                if outPolicy != nil {
×
1213
                                        outPolicy.InboundFee.WhenSome(
×
1214
                                                func(fee lnwire.Fee) {
×
1215
                                                        inboundFee = fee
×
1216
                                                },
×
1217
                                        )
1218
                                }
1219

1220
                                directedChannel := &DirectedChannel{
×
1221
                                        ChannelID: e.ChannelID,
×
1222
                                        IsNode1: nodePub ==
×
1223
                                                e.NodeKey1Bytes,
×
1224
                                        OtherNode:    e.NodeKey2Bytes,
×
1225
                                        Capacity:     e.Capacity,
×
1226
                                        OutPolicySet: outPolicy != nil,
×
1227
                                        InPolicy:     cachedInPolicy,
×
1228
                                        InboundFee:   inboundFee,
×
1229
                                }
×
1230

×
1231
                                if nodePub == e.NodeKey2Bytes {
×
1232
                                        directedChannel.OtherNode =
×
1233
                                                e.NodeKey1Bytes
×
1234
                                }
×
1235

1236
                                channels[e.ChannelID] = directedChannel
×
1237
                        }
1238

1239
                        return cb(nodePub, channels)
×
1240
                })
1241
        }, reset)
1242
}
1243

1244
// ForEachChannelCacheable iterates through all the channel edges stored
1245
// within the graph and invokes the passed callback for each edge. The
1246
// callback takes two edges as since this is a directed graph, both the
1247
// in/out edges are visited. If the callback returns an error, then the
1248
// transaction is aborted and the iteration stops early.
1249
//
1250
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1251
// pointer for that particular channel edge routing policy will be
1252
// passed into the callback.
1253
//
1254
// NOTE: this method is like ForEachChannel but fetches only the data
1255
// required for the graph cache.
1256
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1257
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1258
        reset func()) error {
×
1259

×
1260
        ctx := context.TODO()
×
1261

×
1262
        handleChannel := func(
×
1263
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1264

×
1265
                node1, node2, err := buildNodeVertices(
×
1266
                        row.Node1Pubkey, row.Node2Pubkey,
×
1267
                )
×
1268
                if err != nil {
×
1269
                        return err
×
1270
                }
×
1271

1272
                edge := buildCacheableChannelInfo(
×
1273
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1274
                )
×
1275

×
1276
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1277
                if err != nil {
×
1278
                        return err
×
1279
                }
×
1280

1281
                var pol1, pol2 *models.CachedEdgePolicy
×
1282
                if dbPol1 != nil {
×
1283
                        policy1, err := buildChanPolicy(
×
1284
                                *dbPol1, edge.ChannelID, nil, node2,
×
1285
                        )
×
1286
                        if err != nil {
×
1287
                                return err
×
1288
                        }
×
1289

1290
                        pol1 = models.NewCachedPolicy(policy1)
×
1291
                }
1292
                if dbPol2 != nil {
×
1293
                        policy2, err := buildChanPolicy(
×
1294
                                *dbPol2, edge.ChannelID, nil, node1,
×
1295
                        )
×
1296
                        if err != nil {
×
1297
                                return err
×
1298
                        }
×
1299

1300
                        pol2 = models.NewCachedPolicy(policy2)
×
1301
                }
1302

1303
                if err := cb(edge, pol1, pol2); err != nil {
×
1304
                        return err
×
1305
                }
×
1306

1307
                return nil
×
1308
        }
1309

1310
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1311
                lastID := int64(-1)
×
1312
                for {
×
1313
                        //nolint:ll
×
1314
                        rows, err := db.ListChannelsWithPoliciesForCachePaginated(
×
1315
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1316
                                        Version: int16(ProtocolV1),
×
1317
                                        ID:      lastID,
×
1318
                                        Limit:   pageSize,
×
1319
                                },
×
1320
                        )
×
1321
                        if err != nil {
×
1322
                                return err
×
1323
                        }
×
1324

1325
                        if len(rows) == 0 {
×
1326
                                break
×
1327
                        }
1328

1329
                        for _, row := range rows {
×
1330
                                err := handleChannel(row)
×
1331
                                if err != nil {
×
1332
                                        return err
×
1333
                                }
×
1334

1335
                                lastID = row.ID
×
1336
                        }
1337
                }
1338

1339
                return nil
×
1340
        }, reset)
1341
}
1342

1343
// ForEachChannel iterates through all the channel edges stored within the
1344
// graph and invokes the passed callback for each edge. The callback takes two
1345
// edges as since this is a directed graph, both the in/out edges are visited.
1346
// If the callback returns an error, then the transaction is aborted and the
1347
// iteration stops early.
1348
//
1349
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1350
// for that particular channel edge routing policy will be passed into the
1351
// callback.
1352
//
1353
// NOTE: part of the V1Store interface.
1354
func (s *SQLStore) ForEachChannel(ctx context.Context,
1355
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1356
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1357

×
1358
        handleChannel := func(db SQLQueries,
×
1359
                row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
×
1360

×
1361
                node1, node2, err := buildNodeVertices(
×
1362
                        row.Node1Pubkey, row.Node2Pubkey,
×
1363
                )
×
1364
                if err != nil {
×
1365
                        return fmt.Errorf("unable to build node vertices: %w",
×
1366
                                err)
×
1367
                }
×
1368

1369
                edge, err := getAndBuildEdgeInfo(
×
1370
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1371
                        row.GraphChannel, node1, node2,
×
1372
                )
×
1373
                if err != nil {
×
1374
                        return fmt.Errorf("unable to build channel info: %w",
×
1375
                                err)
×
1376
                }
×
1377

1378
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1379
                if err != nil {
×
1380
                        return fmt.Errorf("unable to extract channel "+
×
1381
                                "policies: %w", err)
×
1382
                }
×
1383

1384
                p1, p2, err := getAndBuildChanPolicies(
×
1385
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1386
                )
×
1387
                if err != nil {
×
1388
                        return fmt.Errorf("unable to build channel "+
×
1389
                                "policies: %w", err)
×
1390
                }
×
1391

1392
                err = cb(edge, p1, p2)
×
1393
                if err != nil {
×
1394
                        return fmt.Errorf("callback failed for channel "+
×
1395
                                "id=%d: %w", edge.ChannelID, err)
×
1396
                }
×
1397

1398
                return nil
×
1399
        }
1400

1401
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1402
                lastID := int64(-1)
×
1403
                for {
×
1404
                        //nolint:ll
×
1405
                        rows, err := db.ListChannelsWithPoliciesPaginated(
×
1406
                                ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
1407
                                        Version: int16(ProtocolV1),
×
1408
                                        ID:      lastID,
×
1409
                                        Limit:   pageSize,
×
1410
                                },
×
1411
                        )
×
1412
                        if err != nil {
×
1413
                                return err
×
1414
                        }
×
1415

1416
                        if len(rows) == 0 {
×
1417
                                break
×
1418
                        }
1419

1420
                        for _, row := range rows {
×
1421
                                err := handleChannel(db, row)
×
1422
                                if err != nil {
×
1423
                                        return err
×
1424
                                }
×
1425

1426
                                lastID = row.GraphChannel.ID
×
1427
                        }
1428
                }
1429

1430
                return nil
×
1431
        }, reset)
1432
}
1433

1434
// FilterChannelRange returns the channel ID's of all known channels which were
1435
// mined in a block height within the passed range. The channel IDs are grouped
1436
// by their common block height. This method can be used to quickly share with a
1437
// peer the set of channels we know of within a particular range to catch them
1438
// up after a period of time offline. If withTimestamps is true then the
1439
// timestamp info of the latest received channel update messages of the channel
1440
// will be included in the response.
1441
//
1442
// NOTE: This is part of the V1Store interface.
1443
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1444
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1445

×
1446
        var (
×
1447
                ctx       = context.TODO()
×
1448
                startSCID = &lnwire.ShortChannelID{
×
1449
                        BlockHeight: startHeight,
×
1450
                }
×
1451
                endSCID = lnwire.ShortChannelID{
×
1452
                        BlockHeight: endHeight,
×
1453
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1454
                        TxPosition:  math.MaxUint16,
×
1455
                }
×
1456
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1457
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1458
        )
×
1459

×
1460
        // 1) get all channels where channelID is between start and end chan ID.
×
1461
        // 2) skip if not public (ie, no channel_proof)
×
1462
        // 3) collect that channel.
×
1463
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1464
        //    and add those timestamps to the collected channel.
×
1465
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1467
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1468
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1469
                                StartScid: chanIDStart,
×
1470
                                EndScid:   chanIDEnd,
×
1471
                        },
×
1472
                )
×
1473
                if err != nil {
×
1474
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1475
                                err)
×
1476
                }
×
1477

1478
                for _, dbChan := range dbChans {
×
1479
                        cid := lnwire.NewShortChanIDFromInt(
×
1480
                                byteOrder.Uint64(dbChan.Scid),
×
1481
                        )
×
1482
                        chanInfo := NewChannelUpdateInfo(
×
1483
                                cid, time.Time{}, time.Time{},
×
1484
                        )
×
1485

×
1486
                        if !withTimestamps {
×
1487
                                channelsPerBlock[cid.BlockHeight] = append(
×
1488
                                        channelsPerBlock[cid.BlockHeight],
×
1489
                                        chanInfo,
×
1490
                                )
×
1491

×
1492
                                continue
×
1493
                        }
1494

1495
                        //nolint:ll
1496
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1497
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1498
                                        Version:   int16(ProtocolV1),
×
1499
                                        ChannelID: dbChan.ID,
×
1500
                                        NodeID:    dbChan.NodeID1,
×
1501
                                },
×
1502
                        )
×
1503
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1504
                                return fmt.Errorf("unable to fetch node1 "+
×
1505
                                        "policy: %w", err)
×
1506
                        } else if err == nil {
×
1507
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1508
                                        node1Policy.LastUpdate.Int64, 0,
×
1509
                                )
×
1510
                        }
×
1511

1512
                        //nolint:ll
1513
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1514
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1515
                                        Version:   int16(ProtocolV1),
×
1516
                                        ChannelID: dbChan.ID,
×
1517
                                        NodeID:    dbChan.NodeID2,
×
1518
                                },
×
1519
                        )
×
1520
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1521
                                return fmt.Errorf("unable to fetch node2 "+
×
1522
                                        "policy: %w", err)
×
1523
                        } else if err == nil {
×
1524
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1525
                                        node2Policy.LastUpdate.Int64, 0,
×
1526
                                )
×
1527
                        }
×
1528

1529
                        channelsPerBlock[cid.BlockHeight] = append(
×
1530
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1531
                        )
×
1532
                }
1533

1534
                return nil
×
1535
        }, func() {
×
1536
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1537
        })
×
1538
        if err != nil {
×
1539
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1540
        }
×
1541

1542
        if len(channelsPerBlock) == 0 {
×
1543
                return nil, nil
×
1544
        }
×
1545

1546
        // Return the channel ranges in ascending block height order.
1547
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1548
        slices.Sort(blocks)
×
1549

×
1550
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1551
                return BlockChannelRange{
×
1552
                        Height:   block,
×
1553
                        Channels: channelsPerBlock[block],
×
1554
                }
×
1555
        }), nil
×
1556
}
1557

1558
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1559
// zombie. This method is used on an ad-hoc basis, when channels need to be
1560
// marked as zombies outside the normal pruning cycle.
1561
//
1562
// NOTE: part of the V1Store interface.
1563
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1564
        pubKey1, pubKey2 [33]byte) error {
×
1565

×
1566
        ctx := context.TODO()
×
1567

×
1568
        s.cacheMu.Lock()
×
1569
        defer s.cacheMu.Unlock()
×
1570

×
1571
        chanIDB := channelIDToBytes(chanID)
×
1572

×
1573
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1574
                return db.UpsertZombieChannel(
×
1575
                        ctx, sqlc.UpsertZombieChannelParams{
×
1576
                                Version:  int16(ProtocolV1),
×
1577
                                Scid:     chanIDB,
×
1578
                                NodeKey1: pubKey1[:],
×
1579
                                NodeKey2: pubKey2[:],
×
1580
                        },
×
1581
                )
×
1582
        }, sqldb.NoOpReset)
×
1583
        if err != nil {
×
1584
                return fmt.Errorf("unable to upsert zombie channel "+
×
1585
                        "(channel_id=%d): %w", chanID, err)
×
1586
        }
×
1587

1588
        s.rejectCache.remove(chanID)
×
1589
        s.chanCache.remove(chanID)
×
1590

×
1591
        return nil
×
1592
}
1593

1594
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1595
//
1596
// NOTE: part of the V1Store interface.
1597
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1598
        s.cacheMu.Lock()
×
1599
        defer s.cacheMu.Unlock()
×
1600

×
1601
        var (
×
1602
                ctx     = context.TODO()
×
1603
                chanIDB = channelIDToBytes(chanID)
×
1604
        )
×
1605

×
1606
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1607
                res, err := db.DeleteZombieChannel(
×
1608
                        ctx, sqlc.DeleteZombieChannelParams{
×
1609
                                Scid:    chanIDB,
×
1610
                                Version: int16(ProtocolV1),
×
1611
                        },
×
1612
                )
×
1613
                if err != nil {
×
1614
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1615
                                err)
×
1616
                }
×
1617

1618
                rows, err := res.RowsAffected()
×
1619
                if err != nil {
×
1620
                        return err
×
1621
                }
×
1622

1623
                if rows == 0 {
×
1624
                        return ErrZombieEdgeNotFound
×
1625
                } else if rows > 1 {
×
1626
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1627
                                "expected 1", rows)
×
1628
                }
×
1629

1630
                return nil
×
1631
        }, sqldb.NoOpReset)
1632
        if err != nil {
×
1633
                return fmt.Errorf("unable to mark edge live "+
×
1634
                        "(channel_id=%d): %w", chanID, err)
×
1635
        }
×
1636

1637
        s.rejectCache.remove(chanID)
×
1638
        s.chanCache.remove(chanID)
×
1639

×
1640
        return err
×
1641
}
1642

1643
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1644
// zombie, then the two node public keys corresponding to this edge are also
1645
// returned.
1646
//
1647
// NOTE: part of the V1Store interface.
1648
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1649
        error) {
×
1650

×
1651
        var (
×
1652
                ctx              = context.TODO()
×
1653
                isZombie         bool
×
1654
                pubKey1, pubKey2 route.Vertex
×
1655
                chanIDB          = channelIDToBytes(chanID)
×
1656
        )
×
1657

×
1658
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1659
                zombie, err := db.GetZombieChannel(
×
1660
                        ctx, sqlc.GetZombieChannelParams{
×
1661
                                Scid:    chanIDB,
×
1662
                                Version: int16(ProtocolV1),
×
1663
                        },
×
1664
                )
×
1665
                if errors.Is(err, sql.ErrNoRows) {
×
1666
                        return nil
×
1667
                }
×
1668
                if err != nil {
×
1669
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1670
                                err)
×
1671
                }
×
1672

1673
                copy(pubKey1[:], zombie.NodeKey1)
×
1674
                copy(pubKey2[:], zombie.NodeKey2)
×
1675
                isZombie = true
×
1676

×
1677
                return nil
×
1678
        }, sqldb.NoOpReset)
1679
        if err != nil {
×
1680
                return false, route.Vertex{}, route.Vertex{},
×
1681
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1682
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1683
        }
×
1684

1685
        return isZombie, pubKey1, pubKey2, nil
×
1686
}
1687

1688
// NumZombies returns the current number of zombie channels in the graph.
1689
//
1690
// NOTE: part of the V1Store interface.
1691
func (s *SQLStore) NumZombies() (uint64, error) {
×
1692
        var (
×
1693
                ctx        = context.TODO()
×
1694
                numZombies uint64
×
1695
        )
×
1696
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1697
                count, err := db.CountZombieChannels(ctx, int16(ProtocolV1))
×
1698
                if err != nil {
×
1699
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1700
                                err)
×
1701
                }
×
1702

1703
                numZombies = uint64(count)
×
1704

×
1705
                return nil
×
1706
        }, sqldb.NoOpReset)
1707
        if err != nil {
×
1708
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1709
        }
×
1710

1711
        return numZombies, nil
×
1712
}
1713

1714
// DeleteChannelEdges removes edges with the given channel IDs from the
1715
// database and marks them as zombies. This ensures that we're unable to re-add
1716
// it to our database once again. If an edge does not exist within the
1717
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1718
// true, then when we mark these edges as zombies, we'll set up the keys such
1719
// that we require the node that failed to send the fresh update to be the one
1720
// that resurrects the channel from its zombie state. The markZombie bool
1721
// denotes whether to mark the channel as a zombie.
1722
//
1723
// NOTE: part of the V1Store interface.
1724
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1725
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1726

×
1727
        s.cacheMu.Lock()
×
1728
        defer s.cacheMu.Unlock()
×
1729

×
1730
        // Keep track of which channels we end up finding so that we can
×
1731
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1732
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1733
        for _, chanID := range chanIDs {
×
1734
                chanLookup[chanID] = struct{}{}
×
1735
        }
×
1736

1737
        var (
×
1738
                ctx     = context.TODO()
×
1739
                deleted []*models.ChannelEdgeInfo
×
1740
        )
×
1741
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1742
                chanIDsToDelete := make([]int64, 0, len(chanIDs))
×
1743
                chanCallBack := func(ctx context.Context,
×
1744
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1745

×
1746
                        // Deleting the entry from the map indicates that we
×
1747
                        // have found the channel.
×
1748
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1749
                        delete(chanLookup, scid)
×
1750

×
1751
                        node1, node2, err := buildNodeVertices(
×
1752
                                row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1753
                        )
×
1754
                        if err != nil {
×
1755
                                return err
×
1756
                        }
×
1757

1758
                        info, err := getAndBuildEdgeInfo(
×
1759
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1760
                                row.GraphChannel, node1, node2,
×
1761
                        )
×
1762
                        if err != nil {
×
1763
                                return err
×
1764
                        }
×
1765

1766
                        deleted = append(deleted, info)
×
1767
                        chanIDsToDelete = append(
×
1768
                                chanIDsToDelete, row.GraphChannel.ID,
×
1769
                        )
×
1770

×
1771
                        if !markZombie {
×
1772
                                return nil
×
1773
                        }
×
1774

1775
                        nodeKey1, nodeKey2 := info.NodeKey1Bytes,
×
1776
                                info.NodeKey2Bytes
×
1777
                        if strictZombiePruning {
×
1778
                                var e1UpdateTime, e2UpdateTime *time.Time
×
1779
                                if row.Policy1LastUpdate.Valid {
×
1780
                                        e1Time := time.Unix(
×
1781
                                                row.Policy1LastUpdate.Int64, 0,
×
1782
                                        )
×
1783
                                        e1UpdateTime = &e1Time
×
1784
                                }
×
1785
                                if row.Policy2LastUpdate.Valid {
×
1786
                                        e2Time := time.Unix(
×
1787
                                                row.Policy2LastUpdate.Int64, 0,
×
1788
                                        )
×
1789
                                        e2UpdateTime = &e2Time
×
1790
                                }
×
1791

1792
                                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
1793
                                        info, e1UpdateTime, e2UpdateTime,
×
1794
                                )
×
1795
                        }
1796

1797
                        err = db.UpsertZombieChannel(
×
1798
                                ctx, sqlc.UpsertZombieChannelParams{
×
1799
                                        Version:  int16(ProtocolV1),
×
1800
                                        Scid:     channelIDToBytes(scid),
×
1801
                                        NodeKey1: nodeKey1[:],
×
1802
                                        NodeKey2: nodeKey2[:],
×
1803
                                },
×
1804
                        )
×
1805
                        if err != nil {
×
1806
                                return fmt.Errorf("unable to mark channel as "+
×
1807
                                        "zombie: %w", err)
×
1808
                        }
×
1809

1810
                        return nil
×
1811
                }
1812

1813
                err := s.forEachChanWithPoliciesInSCIDList(
×
1814
                        ctx, db, chanCallBack, chanIDs,
×
1815
                )
×
1816
                if err != nil {
×
1817
                        return err
×
1818
                }
×
1819

1820
                if len(chanLookup) > 0 {
×
1821
                        return ErrEdgeNotFound
×
1822
                }
×
1823

1824
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1825
        }, func() {
×
1826
                deleted = nil
×
1827

×
1828
                // Re-fill the lookup map.
×
1829
                for _, chanID := range chanIDs {
×
1830
                        chanLookup[chanID] = struct{}{}
×
1831
                }
×
1832
        })
1833
        if err != nil {
×
1834
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1835
                        err)
×
1836
        }
×
1837

1838
        for _, chanID := range chanIDs {
×
1839
                s.rejectCache.remove(chanID)
×
1840
                s.chanCache.remove(chanID)
×
1841
        }
×
1842

1843
        return deleted, nil
×
1844
}
1845

1846
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1847
// channel identified by the channel ID. If the channel can't be found, then
1848
// ErrEdgeNotFound is returned. A struct which houses the general information
1849
// for the channel itself is returned as well as two structs that contain the
1850
// routing policies for the channel in either direction.
1851
//
1852
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1853
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1854
// the ChannelEdgeInfo will only include the public keys of each node.
1855
//
1856
// NOTE: part of the V1Store interface.
1857
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1858
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1859
        *models.ChannelEdgePolicy, error) {
×
1860

×
1861
        var (
×
1862
                ctx              = context.TODO()
×
1863
                edge             *models.ChannelEdgeInfo
×
1864
                policy1, policy2 *models.ChannelEdgePolicy
×
1865
                chanIDB          = channelIDToBytes(chanID)
×
1866
        )
×
1867
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1868
                row, err := db.GetChannelBySCIDWithPolicies(
×
1869
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1870
                                Scid:    chanIDB,
×
1871
                                Version: int16(ProtocolV1),
×
1872
                        },
×
1873
                )
×
1874
                if errors.Is(err, sql.ErrNoRows) {
×
1875
                        // First check if this edge is perhaps in the zombie
×
1876
                        // index.
×
1877
                        zombie, err := db.GetZombieChannel(
×
1878
                                ctx, sqlc.GetZombieChannelParams{
×
1879
                                        Scid:    chanIDB,
×
1880
                                        Version: int16(ProtocolV1),
×
1881
                                },
×
1882
                        )
×
1883
                        if errors.Is(err, sql.ErrNoRows) {
×
1884
                                return ErrEdgeNotFound
×
1885
                        } else if err != nil {
×
1886
                                return fmt.Errorf("unable to check if "+
×
1887
                                        "channel is zombie: %w", err)
×
1888
                        }
×
1889

1890
                        // At this point, we know the channel is a zombie, so
1891
                        // we'll return an error indicating this, and we will
1892
                        // populate the edge info with the public keys of each
1893
                        // party as this is the only information we have about
1894
                        // it.
1895
                        edge = &models.ChannelEdgeInfo{}
×
1896
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1897
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1898

×
1899
                        return ErrZombieEdge
×
1900
                } else if err != nil {
×
1901
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1902
                }
×
1903

1904
                node1, node2, err := buildNodeVertices(
×
1905
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
1906
                )
×
1907
                if err != nil {
×
1908
                        return err
×
1909
                }
×
1910

1911
                edge, err = getAndBuildEdgeInfo(
×
1912
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1913
                        row.GraphChannel, node1, node2,
×
1914
                )
×
1915
                if err != nil {
×
1916
                        return fmt.Errorf("unable to build channel info: %w",
×
1917
                                err)
×
1918
                }
×
1919

1920
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1921
                if err != nil {
×
1922
                        return fmt.Errorf("unable to extract channel "+
×
1923
                                "policies: %w", err)
×
1924
                }
×
1925

1926
                policy1, policy2, err = getAndBuildChanPolicies(
×
1927
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1928
                )
×
1929
                if err != nil {
×
1930
                        return fmt.Errorf("unable to build channel "+
×
1931
                                "policies: %w", err)
×
1932
                }
×
1933

1934
                return nil
×
1935
        }, sqldb.NoOpReset)
1936
        if err != nil {
×
1937
                // If we are returning the ErrZombieEdge, then we also need to
×
1938
                // return the edge info as the method comment indicates that
×
1939
                // this will be populated when the edge is a zombie.
×
1940
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
1941
                        err)
×
1942
        }
×
1943

1944
        return edge, policy1, policy2, nil
×
1945
}
1946

1947
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1948
// the channel identified by the funding outpoint. If the channel can't be
1949
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1950
// information for the channel itself is returned as well as two structs that
1951
// contain the routing policies for the channel in either direction.
1952
//
1953
// NOTE: part of the V1Store interface.
1954
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1955
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1956
        *models.ChannelEdgePolicy, error) {
×
1957

×
1958
        var (
×
1959
                ctx              = context.TODO()
×
1960
                edge             *models.ChannelEdgeInfo
×
1961
                policy1, policy2 *models.ChannelEdgePolicy
×
1962
        )
×
1963
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1964
                row, err := db.GetChannelByOutpointWithPolicies(
×
1965
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
1966
                                Outpoint: op.String(),
×
1967
                                Version:  int16(ProtocolV1),
×
1968
                        },
×
1969
                )
×
1970
                if errors.Is(err, sql.ErrNoRows) {
×
1971
                        return ErrEdgeNotFound
×
1972
                } else if err != nil {
×
1973
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
1974
                }
×
1975

1976
                node1, node2, err := buildNodeVertices(
×
1977
                        row.Node1Pubkey, row.Node2Pubkey,
×
1978
                )
×
1979
                if err != nil {
×
1980
                        return err
×
1981
                }
×
1982

1983
                edge, err = getAndBuildEdgeInfo(
×
1984
                        ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
1985
                        row.GraphChannel, node1, node2,
×
1986
                )
×
1987
                if err != nil {
×
1988
                        return fmt.Errorf("unable to build channel info: %w",
×
1989
                                err)
×
1990
                }
×
1991

1992
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1993
                if err != nil {
×
1994
                        return fmt.Errorf("unable to extract channel "+
×
1995
                                "policies: %w", err)
×
1996
                }
×
1997

1998
                policy1, policy2, err = getAndBuildChanPolicies(
×
1999
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
2000
                )
×
2001
                if err != nil {
×
2002
                        return fmt.Errorf("unable to build channel "+
×
2003
                                "policies: %w", err)
×
2004
                }
×
2005

2006
                return nil
×
2007
        }, sqldb.NoOpReset)
2008
        if err != nil {
×
2009
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2010
                        err)
×
2011
        }
×
2012

2013
        return edge, policy1, policy2, nil
×
2014
}
2015

2016
// HasChannelEdge returns true if the database knows of a channel edge with the
2017
// passed channel ID, and false otherwise. If an edge with that ID is found
2018
// within the graph, then two time stamps representing the last time the edge
2019
// was updated for both directed edges are returned along with the boolean. If
2020
// it is not found, then the zombie index is checked and its result is returned
2021
// as the second boolean.
2022
//
2023
// NOTE: part of the V1Store interface.
2024
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2025
        bool, error) {
×
2026

×
2027
        ctx := context.TODO()
×
2028

×
2029
        var (
×
2030
                exists          bool
×
2031
                isZombie        bool
×
2032
                node1LastUpdate time.Time
×
2033
                node2LastUpdate time.Time
×
2034
        )
×
2035

×
2036
        // We'll query the cache with the shared lock held to allow multiple
×
2037
        // readers to access values in the cache concurrently if they exist.
×
2038
        s.cacheMu.RLock()
×
2039
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2040
                s.cacheMu.RUnlock()
×
2041
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2042
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2043
                exists, isZombie = entry.flags.unpack()
×
2044

×
2045
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2046
        }
×
2047
        s.cacheMu.RUnlock()
×
2048

×
2049
        s.cacheMu.Lock()
×
2050
        defer s.cacheMu.Unlock()
×
2051

×
2052
        // The item was not found with the shared lock, so we'll acquire the
×
2053
        // exclusive lock and check the cache again in case another method added
×
2054
        // the entry to the cache while no lock was held.
×
2055
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2056
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2057
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2058
                exists, isZombie = entry.flags.unpack()
×
2059

×
2060
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2061
        }
×
2062

2063
        chanIDB := channelIDToBytes(chanID)
×
2064
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2065
                channel, err := db.GetChannelBySCID(
×
2066
                        ctx, sqlc.GetChannelBySCIDParams{
×
2067
                                Scid:    chanIDB,
×
2068
                                Version: int16(ProtocolV1),
×
2069
                        },
×
2070
                )
×
2071
                if errors.Is(err, sql.ErrNoRows) {
×
2072
                        // Check if it is a zombie channel.
×
2073
                        isZombie, err = db.IsZombieChannel(
×
2074
                                ctx, sqlc.IsZombieChannelParams{
×
2075
                                        Scid:    chanIDB,
×
2076
                                        Version: int16(ProtocolV1),
×
2077
                                },
×
2078
                        )
×
2079
                        if err != nil {
×
2080
                                return fmt.Errorf("could not check if channel "+
×
2081
                                        "is zombie: %w", err)
×
2082
                        }
×
2083

2084
                        return nil
×
2085
                } else if err != nil {
×
2086
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2087
                }
×
2088

2089
                exists = true
×
2090

×
2091
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2092
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2093
                                Version:   int16(ProtocolV1),
×
2094
                                ChannelID: channel.ID,
×
2095
                                NodeID:    channel.NodeID1,
×
2096
                        },
×
2097
                )
×
2098
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2099
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2100
                                err)
×
2101
                } else if err == nil {
×
2102
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2103
                }
×
2104

2105
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2106
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2107
                                Version:   int16(ProtocolV1),
×
2108
                                ChannelID: channel.ID,
×
2109
                                NodeID:    channel.NodeID2,
×
2110
                        },
×
2111
                )
×
2112
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2113
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2114
                                err)
×
2115
                } else if err == nil {
×
2116
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2117
                }
×
2118

2119
                return nil
×
2120
        }, sqldb.NoOpReset)
2121
        if err != nil {
×
2122
                return time.Time{}, time.Time{}, false, false,
×
2123
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2124
        }
×
2125

2126
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2127
                upd1Time: node1LastUpdate.Unix(),
×
2128
                upd2Time: node2LastUpdate.Unix(),
×
2129
                flags:    packRejectFlags(exists, isZombie),
×
2130
        })
×
2131

×
2132
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2133
}
2134

2135
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2136
// passed channel point (outpoint). If the passed channel doesn't exist within
2137
// the database, then ErrEdgeNotFound is returned.
2138
//
2139
// NOTE: part of the V1Store interface.
2140
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2141
        var (
×
2142
                ctx       = context.TODO()
×
2143
                channelID uint64
×
2144
        )
×
2145
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2146
                chanID, err := db.GetSCIDByOutpoint(
×
2147
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2148
                                Outpoint: chanPoint.String(),
×
2149
                                Version:  int16(ProtocolV1),
×
2150
                        },
×
2151
                )
×
2152
                if errors.Is(err, sql.ErrNoRows) {
×
2153
                        return ErrEdgeNotFound
×
2154
                } else if err != nil {
×
2155
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2156
                                err)
×
2157
                }
×
2158

2159
                channelID = byteOrder.Uint64(chanID)
×
2160

×
2161
                return nil
×
2162
        }, sqldb.NoOpReset)
2163
        if err != nil {
×
2164
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2165
        }
×
2166

2167
        return channelID, nil
×
2168
}
2169

2170
// IsPublicNode is a helper method that determines whether the node with the
2171
// given public key is seen as a public node in the graph from the graph's
2172
// source node's point of view.
2173
//
2174
// NOTE: part of the V1Store interface.
2175
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2176
        ctx := context.TODO()
×
2177

×
2178
        var isPublic bool
×
2179
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2180
                var err error
×
2181
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2182

×
2183
                return err
×
2184
        }, sqldb.NoOpReset)
×
2185
        if err != nil {
×
2186
                return false, fmt.Errorf("unable to check if node is "+
×
2187
                        "public: %w", err)
×
2188
        }
×
2189

2190
        return isPublic, nil
×
2191
}
2192

2193
// FetchChanInfos returns the set of channel edges that correspond to the passed
2194
// channel ID's. If an edge is the query is unknown to the database, it will
2195
// skipped and the result will contain only those edges that exist at the time
2196
// of the query. This can be used to respond to peer queries that are seeking to
2197
// fill in gaps in their view of the channel graph.
2198
//
2199
// NOTE: part of the V1Store interface.
2200
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2201
        var (
×
2202
                ctx   = context.TODO()
×
2203
                edges = make(map[uint64]ChannelEdge)
×
2204
        )
×
2205
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2206
                chanCallBack := func(ctx context.Context,
×
2207
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2208

×
2209
                        node1, node2, err := buildNodes(
×
2210
                                ctx, db, row.GraphNode, row.GraphNode_2,
×
2211
                        )
×
2212
                        if err != nil {
×
2213
                                return fmt.Errorf("unable to fetch nodes: %w",
×
2214
                                        err)
×
2215
                        }
×
2216

2217
                        edge, err := getAndBuildEdgeInfo(
×
2218
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2219
                                row.GraphChannel, node1.PubKeyBytes,
×
2220
                                node2.PubKeyBytes,
×
2221
                        )
×
2222
                        if err != nil {
×
2223
                                return fmt.Errorf("unable to build "+
×
2224
                                        "channel info: %w", err)
×
2225
                        }
×
2226

2227
                        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2228
                        if err != nil {
×
2229
                                return fmt.Errorf("unable to extract channel "+
×
2230
                                        "policies: %w", err)
×
2231
                        }
×
2232

2233
                        p1, p2, err := getAndBuildChanPolicies(
×
2234
                                ctx, db, dbPol1, dbPol2, edge.ChannelID,
×
2235
                                node1.PubKeyBytes, node2.PubKeyBytes,
×
2236
                        )
×
2237
                        if err != nil {
×
2238
                                return fmt.Errorf("unable to build channel "+
×
2239
                                        "policies: %w", err)
×
2240
                        }
×
2241

2242
                        edges[edge.ChannelID] = ChannelEdge{
×
2243
                                Info:    edge,
×
2244
                                Policy1: p1,
×
2245
                                Policy2: p2,
×
2246
                                Node1:   node1,
×
2247
                                Node2:   node2,
×
2248
                        }
×
2249

×
2250
                        return nil
×
2251
                }
2252

2253
                return s.forEachChanWithPoliciesInSCIDList(
×
2254
                        ctx, db, chanCallBack, chanIDs,
×
2255
                )
×
2256
        }, func() {
×
2257
                clear(edges)
×
2258
        })
×
2259
        if err != nil {
×
2260
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2261
        }
×
2262

2263
        res := make([]ChannelEdge, 0, len(edges))
×
2264
        for _, chanID := range chanIDs {
×
2265
                edge, ok := edges[chanID]
×
2266
                if !ok {
×
2267
                        continue
×
2268
                }
2269

2270
                res = append(res, edge)
×
2271
        }
2272

2273
        return res, nil
×
2274
}
2275

2276
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2277
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2278
// channels in a paginated manner.
2279
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2280
        db SQLQueries, cb func(ctx context.Context,
2281
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2282
        chanIDs []uint64) error {
×
2283

×
2284
        queryWrapper := func(ctx context.Context,
×
2285
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2286
                error) {
×
2287

×
2288
                return db.GetChannelsBySCIDWithPolicies(
×
2289
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2290
                                Version: int16(ProtocolV1),
×
2291
                                Scids:   scids,
×
2292
                        },
×
2293
                )
×
2294
        }
×
2295

2296
        return sqldb.ExecutePagedQuery(
×
2297
                ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
×
2298
                queryWrapper, cb,
×
2299
        )
×
2300
}
2301

2302
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2303
// ID's that we don't know and are not known zombies of the passed set. In other
2304
// words, we perform a set difference of our set of chan ID's and the ones
2305
// passed in. This method can be used by callers to determine the set of
2306
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2307
// known zombies is also returned.
2308
//
2309
// NOTE: part of the V1Store interface.
2310
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2311
        []ChannelUpdateInfo, error) {
×
2312

×
2313
        var (
×
2314
                ctx          = context.TODO()
×
2315
                newChanIDs   []uint64
×
2316
                knownZombies []ChannelUpdateInfo
×
2317
                infoLookup   = make(
×
2318
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2319
                )
×
2320
        )
×
2321

×
2322
        // We first build a lookup map of the channel ID's to the
×
2323
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2324
        // already know about.
×
2325
        for _, chanInfo := range chansInfo {
×
2326
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2327
        }
×
2328

2329
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2330
                // The call-back function deletes known channels from
×
2331
                // infoLookup, so that we can later check which channels are
×
2332
                // zombies by only looking at the remaining channels in the set.
×
2333
                cb := func(ctx context.Context,
×
2334
                        channel sqlc.GraphChannel) error {
×
2335

×
2336
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2337

×
2338
                        return nil
×
2339
                }
×
2340

2341
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2342
                if err != nil {
×
2343
                        return fmt.Errorf("unable to iterate through "+
×
2344
                                "channels: %w", err)
×
2345
                }
×
2346

2347
                // We want to ensure that we deal with the channels in the
2348
                // same order that they were passed in, so we iterate over the
2349
                // original chansInfo slice and then check if that channel is
2350
                // still in the infoLookup map.
2351
                for _, chanInfo := range chansInfo {
×
2352
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2353
                        if _, ok := infoLookup[channelID]; !ok {
×
2354
                                continue
×
2355
                        }
2356

2357
                        isZombie, err := db.IsZombieChannel(
×
2358
                                ctx, sqlc.IsZombieChannelParams{
×
2359
                                        Scid:    channelIDToBytes(channelID),
×
2360
                                        Version: int16(ProtocolV1),
×
2361
                                },
×
2362
                        )
×
2363
                        if err != nil {
×
2364
                                return fmt.Errorf("unable to fetch zombie "+
×
2365
                                        "channel: %w", err)
×
2366
                        }
×
2367

2368
                        if isZombie {
×
2369
                                knownZombies = append(knownZombies, chanInfo)
×
2370

×
2371
                                continue
×
2372
                        }
2373

2374
                        newChanIDs = append(newChanIDs, channelID)
×
2375
                }
2376

2377
                return nil
×
2378
        }, func() {
×
2379
                newChanIDs = nil
×
2380
                knownZombies = nil
×
2381
                // Rebuild the infoLookup map in case of a rollback.
×
2382
                for _, chanInfo := range chansInfo {
×
2383
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2384
                        infoLookup[scid] = chanInfo
×
2385
                }
×
2386
        })
2387
        if err != nil {
×
2388
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2389
        }
×
2390

2391
        return newChanIDs, knownZombies, nil
×
2392
}
2393

2394
// forEachChanInSCIDList is a helper method that executes a paged query
2395
// against the database to fetch all channels that match the passed
2396
// ChannelUpdateInfo slice. The callback function is called for each channel
2397
// that is found.
2398
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2399
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2400
        chansInfo []ChannelUpdateInfo) error {
×
2401

×
2402
        queryWrapper := func(ctx context.Context,
×
2403
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2404

×
2405
                return db.GetChannelsBySCIDs(
×
2406
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2407
                                Version: int16(ProtocolV1),
×
2408
                                Scids:   scids,
×
2409
                        },
×
2410
                )
×
2411
        }
×
2412

2413
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2414
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2415

×
2416
                return channelIDToBytes(channelID)
×
2417
        }
×
2418

2419
        return sqldb.ExecutePagedQuery(
×
2420
                ctx, s.cfg.PaginationCfg, chansInfo, chanIDConverter,
×
2421
                queryWrapper, cb,
×
2422
        )
×
2423
}
2424

2425
// PruneGraphNodes is a garbage collection method which attempts to prune out
2426
// any nodes from the channel graph that are currently unconnected. This ensure
2427
// that we only maintain a graph of reachable nodes. In the event that a pruned
2428
// node gains more channels, it will be re-added back to the graph.
2429
//
2430
// NOTE: this prunes nodes across protocol versions. It will never prune the
2431
// source nodes.
2432
//
2433
// NOTE: part of the V1Store interface.
2434
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2435
        var ctx = context.TODO()
×
2436

×
2437
        var prunedNodes []route.Vertex
×
2438
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2439
                var err error
×
2440
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2441

×
2442
                return err
×
2443
        }, func() {
×
2444
                prunedNodes = nil
×
2445
        })
×
2446
        if err != nil {
×
2447
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2448
        }
×
2449

2450
        return prunedNodes, nil
×
2451
}
2452

2453
// PruneGraph prunes newly closed channels from the channel graph in response
2454
// to a new block being solved on the network. Any transactions which spend the
2455
// funding output of any known channels within he graph will be deleted.
2456
// Additionally, the "prune tip", or the last block which has been used to
2457
// prune the graph is stored so callers can ensure the graph is fully in sync
2458
// with the current UTXO state. A slice of channels that have been closed by
2459
// the target block along with any pruned nodes are returned if the function
2460
// succeeds without error.
2461
//
2462
// NOTE: part of the V1Store interface.
2463
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2464
        blockHash *chainhash.Hash, blockHeight uint32) (
2465
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2466

×
2467
        ctx := context.TODO()
×
2468

×
2469
        s.cacheMu.Lock()
×
2470
        defer s.cacheMu.Unlock()
×
2471

×
2472
        var (
×
2473
                closedChans []*models.ChannelEdgeInfo
×
2474
                prunedNodes []route.Vertex
×
2475
        )
×
2476
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2477
                var chansToDelete []int64
×
2478

×
2479
                // Define the callback function for processing each channel.
×
2480
                channelCallback := func(ctx context.Context,
×
2481
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2482

×
2483
                        node1, node2, err := buildNodeVertices(
×
2484
                                row.Node1Pubkey, row.Node2Pubkey,
×
2485
                        )
×
2486
                        if err != nil {
×
2487
                                return err
×
2488
                        }
×
2489

2490
                        info, err := getAndBuildEdgeInfo(
×
2491
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2492
                                row.GraphChannel, node1, node2,
×
2493
                        )
×
2494
                        if err != nil {
×
2495
                                return err
×
2496
                        }
×
2497

2498
                        closedChans = append(closedChans, info)
×
2499
                        chansToDelete = append(
×
2500
                                chansToDelete, row.GraphChannel.ID,
×
2501
                        )
×
2502

×
2503
                        return nil
×
2504
                }
2505

2506
                err := s.forEachChanInOutpoints(
×
2507
                        ctx, db, spentOutputs, channelCallback,
×
2508
                )
×
2509
                if err != nil {
×
2510
                        return fmt.Errorf("unable to fetch channels by "+
×
2511
                                "outpoints: %w", err)
×
2512
                }
×
2513

2514
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2515
                if err != nil {
×
2516
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2517
                }
×
2518

2519
                err = db.UpsertPruneLogEntry(
×
2520
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2521
                                BlockHash:   blockHash[:],
×
2522
                                BlockHeight: int64(blockHeight),
×
2523
                        },
×
2524
                )
×
2525
                if err != nil {
×
2526
                        return fmt.Errorf("unable to insert prune log "+
×
2527
                                "entry: %w", err)
×
2528
                }
×
2529

2530
                // Now that we've pruned some channels, we'll also prune any
2531
                // nodes that no longer have any channels.
2532
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2533
                if err != nil {
×
2534
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2535
                                err)
×
2536
                }
×
2537

2538
                return nil
×
2539
        }, func() {
×
2540
                prunedNodes = nil
×
2541
                closedChans = nil
×
2542
        })
×
2543
        if err != nil {
×
2544
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2545
        }
×
2546

2547
        for _, channel := range closedChans {
×
2548
                s.rejectCache.remove(channel.ChannelID)
×
2549
                s.chanCache.remove(channel.ChannelID)
×
2550
        }
×
2551

2552
        return closedChans, prunedNodes, nil
×
2553
}
2554

2555
// forEachChanInOutpoints is a helper function that executes a paginated
2556
// query to fetch channels by their outpoints and applies the given call-back
2557
// to each.
2558
//
2559
// NOTE: this fetches channels for all protocol versions.
2560
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2561
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2562
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2563

×
2564
        // Create a wrapper that uses the transaction's db instance to execute
×
2565
        // the query.
×
2566
        queryWrapper := func(ctx context.Context,
×
2567
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2568
                error) {
×
2569

×
2570
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2571
        }
×
2572

2573
        // Define the conversion function from Outpoint to string.
2574
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2575
                return outpoint.String()
×
2576
        }
×
2577

2578
        return sqldb.ExecutePagedQuery(
×
2579
                ctx, s.cfg.PaginationCfg, outpoints, outpointToString,
×
2580
                queryWrapper, cb,
×
2581
        )
×
2582
}
2583

2584
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2585
        dbIDs []int64) error {
×
2586

×
2587
        // Create a wrapper that uses the transaction's db instance to execute
×
2588
        // the query.
×
2589
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2590
                return nil, db.DeleteChannels(ctx, ids)
×
2591
        }
×
2592

2593
        idConverter := func(id int64) int64 {
×
2594
                return id
×
2595
        }
×
2596

2597
        return sqldb.ExecutePagedQuery(
×
2598
                ctx, s.cfg.PaginationCfg, dbIDs, idConverter,
×
2599
                queryWrapper, func(ctx context.Context, _ any) error {
×
2600
                        return nil
×
2601
                },
×
2602
        )
2603
}
2604

2605
// ChannelView returns the verifiable edge information for each active channel
2606
// within the known channel graph. The set of UTXOs (along with their scripts)
2607
// returned are the ones that need to be watched on chain to detect channel
2608
// closes on the resident blockchain.
2609
//
2610
// NOTE: part of the V1Store interface.
2611
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2612
        var (
×
2613
                ctx        = context.TODO()
×
2614
                edgePoints []EdgePoint
×
2615
        )
×
2616

×
2617
        handleChannel := func(db SQLQueries,
×
2618
                channel sqlc.ListChannelsPaginatedRow) error {
×
2619

×
2620
                pkScript, err := genMultiSigP2WSH(
×
2621
                        channel.BitcoinKey1, channel.BitcoinKey2,
×
2622
                )
×
2623
                if err != nil {
×
2624
                        return err
×
2625
                }
×
2626

2627
                op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2628
                if err != nil {
×
2629
                        return err
×
2630
                }
×
2631

2632
                edgePoints = append(edgePoints, EdgePoint{
×
2633
                        FundingPkScript: pkScript,
×
2634
                        OutPoint:        *op,
×
2635
                })
×
2636

×
2637
                return nil
×
2638
        }
2639

2640
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2641
                lastID := int64(-1)
×
2642
                for {
×
2643
                        rows, err := db.ListChannelsPaginated(
×
2644
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2645
                                        Version: int16(ProtocolV1),
×
2646
                                        ID:      lastID,
×
2647
                                        Limit:   pageSize,
×
2648
                                },
×
2649
                        )
×
2650
                        if err != nil {
×
2651
                                return err
×
2652
                        }
×
2653

2654
                        if len(rows) == 0 {
×
2655
                                break
×
2656
                        }
2657

2658
                        for _, row := range rows {
×
2659
                                err := handleChannel(db, row)
×
2660
                                if err != nil {
×
2661
                                        return err
×
2662
                                }
×
2663

2664
                                lastID = row.ID
×
2665
                        }
2666
                }
2667

2668
                return nil
×
2669
        }, func() {
×
2670
                edgePoints = nil
×
2671
        })
×
2672
        if err != nil {
×
2673
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2674
        }
×
2675

2676
        return edgePoints, nil
×
2677
}
2678

2679
// PruneTip returns the block height and hash of the latest block that has been
2680
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2681
// to tell if the graph is currently in sync with the current best known UTXO
2682
// state.
2683
//
2684
// NOTE: part of the V1Store interface.
2685
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2686
        var (
×
2687
                ctx       = context.TODO()
×
2688
                tipHash   chainhash.Hash
×
2689
                tipHeight uint32
×
2690
        )
×
2691
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2692
                pruneTip, err := db.GetPruneTip(ctx)
×
2693
                if errors.Is(err, sql.ErrNoRows) {
×
2694
                        return ErrGraphNeverPruned
×
2695
                } else if err != nil {
×
2696
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2697
                }
×
2698

2699
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2700
                tipHeight = uint32(pruneTip.BlockHeight)
×
2701

×
2702
                return nil
×
2703
        }, sqldb.NoOpReset)
2704
        if err != nil {
×
2705
                return nil, 0, err
×
2706
        }
×
2707

2708
        return &tipHash, tipHeight, nil
×
2709
}
2710

2711
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2712
//
2713
// NOTE: this prunes nodes across protocol versions. It will never prune the
2714
// source nodes.
2715
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2716
        db SQLQueries) ([]route.Vertex, error) {
×
2717

×
2718
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2719
        if err != nil {
×
2720
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2721
                        "nodes: %w", err)
×
2722
        }
×
2723

2724
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2725
        for i, nodeKey := range nodeKeys {
×
2726
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2727
                if err != nil {
×
2728
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2729
                                "from bytes: %w", err)
×
2730
                }
×
2731

2732
                prunedNodes[i] = pub
×
2733
        }
2734

2735
        return prunedNodes, nil
×
2736
}
2737

2738
// DisconnectBlockAtHeight is used to indicate that the block specified
2739
// by the passed height has been disconnected from the main chain. This
2740
// will "rewind" the graph back to the height below, deleting channels
2741
// that are no longer confirmed from the graph. The prune log will be
2742
// set to the last prune height valid for the remaining chain.
2743
// Channels that were removed from the graph resulting from the
2744
// disconnected block are returned.
2745
//
2746
// NOTE: part of the V1Store interface.
2747
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2748
        []*models.ChannelEdgeInfo, error) {
×
2749

×
2750
        ctx := context.TODO()
×
2751

×
2752
        var (
×
2753
                // Every channel having a ShortChannelID starting at 'height'
×
2754
                // will no longer be confirmed.
×
2755
                startShortChanID = lnwire.ShortChannelID{
×
2756
                        BlockHeight: height,
×
2757
                }
×
2758

×
2759
                // Delete everything after this height from the db up until the
×
2760
                // SCID alias range.
×
2761
                endShortChanID = aliasmgr.StartingAlias
×
2762

×
2763
                removedChans []*models.ChannelEdgeInfo
×
2764

×
2765
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2766
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2767
        )
×
2768

×
2769
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2770
                rows, err := db.GetChannelsBySCIDRange(
×
2771
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2772
                                StartScid: chanIDStart,
×
2773
                                EndScid:   chanIDEnd,
×
2774
                        },
×
2775
                )
×
2776
                if err != nil {
×
2777
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2778
                }
×
2779

2780
                chanIDsToDelete := make([]int64, len(rows))
×
2781
                for i, row := range rows {
×
2782
                        node1, node2, err := buildNodeVertices(
×
2783
                                row.Node1PubKey, row.Node2PubKey,
×
2784
                        )
×
2785
                        if err != nil {
×
2786
                                return err
×
2787
                        }
×
2788

2789
                        channel, err := getAndBuildEdgeInfo(
×
2790
                                ctx, db, s.cfg.ChainHash, row.GraphChannel.ID,
×
2791
                                row.GraphChannel, node1, node2,
×
2792
                        )
×
2793
                        if err != nil {
×
2794
                                return err
×
2795
                        }
×
2796

2797
                        chanIDsToDelete[i] = row.GraphChannel.ID
×
2798
                        removedChans = append(removedChans, channel)
×
2799
                }
2800

2801
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2802
                if err != nil {
×
2803
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2804
                }
×
2805

2806
                return db.DeletePruneLogEntriesInRange(
×
2807
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2808
                                StartHeight: int64(height),
×
2809
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2810
                        },
×
2811
                )
×
2812
        }, func() {
×
2813
                removedChans = nil
×
2814
        })
×
2815
        if err != nil {
×
2816
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2817
                        "height: %w", err)
×
2818
        }
×
2819

2820
        for _, channel := range removedChans {
×
2821
                s.rejectCache.remove(channel.ChannelID)
×
2822
                s.chanCache.remove(channel.ChannelID)
×
2823
        }
×
2824

2825
        return removedChans, nil
×
2826
}
2827

2828
// AddEdgeProof sets the proof of an existing edge in the graph database.
2829
//
2830
// NOTE: part of the V1Store interface.
2831
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2832
        proof *models.ChannelAuthProof) error {
×
2833

×
2834
        var (
×
2835
                ctx       = context.TODO()
×
2836
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2837
        )
×
2838

×
2839
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2840
                res, err := db.AddV1ChannelProof(
×
2841
                        ctx, sqlc.AddV1ChannelProofParams{
×
2842
                                Scid:              scidBytes,
×
2843
                                Node1Signature:    proof.NodeSig1Bytes,
×
2844
                                Node2Signature:    proof.NodeSig2Bytes,
×
2845
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2846
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2847
                        },
×
2848
                )
×
2849
                if err != nil {
×
2850
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2851
                }
×
2852

2853
                n, err := res.RowsAffected()
×
2854
                if err != nil {
×
2855
                        return err
×
2856
                }
×
2857

2858
                if n == 0 {
×
2859
                        return fmt.Errorf("no rows affected when adding edge "+
×
2860
                                "proof for SCID %v", scid)
×
2861
                } else if n > 1 {
×
2862
                        return fmt.Errorf("multiple rows affected when adding "+
×
2863
                                "edge proof for SCID %v: %d rows affected",
×
2864
                                scid, n)
×
2865
                }
×
2866

2867
                return nil
×
2868
        }, sqldb.NoOpReset)
2869
        if err != nil {
×
2870
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2871
        }
×
2872

2873
        return nil
×
2874
}
2875

2876
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2877
// that we can ignore channel announcements that we know to be closed without
2878
// having to validate them and fetch a block.
2879
//
2880
// NOTE: part of the V1Store interface.
2881
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2882
        var (
×
2883
                ctx     = context.TODO()
×
2884
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2885
        )
×
2886

×
2887
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2888
                return db.InsertClosedChannel(ctx, chanIDB)
×
2889
        }, sqldb.NoOpReset)
×
2890
}
2891

2892
// IsClosedScid checks whether a channel identified by the passed in scid is
2893
// closed. This helps avoid having to perform expensive validation checks.
2894
//
2895
// NOTE: part of the V1Store interface.
2896
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2897
        var (
×
2898
                ctx      = context.TODO()
×
2899
                isClosed bool
×
2900
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2901
        )
×
2902
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2903
                var err error
×
2904
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2905
                if err != nil {
×
2906
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2907
                                err)
×
2908
                }
×
2909

2910
                return nil
×
2911
        }, sqldb.NoOpReset)
2912
        if err != nil {
×
2913
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
2914
                        err)
×
2915
        }
×
2916

2917
        return isClosed, nil
×
2918
}
2919

2920
// GraphSession will provide the call-back with access to a NodeTraverser
2921
// instance which can be used to perform queries against the channel graph.
2922
//
2923
// NOTE: part of the V1Store interface.
2924
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
2925
        reset func()) error {
×
2926

×
2927
        var ctx = context.TODO()
×
2928

×
2929
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2930
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
2931
        }, reset)
×
2932
}
2933

2934
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2935
// read only transaction for a consistent view of the graph.
2936
type sqlNodeTraverser struct {
2937
        db    SQLQueries
2938
        chain chainhash.Hash
2939
}
2940

2941
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2942
// NodeTraverser interface.
2943
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2944

2945
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2946
func newSQLNodeTraverser(db SQLQueries,
2947
        chain chainhash.Hash) *sqlNodeTraverser {
×
2948

×
2949
        return &sqlNodeTraverser{
×
2950
                db:    db,
×
2951
                chain: chain,
×
2952
        }
×
2953
}
×
2954

2955
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2956
// node.
2957
//
2958
// NOTE: Part of the NodeTraverser interface.
2959
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2960
        cb func(channel *DirectedChannel) error, _ func()) error {
×
2961

×
2962
        ctx := context.TODO()
×
2963

×
2964
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
2965
}
×
2966

2967
// FetchNodeFeatures returns the features of the given node. If the node is
2968
// unknown, assume no additional features are supported.
2969
//
2970
// NOTE: Part of the NodeTraverser interface.
2971
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2972
        *lnwire.FeatureVector, error) {
×
2973

×
2974
        ctx := context.TODO()
×
2975

×
2976
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
2977
}
×
2978

2979
// forEachNodeDirectedChannel iterates through all channels of a given
2980
// node, executing the passed callback on the directed edge representing the
2981
// channel and its incoming policy. If the node is not found, no error is
2982
// returned.
2983
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
2984
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
2985

×
2986
        toNodeCallback := func() route.Vertex {
×
2987
                return nodePub
×
2988
        }
×
2989

2990
        dbID, err := db.GetNodeIDByPubKey(
×
2991
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
2992
                        Version: int16(ProtocolV1),
×
2993
                        PubKey:  nodePub[:],
×
2994
                },
×
2995
        )
×
2996
        if errors.Is(err, sql.ErrNoRows) {
×
2997
                return nil
×
2998
        } else if err != nil {
×
2999
                return fmt.Errorf("unable to fetch node: %w", err)
×
3000
        }
×
3001

3002
        rows, err := db.ListChannelsByNodeID(
×
3003
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3004
                        Version: int16(ProtocolV1),
×
3005
                        NodeID1: dbID,
×
3006
                },
×
3007
        )
×
3008
        if err != nil {
×
3009
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3010
        }
×
3011

3012
        // Exit early if there are no channels for this node so we don't
3013
        // do the unnecessary feature fetching.
3014
        if len(rows) == 0 {
×
3015
                return nil
×
3016
        }
×
3017

3018
        features, err := getNodeFeatures(ctx, db, dbID)
×
3019
        if err != nil {
×
3020
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3021
        }
×
3022

3023
        for _, row := range rows {
×
3024
                node1, node2, err := buildNodeVertices(
×
3025
                        row.Node1Pubkey, row.Node2Pubkey,
×
3026
                )
×
3027
                if err != nil {
×
3028
                        return fmt.Errorf("unable to build node vertices: %w",
×
3029
                                err)
×
3030
                }
×
3031

3032
                edge := buildCacheableChannelInfo(
×
3033
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3034
                        node1, node2,
×
3035
                )
×
3036

×
3037
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3038
                if err != nil {
×
3039
                        return err
×
3040
                }
×
3041

3042
                var p1, p2 *models.CachedEdgePolicy
×
3043
                if dbPol1 != nil {
×
3044
                        policy1, err := buildChanPolicy(
×
3045
                                *dbPol1, edge.ChannelID, nil, node2,
×
3046
                        )
×
3047
                        if err != nil {
×
3048
                                return err
×
3049
                        }
×
3050

3051
                        p1 = models.NewCachedPolicy(policy1)
×
3052
                }
3053
                if dbPol2 != nil {
×
3054
                        policy2, err := buildChanPolicy(
×
3055
                                *dbPol2, edge.ChannelID, nil, node1,
×
3056
                        )
×
3057
                        if err != nil {
×
3058
                                return err
×
3059
                        }
×
3060

3061
                        p2 = models.NewCachedPolicy(policy2)
×
3062
                }
3063

3064
                // Determine the outgoing and incoming policy for this
3065
                // channel and node combo.
3066
                outPolicy, inPolicy := p1, p2
×
3067
                if p1 != nil && node2 == nodePub {
×
3068
                        outPolicy, inPolicy = p2, p1
×
3069
                } else if p2 != nil && node1 != nodePub {
×
3070
                        outPolicy, inPolicy = p2, p1
×
3071
                }
×
3072

3073
                var cachedInPolicy *models.CachedEdgePolicy
×
3074
                if inPolicy != nil {
×
3075
                        cachedInPolicy = inPolicy
×
3076
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3077
                        cachedInPolicy.ToNodeFeatures = features
×
3078
                }
×
3079

3080
                directedChannel := &DirectedChannel{
×
3081
                        ChannelID:    edge.ChannelID,
×
3082
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3083
                        OtherNode:    edge.NodeKey2Bytes,
×
3084
                        Capacity:     edge.Capacity,
×
3085
                        OutPolicySet: outPolicy != nil,
×
3086
                        InPolicy:     cachedInPolicy,
×
3087
                }
×
3088
                if outPolicy != nil {
×
3089
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3090
                                directedChannel.InboundFee = fee
×
3091
                        })
×
3092
                }
3093

3094
                if nodePub == edge.NodeKey2Bytes {
×
3095
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3096
                }
×
3097

3098
                if err := cb(directedChannel); err != nil {
×
3099
                        return err
×
3100
                }
×
3101
        }
3102

3103
        return nil
×
3104
}
3105

3106
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3107
// and executes the provided callback for each node.
3108
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
3109
        cb func(nodeID int64, nodePub route.Vertex) error) error {
×
3110

×
3111
        lastID := int64(-1)
×
3112

×
3113
        for {
×
3114
                nodes, err := db.ListNodeIDsAndPubKeys(
×
3115
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3116
                                Version: int16(ProtocolV1),
×
3117
                                ID:      lastID,
×
3118
                                Limit:   pageSize,
×
3119
                        },
×
3120
                )
×
3121
                if err != nil {
×
3122
                        return fmt.Errorf("unable to fetch nodes: %w", err)
×
3123
                }
×
3124

3125
                if len(nodes) == 0 {
×
3126
                        break
×
3127
                }
3128

3129
                for _, node := range nodes {
×
3130
                        var pub route.Vertex
×
3131
                        copy(pub[:], node.PubKey)
×
3132

×
3133
                        if err := cb(node.ID, pub); err != nil {
×
3134
                                return fmt.Errorf("forEachNodeCacheable "+
×
3135
                                        "callback failed for node(id=%d): %w",
×
3136
                                        node.ID, err)
×
3137
                        }
×
3138

3139
                        lastID = node.ID
×
3140
                }
3141
        }
3142

3143
        return nil
×
3144
}
3145

3146
// forEachNodeChannel iterates through all channels of a node, executing
3147
// the passed callback on each. The call-back is provided with the channel's
3148
// edge information, the outgoing policy and the incoming policy for the
3149
// channel and node combo.
3150
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3151
        chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
3152
                *models.ChannelEdgePolicy,
3153
                *models.ChannelEdgePolicy) error) error {
×
3154

×
3155
        // Get all the V1 channels for this node.Add commentMore actions
×
3156
        rows, err := db.ListChannelsByNodeID(
×
3157
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3158
                        Version: int16(ProtocolV1),
×
3159
                        NodeID1: id,
×
3160
                },
×
3161
        )
×
3162
        if err != nil {
×
3163
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3164
        }
×
3165

3166
        // Call the call-back for each channel and its known policies.
3167
        for _, row := range rows {
×
3168
                node1, node2, err := buildNodeVertices(
×
3169
                        row.Node1Pubkey, row.Node2Pubkey,
×
3170
                )
×
3171
                if err != nil {
×
3172
                        return fmt.Errorf("unable to build node vertices: %w",
×
3173
                                err)
×
3174
                }
×
3175

3176
                edge, err := getAndBuildEdgeInfo(
×
3177
                        ctx, db, chain, row.GraphChannel.ID, row.GraphChannel,
×
3178
                        node1, node2,
×
3179
                )
×
3180
                if err != nil {
×
3181
                        return fmt.Errorf("unable to build channel info: %w",
×
3182
                                err)
×
3183
                }
×
3184

3185
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3186
                if err != nil {
×
3187
                        return fmt.Errorf("unable to extract channel "+
×
3188
                                "policies: %w", err)
×
3189
                }
×
3190

3191
                p1, p2, err := getAndBuildChanPolicies(
×
3192
                        ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3193
                )
×
3194
                if err != nil {
×
3195
                        return fmt.Errorf("unable to build channel "+
×
3196
                                "policies: %w", err)
×
3197
                }
×
3198

3199
                // Determine the outgoing and incoming policy for this
3200
                // channel and node combo.
3201
                p1ToNode := row.GraphChannel.NodeID2
×
3202
                p2ToNode := row.GraphChannel.NodeID1
×
3203
                outPolicy, inPolicy := p1, p2
×
3204
                if (p1 != nil && p1ToNode == id) ||
×
3205
                        (p2 != nil && p2ToNode != id) {
×
3206

×
3207
                        outPolicy, inPolicy = p2, p1
×
3208
                }
×
3209

3210
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3211
                        return err
×
3212
                }
×
3213
        }
3214

3215
        return nil
×
3216
}
3217

3218
// updateChanEdgePolicy upserts the channel policy info we have stored for
3219
// a channel we already know of.
3220
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3221
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3222
        error) {
×
3223

×
3224
        var (
×
3225
                node1Pub, node2Pub route.Vertex
×
3226
                isNode1            bool
×
3227
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3228
        )
×
3229

×
3230
        // Check that this edge policy refers to a channel that we already
×
3231
        // know of. We do this explicitly so that we can return the appropriate
×
3232
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3233
        // abort the transaction which would abort the entire batch.
×
3234
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3235
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3236
                        Scid:    chanIDB,
×
3237
                        Version: int16(ProtocolV1),
×
3238
                },
×
3239
        )
×
3240
        if errors.Is(err, sql.ErrNoRows) {
×
3241
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3242
        } else if err != nil {
×
3243
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3244
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3245
        }
×
3246

3247
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3248
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3249

×
3250
        // Figure out which node this edge is from.
×
3251
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3252
        nodeID := dbChan.NodeID1
×
3253
        if !isNode1 {
×
3254
                nodeID = dbChan.NodeID2
×
3255
        }
×
3256

3257
        var (
×
3258
                inboundBase sql.NullInt64
×
3259
                inboundRate sql.NullInt64
×
3260
        )
×
3261
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3262
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3263
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3264
        })
×
3265

3266
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3267
                Version:     int16(ProtocolV1),
×
3268
                ChannelID:   dbChan.ID,
×
3269
                NodeID:      nodeID,
×
3270
                Timelock:    int32(edge.TimeLockDelta),
×
3271
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3272
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3273
                MinHtlcMsat: int64(edge.MinHTLC),
×
3274
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3275
                Disabled: sql.NullBool{
×
3276
                        Valid: true,
×
3277
                        Bool:  edge.IsDisabled(),
×
3278
                },
×
3279
                MaxHtlcMsat: sql.NullInt64{
×
3280
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3281
                        Int64: int64(edge.MaxHTLC),
×
3282
                },
×
3283
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3284
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3285
                InboundBaseFeeMsat:      inboundBase,
×
3286
                InboundFeeRateMilliMsat: inboundRate,
×
3287
                Signature:               edge.SigBytes,
×
3288
        })
×
3289
        if err != nil {
×
3290
                return node1Pub, node2Pub, isNode1,
×
3291
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3292
        }
×
3293

3294
        // Convert the flat extra opaque data into a map of TLV types to
3295
        // values.
3296
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3297
        if err != nil {
×
3298
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3299
                        "marshal extra opaque data: %w", err)
×
3300
        }
×
3301

3302
        // Update the channel policy's extra signed fields.
3303
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3304
        if err != nil {
×
3305
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3306
                        "policy extra TLVs: %w", err)
×
3307
        }
×
3308

3309
        return node1Pub, node2Pub, isNode1, nil
×
3310
}
3311

3312
// getNodeByPubKey attempts to look up a target node by its public key.
3313
func getNodeByPubKey(ctx context.Context, db SQLQueries,
3314
        pubKey route.Vertex) (int64, *models.LightningNode, error) {
×
3315

×
3316
        dbNode, err := db.GetNodeByPubKey(
×
3317
                ctx, sqlc.GetNodeByPubKeyParams{
×
3318
                        Version: int16(ProtocolV1),
×
3319
                        PubKey:  pubKey[:],
×
3320
                },
×
3321
        )
×
3322
        if errors.Is(err, sql.ErrNoRows) {
×
3323
                return 0, nil, ErrGraphNodeNotFound
×
3324
        } else if err != nil {
×
3325
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3326
        }
×
3327

3328
        node, err := buildNode(ctx, db, &dbNode)
×
3329
        if err != nil {
×
3330
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3331
        }
×
3332

3333
        return dbNode.ID, node, nil
×
3334
}
3335

3336
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3337
// provided parameters.
3338
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3339
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3340

×
3341
        return &models.CachedEdgeInfo{
×
3342
                ChannelID:     byteOrder.Uint64(scid),
×
3343
                NodeKey1Bytes: node1Pub,
×
3344
                NodeKey2Bytes: node2Pub,
×
3345
                Capacity:      btcutil.Amount(capacity),
×
3346
        }
×
3347
}
×
3348

3349
// buildNode constructs a LightningNode instance from the given database node
3350
// record. The node's features, addresses and extra signed fields are also
3351
// fetched from the database and set on the node.
3352
func buildNode(ctx context.Context, db SQLQueries,
NEW
3353
        dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
×
NEW
3354

×
NEW
3355
        // NOTE: buildNode is only used to load the data for a single node, and
×
NEW
3356
        // so no paged queries will be performed. This means that it's ok to
×
NEW
3357
        // used pass in default config values here.
×
NEW
3358
        cfg := sqldb.DefaultPagedQueryConfig()
×
NEW
3359

×
NEW
3360
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
NEW
3361
        if err != nil {
×
NEW
3362
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
NEW
3363
                        err)
×
NEW
3364
        }
×
3365

NEW
3366
        return buildNodeWithBatchData(dbNode, data)
×
3367
}
3368

3369
// buildNodeWithBatchData builds a models.LightningNode instance
3370
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3371
// features/addresses/extra fields, then the corresponding fields are expected
3372
// to be present in the batchNodeData.
3373
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
NEW
3374
        batchData *batchNodeData) (*models.LightningNode, error) {
×
3375

×
3376
        if dbNode.Version != int16(ProtocolV1) {
×
3377
                return nil, fmt.Errorf("unsupported node version: %d",
×
3378
                        dbNode.Version)
×
3379
        }
×
3380

3381
        var pub [33]byte
×
3382
        copy(pub[:], dbNode.PubKey)
×
3383

×
3384
        node := &models.LightningNode{
×
3385
                PubKeyBytes: pub,
×
3386
                Features:    lnwire.EmptyFeatureVector(),
×
3387
                LastUpdate:  time.Unix(0, 0),
×
3388
        }
×
3389

×
3390
        if len(dbNode.Signature) == 0 {
×
3391
                return node, nil
×
3392
        }
×
3393

3394
        node.HaveNodeAnnouncement = true
×
3395
        node.AuthSigBytes = dbNode.Signature
×
3396
        node.Alias = dbNode.Alias.String
×
3397
        node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3398

×
3399
        var err error
×
3400
        if dbNode.Color.Valid {
×
3401
                node.Color, err = DecodeHexColor(dbNode.Color.String)
×
3402
                if err != nil {
×
3403
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3404
                                err)
×
3405
                }
×
3406
        }
3407

3408
        // Use preloaded features.
NEW
3409
        if features, exists := batchData.features[dbNode.ID]; exists {
×
NEW
3410
                fv := lnwire.EmptyFeatureVector()
×
NEW
3411
                for _, bit := range features {
×
NEW
3412
                        fv.Set(lnwire.FeatureBit(bit))
×
NEW
3413
                }
×
NEW
3414
                node.Features = fv
×
3415
        }
3416

3417
        // Use preloaded addresses.
NEW
3418
        addresses, exists := batchData.addresses[dbNode.ID]
×
NEW
3419
        if exists && len(addresses) > 0 {
×
NEW
3420
                node.Addresses, err = buildNodeAddresses(addresses)
×
NEW
3421
                if err != nil {
×
NEW
3422
                        return nil, fmt.Errorf("unable to build addresses "+
×
NEW
3423
                                "for node(%d): %w", dbNode.ID, err)
×
NEW
3424
                }
×
3425
        }
3426

3427
        // Use preloaded extra fields.
NEW
3428
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
NEW
3429
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
NEW
3430
                if err != nil {
×
NEW
3431
                        return nil, fmt.Errorf("unable to serialize extra "+
×
NEW
3432
                                "signed fields: %w", err)
×
NEW
3433
                }
×
NEW
3434
                if len(recs) != 0 {
×
NEW
3435
                        node.ExtraOpaqueData = recs
×
NEW
3436
                }
×
3437
        }
3438

NEW
3439
        return node, nil
×
3440
}
3441

3442
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3443
// with the preloaded data, and executes the provided callback for each node.
3444
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.PagedQueryConfig,
3445
        db SQLQueries, nodes []sqlc.GraphNode,
NEW
3446
        cb func(dbID int64, node *models.LightningNode) error) error {
×
NEW
3447

×
NEW
3448
        // Extract node IDs for batch loading.
×
NEW
3449
        nodeIDs := make([]int64, len(nodes))
×
NEW
3450
        for i, node := range nodes {
×
NEW
3451
                nodeIDs[i] = node.ID
×
NEW
3452
        }
×
3453

3454
        // Batch load all related data for this page.
NEW
3455
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3456
        if err != nil {
×
NEW
3457
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3458
        }
×
3459

NEW
3460
        for _, dbNode := range nodes {
×
NEW
3461
                node, err := buildNodeWithBatchData(&dbNode, batchData)
×
NEW
3462
                if err != nil {
×
NEW
3463
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
NEW
3464
                                dbNode.ID, err)
×
NEW
3465
                }
×
3466

NEW
3467
                if err := cb(dbNode.ID, node); err != nil {
×
NEW
3468
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
NEW
3469
                                dbNode.ID, err)
×
NEW
3470
                }
×
3471
        }
3472

NEW
3473
        return nil
×
3474
}
3475

3476
// getNodeFeatures fetches the feature bits and constructs the feature vector
3477
// for a node with the given DB ID.
3478
func getNodeFeatures(ctx context.Context, db SQLQueries,
3479
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3480

×
3481
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3482
        if err != nil {
×
3483
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3484
                        nodeID, err)
×
3485
        }
×
3486

3487
        features := lnwire.EmptyFeatureVector()
×
3488
        for _, feature := range rows {
×
3489
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3490
        }
×
3491

3492
        return features, nil
×
3493
}
3494

3495
// upsertNode upserts the node record into the database. If the node already
3496
// exists, then the node's information is updated. If the node doesn't exist,
3497
// then a new node is created. The node's features, addresses and extra TLV
3498
// types are also updated. The node's DB ID is returned.
3499
func upsertNode(ctx context.Context, db SQLQueries,
3500
        node *models.LightningNode) (int64, error) {
×
3501

×
3502
        params := sqlc.UpsertNodeParams{
×
3503
                Version: int16(ProtocolV1),
×
3504
                PubKey:  node.PubKeyBytes[:],
×
3505
        }
×
3506

×
3507
        if node.HaveNodeAnnouncement {
×
3508
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
3509
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
3510
                params.Alias = sqldb.SQLStr(node.Alias)
×
3511
                params.Signature = node.AuthSigBytes
×
3512
        }
×
3513

3514
        nodeID, err := db.UpsertNode(ctx, params)
×
3515
        if err != nil {
×
3516
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3517
                        err)
×
3518
        }
×
3519

3520
        // We can exit here if we don't have the announcement yet.
3521
        if !node.HaveNodeAnnouncement {
×
3522
                return nodeID, nil
×
3523
        }
×
3524

3525
        // Update the node's features.
3526
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3527
        if err != nil {
×
3528
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3529
        }
×
3530

3531
        // Update the node's addresses.
3532
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3533
        if err != nil {
×
3534
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3535
        }
×
3536

3537
        // Convert the flat extra opaque data into a map of TLV types to
3538
        // values.
3539
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3540
        if err != nil {
×
3541
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3542
                        err)
×
3543
        }
×
3544

3545
        // Update the node's extra signed fields.
3546
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3547
        if err != nil {
×
3548
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3549
        }
×
3550

3551
        return nodeID, nil
×
3552
}
3553

3554
// upsertNodeFeatures updates the node's features node_features table. This
3555
// includes deleting any feature bits no longer present and inserting any new
3556
// feature bits. If the feature bit does not yet exist in the features table,
3557
// then an entry is created in that table first.
3558
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3559
        features *lnwire.FeatureVector) error {
×
3560

×
3561
        // Get any existing features for the node.
×
3562
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3563
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3564
                return err
×
3565
        }
×
3566

3567
        // Copy the nodes latest set of feature bits.
3568
        newFeatures := make(map[int32]struct{})
×
3569
        if features != nil {
×
3570
                for feature := range features.Features() {
×
3571
                        newFeatures[int32(feature)] = struct{}{}
×
3572
                }
×
3573
        }
3574

3575
        // For any current feature that already exists in the DB, remove it from
3576
        // the in-memory map. For any existing feature that does not exist in
3577
        // the in-memory map, delete it from the database.
3578
        for _, feature := range existingFeatures {
×
3579
                // The feature is still present, so there are no updates to be
×
3580
                // made.
×
3581
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3582
                        delete(newFeatures, feature.FeatureBit)
×
3583
                        continue
×
3584
                }
3585

3586
                // The feature is no longer present, so we remove it from the
3587
                // database.
3588
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3589
                        NodeID:     nodeID,
×
3590
                        FeatureBit: feature.FeatureBit,
×
3591
                })
×
3592
                if err != nil {
×
3593
                        return fmt.Errorf("unable to delete node(%d) "+
×
3594
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3595
                                err)
×
3596
                }
×
3597
        }
3598

3599
        // Any remaining entries in newFeatures are new features that need to be
3600
        // added to the database for the first time.
3601
        for feature := range newFeatures {
×
3602
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3603
                        NodeID:     nodeID,
×
3604
                        FeatureBit: feature,
×
3605
                })
×
3606
                if err != nil {
×
3607
                        return fmt.Errorf("unable to insert node(%d) "+
×
3608
                                "feature(%v): %w", nodeID, feature, err)
×
3609
                }
×
3610
        }
3611

3612
        return nil
×
3613
}
3614

3615
// fetchNodeFeatures fetches the features for a node with the given public key.
3616
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3617
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3618

×
3619
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3620
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3621
                        PubKey:  nodePub[:],
×
3622
                        Version: int16(ProtocolV1),
×
3623
                },
×
3624
        )
×
3625
        if err != nil {
×
3626
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3627
                        nodePub, err)
×
3628
        }
×
3629

3630
        features := lnwire.EmptyFeatureVector()
×
3631
        for _, bit := range rows {
×
3632
                features.Set(lnwire.FeatureBit(bit))
×
3633
        }
×
3634

3635
        return features, nil
×
3636
}
3637

3638
// dbAddressType is an enum type that represents the different address types
3639
// that we store in the node_addresses table. The address type determines how
3640
// the address is to be serialised/deserialize.
3641
type dbAddressType uint8
3642

3643
const (
3644
        addressTypeIPv4   dbAddressType = 1
3645
        addressTypeIPv6   dbAddressType = 2
3646
        addressTypeTorV2  dbAddressType = 3
3647
        addressTypeTorV3  dbAddressType = 4
3648
        addressTypeOpaque dbAddressType = math.MaxInt8
3649
)
3650

3651
// upsertNodeAddresses updates the node's addresses in the database. This
3652
// includes deleting any existing addresses and inserting the new set of
3653
// addresses. The deletion is necessary since the ordering of the addresses may
3654
// change, and we need to ensure that the database reflects the latest set of
3655
// addresses so that at the time of reconstructing the node announcement, the
3656
// order is preserved and the signature over the message remains valid.
3657
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3658
        addresses []net.Addr) error {
×
3659

×
3660
        // Delete any existing addresses for the node. This is required since
×
3661
        // even if the new set of addresses is the same, the ordering may have
×
3662
        // changed for a given address type.
×
3663
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3664
        if err != nil {
×
3665
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3666
                        nodeID, err)
×
3667
        }
×
3668

3669
        // Copy the nodes latest set of addresses.
3670
        newAddresses := map[dbAddressType][]string{
×
3671
                addressTypeIPv4:   {},
×
3672
                addressTypeIPv6:   {},
×
3673
                addressTypeTorV2:  {},
×
3674
                addressTypeTorV3:  {},
×
3675
                addressTypeOpaque: {},
×
3676
        }
×
3677
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3678
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3679
        }
×
3680

3681
        for _, address := range addresses {
×
3682
                switch addr := address.(type) {
×
3683
                case *net.TCPAddr:
×
3684
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3685
                                addAddr(addressTypeIPv4, addr)
×
3686
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3687
                                addAddr(addressTypeIPv6, addr)
×
3688
                        } else {
×
3689
                                return fmt.Errorf("unhandled IP address: %v",
×
3690
                                        addr)
×
3691
                        }
×
3692

3693
                case *tor.OnionAddr:
×
3694
                        switch len(addr.OnionService) {
×
3695
                        case tor.V2Len:
×
3696
                                addAddr(addressTypeTorV2, addr)
×
3697
                        case tor.V3Len:
×
3698
                                addAddr(addressTypeTorV3, addr)
×
3699
                        default:
×
3700
                                return fmt.Errorf("invalid length for a tor " +
×
3701
                                        "address")
×
3702
                        }
3703

3704
                case *lnwire.OpaqueAddrs:
×
3705
                        addAddr(addressTypeOpaque, addr)
×
3706

3707
                default:
×
3708
                        return fmt.Errorf("unhandled address type: %T", addr)
×
3709
                }
3710
        }
3711

3712
        // Any remaining entries in newAddresses are new addresses that need to
3713
        // be added to the database for the first time.
3714
        for addrType, addrList := range newAddresses {
×
3715
                for position, addr := range addrList {
×
3716
                        err := db.InsertNodeAddress(
×
3717
                                ctx, sqlc.InsertNodeAddressParams{
×
3718
                                        NodeID:   nodeID,
×
3719
                                        Type:     int16(addrType),
×
3720
                                        Address:  addr,
×
3721
                                        Position: int32(position),
×
3722
                                },
×
3723
                        )
×
3724
                        if err != nil {
×
3725
                                return fmt.Errorf("unable to insert "+
×
3726
                                        "node(%d) address(%v): %w", nodeID,
×
3727
                                        addr, err)
×
3728
                        }
×
3729
                }
3730
        }
3731

3732
        return nil
×
3733
}
3734

3735
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3736
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3737
        error) {
×
3738

×
3739
        // GetNodeAddresses ensures that the addresses for a given type are
×
3740
        // returned in the same order as they were inserted.
×
3741
        rows, err := db.GetNodeAddresses(ctx, id)
×
3742
        if err != nil {
×
3743
                return nil, err
×
3744
        }
×
3745

3746
        addresses := make([]net.Addr, 0, len(rows))
×
3747
        for _, row := range rows {
×
3748
                address := row.Address
×
3749

×
NEW
3750
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
NEW
3751
                if err != nil {
×
NEW
3752
                        return nil, fmt.Errorf("unable to parse address "+
×
NEW
3753
                                "for node(%d): %v: %w", id, address, err)
×
UNCOV
3754
                }
×
3755

NEW
3756
                addresses = append(addresses, addr)
×
3757
        }
3758

3759
        // If we have no addresses, then we'll return nil instead of an
3760
        // empty slice.
3761
        if len(addresses) == 0 {
×
3762
                addresses = nil
×
3763
        }
×
3764

3765
        return addresses, nil
×
3766
}
3767

3768
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3769
// database. This includes updating any existing types, inserting any new types,
3770
// and deleting any types that are no longer present.
3771
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3772
        nodeID int64, extraFields map[uint64][]byte) error {
×
3773

×
3774
        // Get any existing extra signed fields for the node.
×
3775
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3776
        if err != nil {
×
3777
                return err
×
3778
        }
×
3779

3780
        // Make a lookup map of the existing field types so that we can use it
3781
        // to keep track of any fields we should delete.
3782
        m := make(map[uint64]bool)
×
3783
        for _, field := range existingFields {
×
3784
                m[uint64(field.Type)] = true
×
3785
        }
×
3786

3787
        // For all the new fields, we'll upsert them and remove them from the
3788
        // map of existing fields.
3789
        for tlvType, value := range extraFields {
×
3790
                err = db.UpsertNodeExtraType(
×
3791
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3792
                                NodeID: nodeID,
×
3793
                                Type:   int64(tlvType),
×
3794
                                Value:  value,
×
3795
                        },
×
3796
                )
×
3797
                if err != nil {
×
3798
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3799
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3800
                }
×
3801

3802
                // Remove the field from the map of existing fields if it was
3803
                // present.
3804
                delete(m, tlvType)
×
3805
        }
3806

3807
        // For all the fields that are left in the map of existing fields, we'll
3808
        // delete them as they are no longer present in the new set of fields.
3809
        for tlvType := range m {
×
3810
                err = db.DeleteExtraNodeType(
×
3811
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3812
                                NodeID: nodeID,
×
3813
                                Type:   int64(tlvType),
×
3814
                        },
×
3815
                )
×
3816
                if err != nil {
×
3817
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3818
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3819
                }
×
3820
        }
3821

3822
        return nil
×
3823
}
3824

3825
// srcNodeInfo holds the information about the source node of the graph.
3826
type srcNodeInfo struct {
3827
        // id is the DB level ID of the source node entry in the "nodes" table.
3828
        id int64
3829

3830
        // pub is the public key of the source node.
3831
        pub route.Vertex
3832
}
3833

3834
// sourceNode returns the DB node ID and pub key of the source node for the
3835
// specified protocol version.
3836
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3837
        version ProtocolVersion) (int64, route.Vertex, error) {
×
3838

×
3839
        s.srcNodeMu.Lock()
×
3840
        defer s.srcNodeMu.Unlock()
×
3841

×
3842
        // If we already have the source node ID and pub key cached, then
×
3843
        // return them.
×
3844
        if info, ok := s.srcNodes[version]; ok {
×
3845
                return info.id, info.pub, nil
×
3846
        }
×
3847

3848
        var pubKey route.Vertex
×
3849

×
3850
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3851
        if err != nil {
×
3852
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3853
                        err)
×
3854
        }
×
3855

3856
        if len(nodes) == 0 {
×
3857
                return 0, pubKey, ErrSourceNodeNotSet
×
3858
        } else if len(nodes) > 1 {
×
3859
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
3860
                        "protocol %s found", version)
×
3861
        }
×
3862

3863
        copy(pubKey[:], nodes[0].PubKey)
×
3864

×
3865
        s.srcNodes[version] = &srcNodeInfo{
×
3866
                id:  nodes[0].NodeID,
×
3867
                pub: pubKey,
×
3868
        }
×
3869

×
3870
        return nodes[0].NodeID, pubKey, nil
×
3871
}
3872

3873
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
3874
// This then produces a map from TLV type to value. If the input is not a
3875
// valid TLV stream, then an error is returned.
3876
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
3877
        r := bytes.NewReader(data)
×
3878

×
3879
        tlvStream, err := tlv.NewStream()
×
3880
        if err != nil {
×
3881
                return nil, err
×
3882
        }
×
3883

3884
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
3885
        // pass it into the P2P decoding variant.
3886
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
3887
        if err != nil {
×
3888
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
3889
        }
×
3890
        if len(parsedTypes) == 0 {
×
3891
                return nil, nil
×
3892
        }
×
3893

3894
        records := make(map[uint64][]byte)
×
3895
        for k, v := range parsedTypes {
×
3896
                records[uint64(k)] = v
×
3897
        }
×
3898

3899
        return records, nil
×
3900
}
3901

3902
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
3903
// channel.
3904
type dbChanInfo struct {
3905
        channelID int64
3906
        node1ID   int64
3907
        node2ID   int64
3908
}
3909

3910
// insertChannel inserts a new channel record into the database.
3911
func insertChannel(ctx context.Context, db SQLQueries,
3912
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
3913

×
3914
        chanIDB := channelIDToBytes(edge.ChannelID)
×
3915

×
3916
        // Make sure that the channel doesn't already exist. We do this
×
3917
        // explicitly instead of relying on catching a unique constraint error
×
3918
        // because relying on SQL to throw that error would abort the entire
×
3919
        // batch of transactions.
×
3920
        _, err := db.GetChannelBySCID(
×
3921
                ctx, sqlc.GetChannelBySCIDParams{
×
3922
                        Scid:    chanIDB,
×
3923
                        Version: int16(ProtocolV1),
×
3924
                },
×
3925
        )
×
3926
        if err == nil {
×
3927
                return nil, ErrEdgeAlreadyExist
×
3928
        } else if !errors.Is(err, sql.ErrNoRows) {
×
3929
                return nil, fmt.Errorf("unable to fetch channel: %w", err)
×
3930
        }
×
3931

3932
        // Make sure that at least a "shell" entry for each node is present in
3933
        // the nodes table.
3934
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
3935
        if err != nil {
×
3936
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3937
        }
×
3938

3939
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
3940
        if err != nil {
×
3941
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
3942
        }
×
3943

3944
        var capacity sql.NullInt64
×
3945
        if edge.Capacity != 0 {
×
3946
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
3947
        }
×
3948

3949
        createParams := sqlc.CreateChannelParams{
×
3950
                Version:     int16(ProtocolV1),
×
3951
                Scid:        chanIDB,
×
3952
                NodeID1:     node1DBID,
×
3953
                NodeID2:     node2DBID,
×
3954
                Outpoint:    edge.ChannelPoint.String(),
×
3955
                Capacity:    capacity,
×
3956
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
3957
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
3958
        }
×
3959

×
3960
        if edge.AuthProof != nil {
×
3961
                proof := edge.AuthProof
×
3962

×
3963
                createParams.Node1Signature = proof.NodeSig1Bytes
×
3964
                createParams.Node2Signature = proof.NodeSig2Bytes
×
3965
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
3966
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
3967
        }
×
3968

3969
        // Insert the new channel record.
3970
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
3971
        if err != nil {
×
3972
                return nil, err
×
3973
        }
×
3974

3975
        // Insert any channel features.
3976
        for feature := range edge.Features.Features() {
×
3977
                err = db.InsertChannelFeature(
×
3978
                        ctx, sqlc.InsertChannelFeatureParams{
×
3979
                                ChannelID:  dbChanID,
×
3980
                                FeatureBit: int32(feature),
×
3981
                        },
×
3982
                )
×
3983
                if err != nil {
×
3984
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
3985
                                "feature(%v): %w", dbChanID, feature, err)
×
3986
                }
×
3987
        }
3988

3989
        // Finally, insert any extra TLV fields in the channel announcement.
3990
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3991
        if err != nil {
×
3992
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
3993
                        "data: %w", err)
×
3994
        }
×
3995

3996
        for tlvType, value := range extra {
×
3997
                err := db.CreateChannelExtraType(
×
3998
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
3999
                                ChannelID: dbChanID,
×
4000
                                Type:      int64(tlvType),
×
4001
                                Value:     value,
×
4002
                        },
×
4003
                )
×
4004
                if err != nil {
×
4005
                        return nil, fmt.Errorf("unable to upsert "+
×
4006
                                "channel(%d) extra signed field(%v): %w",
×
4007
                                edge.ChannelID, tlvType, err)
×
4008
                }
×
4009
        }
4010

4011
        return &dbChanInfo{
×
4012
                channelID: dbChanID,
×
4013
                node1ID:   node1DBID,
×
4014
                node2ID:   node2DBID,
×
4015
        }, nil
×
4016
}
4017

4018
// maybeCreateShellNode checks if a shell node entry exists for the
4019
// given public key. If it does not exist, then a new shell node entry is
4020
// created. The ID of the node is returned. A shell node only has a protocol
4021
// version and public key persisted.
4022
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4023
        pubKey route.Vertex) (int64, error) {
×
4024

×
4025
        dbNode, err := db.GetNodeByPubKey(
×
4026
                ctx, sqlc.GetNodeByPubKeyParams{
×
4027
                        PubKey:  pubKey[:],
×
4028
                        Version: int16(ProtocolV1),
×
4029
                },
×
4030
        )
×
4031
        // The node exists. Return the ID.
×
4032
        if err == nil {
×
4033
                return dbNode.ID, nil
×
4034
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4035
                return 0, err
×
4036
        }
×
4037

4038
        // Otherwise, the node does not exist, so we create a shell entry for
4039
        // it.
4040
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4041
                Version: int16(ProtocolV1),
×
4042
                PubKey:  pubKey[:],
×
4043
        })
×
4044
        if err != nil {
×
4045
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4046
        }
×
4047

4048
        return id, nil
×
4049
}
4050

4051
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4052
// the database. This includes deleting any existing types and then inserting
4053
// the new types.
4054
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4055
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4056

×
4057
        // Delete all existing extra signed fields for the channel policy.
×
4058
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4059
        if err != nil {
×
4060
                return fmt.Errorf("unable to delete "+
×
4061
                        "existing policy extra signed fields for policy %d: %w",
×
4062
                        chanPolicyID, err)
×
4063
        }
×
4064

4065
        // Insert all new extra signed fields for the channel policy.
4066
        for tlvType, value := range extraFields {
×
4067
                err = db.InsertChanPolicyExtraType(
×
4068
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
4069
                                ChannelPolicyID: chanPolicyID,
×
4070
                                Type:            int64(tlvType),
×
4071
                                Value:           value,
×
4072
                        },
×
4073
                )
×
4074
                if err != nil {
×
4075
                        return fmt.Errorf("unable to insert "+
×
4076
                                "channel_policy(%d) extra signed field(%v): %w",
×
4077
                                chanPolicyID, tlvType, err)
×
4078
                }
×
4079
        }
4080

4081
        return nil
×
4082
}
4083

4084
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4085
// provided dbChanRow and also fetches any other required information
4086
// to construct the edge info.
4087
func getAndBuildEdgeInfo(ctx context.Context, db SQLQueries,
4088
        chain chainhash.Hash, dbChanID int64, dbChan sqlc.GraphChannel, node1,
4089
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4090

×
4091
        if dbChan.Version != int16(ProtocolV1) {
×
4092
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4093
                        dbChan.Version)
×
4094
        }
×
4095

4096
        fv, extras, err := getChanFeaturesAndExtras(
×
4097
                ctx, db, dbChanID,
×
4098
        )
×
4099
        if err != nil {
×
4100
                return nil, err
×
4101
        }
×
4102

4103
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4104
        if err != nil {
×
4105
                return nil, err
×
4106
        }
×
4107

4108
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4109
        if err != nil {
×
4110
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4111
                        "fields: %w", err)
×
4112
        }
×
4113
        if recs == nil {
×
4114
                recs = make([]byte, 0)
×
4115
        }
×
4116

4117
        var btcKey1, btcKey2 route.Vertex
×
4118
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4119
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4120

×
4121
        channel := &models.ChannelEdgeInfo{
×
4122
                ChainHash:        chain,
×
4123
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4124
                NodeKey1Bytes:    node1,
×
4125
                NodeKey2Bytes:    node2,
×
4126
                BitcoinKey1Bytes: btcKey1,
×
4127
                BitcoinKey2Bytes: btcKey2,
×
4128
                ChannelPoint:     *op,
×
4129
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4130
                Features:         fv,
×
4131
                ExtraOpaqueData:  recs,
×
4132
        }
×
4133

×
4134
        // We always set all the signatures at the same time, so we can
×
4135
        // safely check if one signature is present to determine if we have the
×
4136
        // rest of the signatures for the auth proof.
×
4137
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4138
                channel.AuthProof = &models.ChannelAuthProof{
×
4139
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4140
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4141
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4142
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4143
                }
×
4144
        }
×
4145

4146
        return channel, nil
×
4147
}
4148

4149
// buildNodeVertices is a helper that converts raw node public keys
4150
// into route.Vertex instances.
4151
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4152
        route.Vertex, error) {
×
4153

×
4154
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4155
        if err != nil {
×
4156
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4157
                        "create vertex from node1 pubkey: %w", err)
×
4158
        }
×
4159

4160
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4161
        if err != nil {
×
4162
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4163
                        "create vertex from node2 pubkey: %w", err)
×
4164
        }
×
4165

4166
        return node1Vertex, node2Vertex, nil
×
4167
}
4168

4169
// getChanFeaturesAndExtras fetches the channel features and extra TLV types
4170
// for a channel with the given ID.
4171
func getChanFeaturesAndExtras(ctx context.Context, db SQLQueries,
4172
        id int64) (*lnwire.FeatureVector, map[uint64][]byte, error) {
×
4173

×
4174
        rows, err := db.GetChannelFeaturesAndExtras(ctx, id)
×
4175
        if err != nil {
×
4176
                return nil, nil, fmt.Errorf("unable to fetch channel "+
×
4177
                        "features and extras: %w", err)
×
4178
        }
×
4179

4180
        var (
×
4181
                fv     = lnwire.EmptyFeatureVector()
×
4182
                extras = make(map[uint64][]byte)
×
4183
        )
×
4184
        for _, row := range rows {
×
4185
                if row.IsFeature {
×
4186
                        fv.Set(lnwire.FeatureBit(row.FeatureBit))
×
4187

×
4188
                        continue
×
4189
                }
4190

4191
                tlvType, ok := row.ExtraKey.(int64)
×
4192
                if !ok {
×
4193
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4194
                                "TLV type: %T", row.ExtraKey)
×
4195
                }
×
4196

4197
                valueBytes, ok := row.Value.([]byte)
×
4198
                if !ok {
×
4199
                        return nil, nil, fmt.Errorf("unexpected type for "+
×
4200
                                "Value: %T", row.Value)
×
4201
                }
×
4202

4203
                extras[uint64(tlvType)] = valueBytes
×
4204
        }
4205

4206
        return fv, extras, nil
×
4207
}
4208

4209
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4210
// retrieves all the extra info required to build the complete
4211
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4212
// the provided sqlc.GraphChannelPolicy records are nil.
4213
func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
4214
        dbPol1, dbPol2 *sqlc.GraphChannelPolicy, channelID uint64, node1,
4215
        node2 route.Vertex) (*models.ChannelEdgePolicy,
4216
        *models.ChannelEdgePolicy, error) {
×
4217

×
4218
        if dbPol1 == nil && dbPol2 == nil {
×
4219
                return nil, nil, nil
×
4220
        }
×
4221

4222
        var (
×
4223
                policy1ID int64
×
4224
                policy2ID int64
×
4225
        )
×
4226
        if dbPol1 != nil {
×
4227
                policy1ID = dbPol1.ID
×
4228
        }
×
4229
        if dbPol2 != nil {
×
4230
                policy2ID = dbPol2.ID
×
4231
        }
×
4232
        rows, err := db.GetChannelPolicyExtraTypes(
×
4233
                ctx, sqlc.GetChannelPolicyExtraTypesParams{
×
4234
                        ID:   policy1ID,
×
4235
                        ID_2: policy2ID,
×
4236
                },
×
4237
        )
×
4238
        if err != nil {
×
4239
                return nil, nil, err
×
4240
        }
×
4241

4242
        var (
×
4243
                dbPol1Extras = make(map[uint64][]byte)
×
4244
                dbPol2Extras = make(map[uint64][]byte)
×
4245
        )
×
4246
        for _, row := range rows {
×
4247
                switch row.PolicyID {
×
4248
                case policy1ID:
×
4249
                        dbPol1Extras[uint64(row.Type)] = row.Value
×
4250
                case policy2ID:
×
4251
                        dbPol2Extras[uint64(row.Type)] = row.Value
×
4252
                default:
×
4253
                        return nil, nil, fmt.Errorf("unexpected policy ID %d "+
×
4254
                                "in row: %v", row.PolicyID, row)
×
4255
                }
4256
        }
4257

4258
        var pol1, pol2 *models.ChannelEdgePolicy
×
4259
        if dbPol1 != nil {
×
4260
                pol1, err = buildChanPolicy(
×
4261
                        *dbPol1, channelID, dbPol1Extras, node2,
×
4262
                )
×
4263
                if err != nil {
×
4264
                        return nil, nil, err
×
4265
                }
×
4266
        }
4267
        if dbPol2 != nil {
×
4268
                pol2, err = buildChanPolicy(
×
4269
                        *dbPol2, channelID, dbPol2Extras, node1,
×
4270
                )
×
4271
                if err != nil {
×
4272
                        return nil, nil, err
×
4273
                }
×
4274
        }
4275

4276
        return pol1, pol2, nil
×
4277
}
4278

4279
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4280
// provided sqlc.GraphChannelPolicy and other required information.
4281
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4282
        extras map[uint64][]byte,
4283
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4284

×
4285
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4286
        if err != nil {
×
4287
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4288
                        "fields: %w", err)
×
4289
        }
×
4290

4291
        var inboundFee fn.Option[lnwire.Fee]
×
4292
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4293
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4294

×
4295
                inboundFee = fn.Some(lnwire.Fee{
×
4296
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4297
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4298
                })
×
4299
        }
×
4300

4301
        return &models.ChannelEdgePolicy{
×
4302
                SigBytes:  dbPolicy.Signature,
×
4303
                ChannelID: channelID,
×
4304
                LastUpdate: time.Unix(
×
4305
                        dbPolicy.LastUpdate.Int64, 0,
×
4306
                ),
×
4307
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4308
                        dbPolicy.MessageFlags,
×
4309
                ),
×
4310
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4311
                        dbPolicy.ChannelFlags,
×
4312
                ),
×
4313
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4314
                MinHTLC: lnwire.MilliSatoshi(
×
4315
                        dbPolicy.MinHtlcMsat,
×
4316
                ),
×
4317
                MaxHTLC: lnwire.MilliSatoshi(
×
4318
                        dbPolicy.MaxHtlcMsat.Int64,
×
4319
                ),
×
4320
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4321
                        dbPolicy.BaseFeeMsat,
×
4322
                ),
×
4323
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4324
                ToNode:                    toNode,
×
4325
                InboundFee:                inboundFee,
×
4326
                ExtraOpaqueData:           recs,
×
4327
        }, nil
×
4328
}
4329

4330
// buildNodes builds the models.LightningNode instances for the
4331
// given row which is expected to be a sqlc type that contains node information.
4332
func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
4333
        dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode,
4334
        error) {
×
4335

×
4336
        node1, err := buildNode(ctx, db, &dbNode1)
×
4337
        if err != nil {
×
4338
                return nil, nil, err
×
4339
        }
×
4340

4341
        node2, err := buildNode(ctx, db, &dbNode2)
×
4342
        if err != nil {
×
4343
                return nil, nil, err
×
4344
        }
×
4345

4346
        return node1, node2, nil
×
4347
}
4348

4349
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4350
// row which is expected to be a sqlc type that contains channel policy
4351
// information. It returns two policies, which may be nil if the policy
4352
// information is not present in the row.
4353
//
4354
//nolint:ll,dupl,funlen
4355
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4356
        *sqlc.GraphChannelPolicy, error) {
×
4357

×
4358
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4359
        switch r := row.(type) {
×
4360
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4361
                if r.Policy1Timelock.Valid {
×
4362
                        policy1 = &sqlc.GraphChannelPolicy{
×
4363
                                Timelock:                r.Policy1Timelock.Int32,
×
4364
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4365
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4366
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4367
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4368
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4369
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4370
                                Disabled:                r.Policy1Disabled,
×
4371
                                MessageFlags:            r.Policy1MessageFlags,
×
4372
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4373
                        }
×
4374
                }
×
4375
                if r.Policy2Timelock.Valid {
×
4376
                        policy2 = &sqlc.GraphChannelPolicy{
×
4377
                                Timelock:                r.Policy2Timelock.Int32,
×
4378
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4379
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4380
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4381
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4382
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4383
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4384
                                Disabled:                r.Policy2Disabled,
×
4385
                                MessageFlags:            r.Policy2MessageFlags,
×
4386
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4387
                        }
×
4388
                }
×
4389

4390
                return policy1, policy2, nil
×
4391

4392
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4393
                if r.Policy1ID.Valid {
×
4394
                        policy1 = &sqlc.GraphChannelPolicy{
×
4395
                                ID:                      r.Policy1ID.Int64,
×
4396
                                Version:                 r.Policy1Version.Int16,
×
4397
                                ChannelID:               r.GraphChannel.ID,
×
4398
                                NodeID:                  r.Policy1NodeID.Int64,
×
4399
                                Timelock:                r.Policy1Timelock.Int32,
×
4400
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4401
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4402
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4403
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4404
                                LastUpdate:              r.Policy1LastUpdate,
×
4405
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4406
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4407
                                Disabled:                r.Policy1Disabled,
×
4408
                                MessageFlags:            r.Policy1MessageFlags,
×
4409
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4410
                                Signature:               r.Policy1Signature,
×
4411
                        }
×
4412
                }
×
4413
                if r.Policy2ID.Valid {
×
4414
                        policy2 = &sqlc.GraphChannelPolicy{
×
4415
                                ID:                      r.Policy2ID.Int64,
×
4416
                                Version:                 r.Policy2Version.Int16,
×
4417
                                ChannelID:               r.GraphChannel.ID,
×
4418
                                NodeID:                  r.Policy2NodeID.Int64,
×
4419
                                Timelock:                r.Policy2Timelock.Int32,
×
4420
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4421
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4422
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4423
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4424
                                LastUpdate:              r.Policy2LastUpdate,
×
4425
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4426
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4427
                                Disabled:                r.Policy2Disabled,
×
4428
                                MessageFlags:            r.Policy2MessageFlags,
×
4429
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4430
                                Signature:               r.Policy2Signature,
×
4431
                        }
×
4432
                }
×
4433

4434
                return policy1, policy2, nil
×
4435

4436
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4437
                if r.Policy1ID.Valid {
×
4438
                        policy1 = &sqlc.GraphChannelPolicy{
×
4439
                                ID:                      r.Policy1ID.Int64,
×
4440
                                Version:                 r.Policy1Version.Int16,
×
4441
                                ChannelID:               r.GraphChannel.ID,
×
4442
                                NodeID:                  r.Policy1NodeID.Int64,
×
4443
                                Timelock:                r.Policy1Timelock.Int32,
×
4444
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4445
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4446
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4447
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4448
                                LastUpdate:              r.Policy1LastUpdate,
×
4449
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4450
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4451
                                Disabled:                r.Policy1Disabled,
×
4452
                                MessageFlags:            r.Policy1MessageFlags,
×
4453
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4454
                                Signature:               r.Policy1Signature,
×
4455
                        }
×
4456
                }
×
4457
                if r.Policy2ID.Valid {
×
4458
                        policy2 = &sqlc.GraphChannelPolicy{
×
4459
                                ID:                      r.Policy2ID.Int64,
×
4460
                                Version:                 r.Policy2Version.Int16,
×
4461
                                ChannelID:               r.GraphChannel.ID,
×
4462
                                NodeID:                  r.Policy2NodeID.Int64,
×
4463
                                Timelock:                r.Policy2Timelock.Int32,
×
4464
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4465
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4466
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4467
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4468
                                LastUpdate:              r.Policy2LastUpdate,
×
4469
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4470
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4471
                                Disabled:                r.Policy2Disabled,
×
4472
                                MessageFlags:            r.Policy2MessageFlags,
×
4473
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4474
                                Signature:               r.Policy2Signature,
×
4475
                        }
×
4476
                }
×
4477

4478
                return policy1, policy2, nil
×
4479

4480
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4481
                if r.Policy1ID.Valid {
×
4482
                        policy1 = &sqlc.GraphChannelPolicy{
×
4483
                                ID:                      r.Policy1ID.Int64,
×
4484
                                Version:                 r.Policy1Version.Int16,
×
4485
                                ChannelID:               r.GraphChannel.ID,
×
4486
                                NodeID:                  r.Policy1NodeID.Int64,
×
4487
                                Timelock:                r.Policy1Timelock.Int32,
×
4488
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4489
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4490
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4491
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4492
                                LastUpdate:              r.Policy1LastUpdate,
×
4493
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4494
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4495
                                Disabled:                r.Policy1Disabled,
×
4496
                                MessageFlags:            r.Policy1MessageFlags,
×
4497
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4498
                                Signature:               r.Policy1Signature,
×
4499
                        }
×
4500
                }
×
4501
                if r.Policy2ID.Valid {
×
4502
                        policy2 = &sqlc.GraphChannelPolicy{
×
4503
                                ID:                      r.Policy2ID.Int64,
×
4504
                                Version:                 r.Policy2Version.Int16,
×
4505
                                ChannelID:               r.GraphChannel.ID,
×
4506
                                NodeID:                  r.Policy2NodeID.Int64,
×
4507
                                Timelock:                r.Policy2Timelock.Int32,
×
4508
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4509
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4510
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4511
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4512
                                LastUpdate:              r.Policy2LastUpdate,
×
4513
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4514
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4515
                                Disabled:                r.Policy2Disabled,
×
4516
                                MessageFlags:            r.Policy2MessageFlags,
×
4517
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4518
                                Signature:               r.Policy2Signature,
×
4519
                        }
×
4520
                }
×
4521

4522
                return policy1, policy2, nil
×
4523

4524
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4525
                if r.Policy1ID.Valid {
×
4526
                        policy1 = &sqlc.GraphChannelPolicy{
×
4527
                                ID:                      r.Policy1ID.Int64,
×
4528
                                Version:                 r.Policy1Version.Int16,
×
4529
                                ChannelID:               r.GraphChannel.ID,
×
4530
                                NodeID:                  r.Policy1NodeID.Int64,
×
4531
                                Timelock:                r.Policy1Timelock.Int32,
×
4532
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4533
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4534
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4535
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4536
                                LastUpdate:              r.Policy1LastUpdate,
×
4537
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4538
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4539
                                Disabled:                r.Policy1Disabled,
×
4540
                                MessageFlags:            r.Policy1MessageFlags,
×
4541
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4542
                                Signature:               r.Policy1Signature,
×
4543
                        }
×
4544
                }
×
4545
                if r.Policy2ID.Valid {
×
4546
                        policy2 = &sqlc.GraphChannelPolicy{
×
4547
                                ID:                      r.Policy2ID.Int64,
×
4548
                                Version:                 r.Policy2Version.Int16,
×
4549
                                ChannelID:               r.GraphChannel.ID,
×
4550
                                NodeID:                  r.Policy2NodeID.Int64,
×
4551
                                Timelock:                r.Policy2Timelock.Int32,
×
4552
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4553
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4554
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4555
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4556
                                LastUpdate:              r.Policy2LastUpdate,
×
4557
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4558
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4559
                                Disabled:                r.Policy2Disabled,
×
4560
                                MessageFlags:            r.Policy2MessageFlags,
×
4561
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4562
                                Signature:               r.Policy2Signature,
×
4563
                        }
×
4564
                }
×
4565

4566
                return policy1, policy2, nil
×
4567

4568
        case sqlc.ListChannelsByNodeIDRow:
×
4569
                if r.Policy1ID.Valid {
×
4570
                        policy1 = &sqlc.GraphChannelPolicy{
×
4571
                                ID:                      r.Policy1ID.Int64,
×
4572
                                Version:                 r.Policy1Version.Int16,
×
4573
                                ChannelID:               r.GraphChannel.ID,
×
4574
                                NodeID:                  r.Policy1NodeID.Int64,
×
4575
                                Timelock:                r.Policy1Timelock.Int32,
×
4576
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4577
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4578
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4579
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4580
                                LastUpdate:              r.Policy1LastUpdate,
×
4581
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4582
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4583
                                Disabled:                r.Policy1Disabled,
×
4584
                                MessageFlags:            r.Policy1MessageFlags,
×
4585
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4586
                                Signature:               r.Policy1Signature,
×
4587
                        }
×
4588
                }
×
4589
                if r.Policy2ID.Valid {
×
4590
                        policy2 = &sqlc.GraphChannelPolicy{
×
4591
                                ID:                      r.Policy2ID.Int64,
×
4592
                                Version:                 r.Policy2Version.Int16,
×
4593
                                ChannelID:               r.GraphChannel.ID,
×
4594
                                NodeID:                  r.Policy2NodeID.Int64,
×
4595
                                Timelock:                r.Policy2Timelock.Int32,
×
4596
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4597
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4598
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4599
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4600
                                LastUpdate:              r.Policy2LastUpdate,
×
4601
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4602
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4603
                                Disabled:                r.Policy2Disabled,
×
4604
                                MessageFlags:            r.Policy2MessageFlags,
×
4605
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4606
                                Signature:               r.Policy2Signature,
×
4607
                        }
×
4608
                }
×
4609

4610
                return policy1, policy2, nil
×
4611

4612
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4613
                if r.Policy1ID.Valid {
×
4614
                        policy1 = &sqlc.GraphChannelPolicy{
×
4615
                                ID:                      r.Policy1ID.Int64,
×
4616
                                Version:                 r.Policy1Version.Int16,
×
4617
                                ChannelID:               r.GraphChannel.ID,
×
4618
                                NodeID:                  r.Policy1NodeID.Int64,
×
4619
                                Timelock:                r.Policy1Timelock.Int32,
×
4620
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4621
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4622
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4623
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4624
                                LastUpdate:              r.Policy1LastUpdate,
×
4625
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4626
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4627
                                Disabled:                r.Policy1Disabled,
×
4628
                                MessageFlags:            r.Policy1MessageFlags,
×
4629
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4630
                                Signature:               r.Policy1Signature,
×
4631
                        }
×
4632
                }
×
4633
                if r.Policy2ID.Valid {
×
4634
                        policy2 = &sqlc.GraphChannelPolicy{
×
4635
                                ID:                      r.Policy2ID.Int64,
×
4636
                                Version:                 r.Policy2Version.Int16,
×
4637
                                ChannelID:               r.GraphChannel.ID,
×
4638
                                NodeID:                  r.Policy2NodeID.Int64,
×
4639
                                Timelock:                r.Policy2Timelock.Int32,
×
4640
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4641
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4642
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4643
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4644
                                LastUpdate:              r.Policy2LastUpdate,
×
4645
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4646
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4647
                                Disabled:                r.Policy2Disabled,
×
4648
                                MessageFlags:            r.Policy2MessageFlags,
×
4649
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4650
                                Signature:               r.Policy2Signature,
×
4651
                        }
×
4652
                }
×
4653

4654
                return policy1, policy2, nil
×
4655
        default:
×
4656
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4657
                        "extractChannelPolicies: %T", r)
×
4658
        }
4659
}
4660

4661
// channelIDToBytes converts a channel ID (SCID) to a byte array
4662
// representation.
4663
func channelIDToBytes(channelID uint64) []byte {
×
4664
        var chanIDB [8]byte
×
4665
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4666

×
4667
        return chanIDB[:]
×
4668
}
×
4669

4670
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
NEW
4671
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
NEW
4672
        if len(addresses) == 0 {
×
NEW
4673
                return nil, nil
×
NEW
4674
        }
×
4675

NEW
4676
        result := make([]net.Addr, 0, len(addresses))
×
NEW
4677
        for _, addr := range addresses {
×
NEW
4678
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
NEW
4679
                if err != nil {
×
NEW
4680
                        return nil, fmt.Errorf("unable to parse address %s "+
×
NEW
4681
                                "of type %d: %w", addr.address, addr.addrType,
×
NEW
4682
                                err)
×
NEW
4683
                }
×
NEW
4684
                if netAddr != nil {
×
NEW
4685
                        result = append(result, netAddr)
×
NEW
4686
                }
×
4687
        }
4688

4689
        // If we have no valid addresses, return nil instead of empty slice.
NEW
4690
        if len(result) == 0 {
×
NEW
4691
                return nil, nil
×
NEW
4692
        }
×
4693

NEW
4694
        return result, nil
×
4695
}
4696

4697
// parseAddress parses the given address string based on the address type
4698
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4699
// and opaque addresses.
NEW
4700
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
NEW
4701
        switch addrType {
×
NEW
4702
        case addressTypeIPv4:
×
NEW
4703
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
NEW
4704
                if err != nil {
×
NEW
4705
                        return nil, err
×
NEW
4706
                }
×
4707

NEW
4708
                tcp.IP = tcp.IP.To4()
×
NEW
4709

×
NEW
4710
                return tcp, nil
×
4711

NEW
4712
        case addressTypeIPv6:
×
NEW
4713
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
NEW
4714
                if err != nil {
×
NEW
4715
                        return nil, err
×
NEW
4716
                }
×
4717

NEW
4718
                return tcp, nil
×
4719

NEW
4720
        case addressTypeTorV3, addressTypeTorV2:
×
NEW
4721
                service, portStr, err := net.SplitHostPort(address)
×
NEW
4722
                if err != nil {
×
NEW
4723
                        return nil, fmt.Errorf("unable to split tor "+
×
NEW
4724
                                "address: %v", address)
×
NEW
4725
                }
×
4726

NEW
4727
                port, err := strconv.Atoi(portStr)
×
NEW
4728
                if err != nil {
×
NEW
4729
                        return nil, err
×
NEW
4730
                }
×
4731

NEW
4732
                return &tor.OnionAddr{
×
NEW
4733
                        OnionService: service,
×
NEW
4734
                        Port:         port,
×
NEW
4735
                }, nil
×
4736

NEW
4737
        case addressTypeOpaque:
×
NEW
4738
                opaque, err := hex.DecodeString(address)
×
NEW
4739
                if err != nil {
×
NEW
4740
                        return nil, fmt.Errorf("unable to decode opaque "+
×
NEW
4741
                                "address: %v", address)
×
NEW
4742
                }
×
4743

NEW
4744
                return &lnwire.OpaqueAddrs{
×
NEW
4745
                        Payload: opaque,
×
NEW
4746
                }, nil
×
4747

NEW
4748
        default:
×
NEW
4749
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4750
        }
4751
}
4752

4753
// batchNodeData holds all the related data for a batch of nodes.
4754
type batchNodeData struct {
4755
        // features is a map from a DB node ID to the feature bits for that
4756
        // node.
4757
        features map[int64][]int
4758

4759
        // addresses is a map from a DB node ID to the node's addresses.
4760
        addresses map[int64][]nodeAddress
4761

4762
        // extraFields is a map from a DB node ID to the extra signed fields
4763
        // for that node.
4764
        extraFields map[int64]map[uint64][]byte
4765
}
4766

4767
// nodeAddress holds the address type, position and address string for a
4768
// node. This is used to batch the fetching of node addresses.
4769
type nodeAddress struct {
4770
        addrType dbAddressType
4771
        position int32
4772
        address  string
4773
}
4774

4775
// batchLoadNodeData loads all related data for a batch of node IDs using the
4776
// provided SQLQueries interface. It returns a batchNodeData instance containing
4777
// the node features, addresses and extra signed fields.
4778
func batchLoadNodeData(ctx context.Context, cfg *sqldb.PagedQueryConfig,
NEW
4779
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
NEW
4780

×
NEW
4781
        // Batch load the node features.
×
NEW
4782
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
NEW
4783
        if err != nil {
×
NEW
4784
                return nil, fmt.Errorf("unable to batch load node "+
×
NEW
4785
                        "features: %w", err)
×
NEW
4786
        }
×
4787

4788
        // Batch load the node addresses.
NEW
4789
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
NEW
4790
        if err != nil {
×
NEW
4791
                return nil, fmt.Errorf("unable to batch load node "+
×
NEW
4792
                        "addresses: %w", err)
×
NEW
4793
        }
×
4794

4795
        // Batch load the node extra signed fields.
NEW
4796
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
NEW
4797
        if err != nil {
×
NEW
4798
                return nil, fmt.Errorf("unable to batch load node extra "+
×
NEW
4799
                        "signed fields: %w", err)
×
NEW
4800
        }
×
4801

NEW
4802
        return &batchNodeData{
×
NEW
4803
                features:    features,
×
NEW
4804
                addresses:   addrs,
×
NEW
4805
                extraFields: extraTypes,
×
NEW
4806
        }, nil
×
4807
}
4808

4809
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
4810
// using ExecutePagedQuery wrapper around the GetNodeFeaturesBatch query.
4811
func batchLoadNodeFeaturesHelper(ctx context.Context,
4812
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
NEW
4813
        nodeIDs []int64) (map[int64][]int, error) {
×
NEW
4814

×
NEW
4815
        features := make(map[int64][]int)
×
NEW
4816

×
NEW
4817
        return features, sqldb.ExecutePagedQuery(
×
NEW
4818
                ctx, cfg, nodeIDs,
×
NEW
4819
                func(id int64) int64 {
×
NEW
4820
                        return id
×
NEW
4821
                },
×
4822
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
NEW
4823
                        error) {
×
NEW
4824

×
NEW
4825
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
NEW
4826
                },
×
NEW
4827
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
NEW
4828
                        features[feature.NodeID] = append(
×
NEW
4829
                                features[feature.NodeID],
×
NEW
4830
                                int(feature.FeatureBit),
×
NEW
4831
                        )
×
NEW
4832

×
NEW
4833
                        return nil
×
NEW
4834
                },
×
4835
        )
4836
}
4837

4838
// batchLoadNodeAddressesHelper loads node addresses using ExecutePagedQuery
4839
// wrapper around the GetNodeAddressesBatch query. It returns a map from
4840
// node ID to a slice of nodeAddress structs.
4841
func batchLoadNodeAddressesHelper(ctx context.Context,
4842
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
NEW
4843
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
NEW
4844

×
NEW
4845
        addrs := make(map[int64][]nodeAddress)
×
NEW
4846

×
NEW
4847
        return addrs, sqldb.ExecutePagedQuery(
×
NEW
4848
                ctx, cfg, nodeIDs,
×
NEW
4849
                func(id int64) int64 {
×
NEW
4850
                        return id
×
NEW
4851
                },
×
4852
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
NEW
4853
                        error) {
×
NEW
4854

×
NEW
4855
                        return db.GetNodeAddressesBatch(ctx, ids)
×
NEW
4856
                },
×
NEW
4857
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
NEW
4858
                        addrs[addr.NodeID] = append(
×
NEW
4859
                                addrs[addr.NodeID], nodeAddress{
×
NEW
4860
                                        addrType: dbAddressType(addr.Type),
×
NEW
4861
                                        position: addr.Position,
×
NEW
4862
                                        address:  addr.Address,
×
NEW
4863
                                },
×
NEW
4864
                        )
×
NEW
4865

×
NEW
4866
                        return nil
×
NEW
4867
                },
×
4868
        )
4869
}
4870

4871
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
4872
// node IDs using ExecutePagedQuery wrapper around the GetNodeExtraTypesBatch
4873
// query.
4874
func batchLoadNodeExtraTypesHelper(ctx context.Context,
4875
        cfg *sqldb.PagedQueryConfig, db SQLQueries,
NEW
4876
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
NEW
4877

×
NEW
4878
        extraFields := make(map[int64]map[uint64][]byte)
×
NEW
4879

×
NEW
4880
        callback := func(ctx context.Context,
×
NEW
4881
                field sqlc.GraphNodeExtraType) error {
×
NEW
4882

×
NEW
4883
                if extraFields[field.NodeID] == nil {
×
NEW
4884
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
NEW
4885
                }
×
NEW
4886
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
NEW
4887

×
NEW
4888
                return nil
×
4889
        }
4890

NEW
4891
        return extraFields, sqldb.ExecutePagedQuery(
×
NEW
4892
                ctx, cfg, nodeIDs,
×
NEW
4893
                func(id int64) int64 {
×
NEW
4894
                        return id
×
NEW
4895
                },
×
4896
                func(ctx context.Context, ids []int64) (
NEW
4897
                        []sqlc.GraphNodeExtraType, error) {
×
NEW
4898

×
NEW
4899
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
NEW
4900
                },
×
4901
                callback,
4902
        )
4903
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc